Source code for deepblink.networks.unet

"""UNet architecture."""

import math

import tensorflow as tf

from ._networks import OPTIONS_CONV
from ._networks import conv_block
from ._networks import inception_block
from ._networks import residual_block
from ._networks import squeeze_block
from ._networks import upconv_block

def __block(inputs, filters, block, l2):
    opts_conv = OPTIONS_CONV
    opts_conv["kernel_regularizer"] = tf.keras.regularizers.l2(l2) if l2 else None
    opts_conv["bias_regularizer"] = tf.keras.regularizers.l2(l2) if l2 else None
    if block == "convolutional":
        x = conv_block(inputs=inputs, filters=filters, n_convs=3, opts_conv=opts_conv)
    if block == "inception":
        x = inception_block(inputs=inputs, filters=filters, l2_regularizer=l2)
    if block == "residual":
        x = residual_block(inputs=inputs, filters=filters, opts_conv=opts_conv)
    x = squeeze_block(x=x)
    return x

def __encoder(inputs, filters, block, l2, dropout):
    x = __block(inputs, filters, block, l2)
    skip = tf.keras.layers.SpatialDropout2D(dropout)(x)
    x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(skip)
    return x, skip

def __decoder(inputs, skip, filters, block, l2):
    x = __block(inputs, filters, block, l2)
    x = upconv_block(inputs=x, skip=skip)
    return x

[docs]def unet( dropout: float = 0.2, cell_size: int = 4, filters: int = 5, ndown: int = 2, l2: float = 1e-6, block: str = "convolutional", ) -> tf.keras.models.Model: """Unet model with second, cell size dependent encoder. Note that "convolution" is the currently best block. Arguments: dropout: Percentage of dropout before each MaxPooling step. cell_size: Size of one cell in the prediction matrix. filters: Log_2 number of filters in the first inception block. ndown: Downsampling steps in the first encoder / decoder. l2: L2 value for kernel and bias regularization. block: Type of block in each layer. [options: convolutional, inception, residual] """ if not math.log(cell_size, 2).is_integer(): raise ValueError(f"cell_size must be a power of 2, but is {cell_size}.") # Input inputs = tf.keras.layers.Input(shape=(None, None, 1)) x = inputs skip_layers = [] # Encoder v1 for n in range(ndown): x, skip = __encoder(x, 2 ** (filters + n), block, l2, dropout) skip_layers.append(skip) skip_bottom = x # Decoder for n, skip in enumerate(reversed(skip_layers)): x = __decoder(x, skip, 2 ** (filters + (ndown - n)), block, l2) # Encoder v2 ndown_cell = int(math.log(cell_size, 2)) for n in range(ndown_cell): x, _ = __encoder(x, 2 ** (filters + n), block, l2, dropout) # Logit if ndown == 2 and cell_size == 4: x = tf.keras.layers.Concatenate()([skip_bottom, x]) x = __block(x, 2 ** (filters + ndown_cell), block, l2) x = tf.keras.layers.Conv2D(filters=3, kernel_size=1, strides=1)(x) x = tf.keras.layers.Activation("sigmoid")(x) model = tf.keras.Model(inputs=inputs, outputs=x) return model