改进3D/2D U-NET--添加深度监督deep supervision【Keras】

时间:2024-10-23 07:42:34

前言

深度监督deep supervision(又称为中继监督intermediate supervision),其实就是网络的中间部分新添加了额外的loss,跟多任务是有区别的,多任务有不同的GT计算不同的loss,而深度监督的GT都是同一个GT,不同位置的loss按系数求和。
深度监督的目的是为了浅层能够得到更加充分的训练,避免梯度消失(ps:好像目前的技术已经使梯度消失得到了解决,像Relu,BN等等,"避免梯度消失"有待商榷,但是对训练的确有帮助)。
CPM(Convolutional Pose Machines)[2]中使用中继(深度)监督是最典型的一个例子。CPM的问题是为了解决人体姿态估计问题,分4个阶段,每个stage都会进行监督训练,使最终得到的人体姿态估计的关键点优化效果达到最佳状态。
在这里插入图片描述
下面?这张图来自论文,箭头处是每次要优化的map。注意⚠️:是使用同一个GT对各个stage的map进行优化。
在这里插入图片描述

3D U-Net with deep supervision

图片来自[1],网络结构示意图如下:
在这里插入图片描述
红色的方框内为两次的中继监督。此网络三次下采样,三次上采样,上采样的过程中进行中继监督。
[3]代码中的网络结构实现的深度监督的方式,如下图所示:
在这里插入图片描述
实现的代码[Code with tensorflow][3]:

def unet3d(inputs):
    depth = config.DEPTH
    filters = []
    down_list = []
    deep_supervision = None
    layer = tf.layers.conv3d(inputs=inputs, 
                   filters=BASE_FILTER,
                   kernel_size=(3,3,3),
                   strides=1,
                   padding=PADDING,
                   activation=lambda x, name=None: BN_Relu(x),
                   data_format=DATA_FORMAT,
                   name="init_conv")  
    for d in range(depth):
        if config.FILTER_GROW:
            num_filters = BASE_FILTER * (2**d)
        else:
            num_filters = BASE_FILTER
        filters.append(num_filters)
        layer = Unet3dBlock('down{}'.format(d), layer, kernels=(3,3,3), n_feat=num_filters, s=1)
        down_list.append(layer)
        if d != depth - 1:
            layer = tf.layers.conv3d(inputs=layer, 
                                    filters=num_filters*2,
                                    kernel_size=(3,3,3),
                                    strides=(2,2,2),
                                    padding=PADDING,
                                    activation=lambda x, name=None: BN_Relu(x),
                                    data_format=DATA_FORMAT,
                                    name="stride2conv{}".format(d))
        print("1 layer", layer.shape)
    for d in range(depth-2, -1, -1):
        layer = UnetUpsample(d, layer, filters[d])
        if DATA_FORMAT == 'channels_first':
            layer = tf.concat([layer, down_list[d]], axis=1)
        else:
            layer = tf.concat([layer, down_list[d]], axis=-1)
        #layer = Unet3dBlock('up{}'.format(d), layer, kernels=(3,3,3), n_feat=filters[d], s=1)
        layer = tf.layers.conv3d(inputs=layer, 
                                filters=filters[d],
                                kernel_size=(3,3,3),
                                strides=1,
                                padding=PADDING,
                                activation=lambda x, name=None: BN_Relu(x),
                                data_format=DATA_FORMAT,
                                name="lo_conv0_{}".format(d))
        layer = tf.layers.conv3d(inputs=layer, 
                                filters=filters[d],
                                kernel_size=(1,1,1),
                                strides=1,
                                padding=PADDING,
                                activation=lambda x, name=None: BN_Relu(x),
                                data_format=DATA_FORMAT,
                                name="lo_conv1_{}".format(d))
        if config.DEEP_SUPERVISION:
            if d < 3 and d > 0:
                pred = tf.layers.conv3d(inputs=layer, 
                                    filters=config.NUM_CLASS,
                                    kernel_size=(1,1,1),
                                    strides=1,
                                    padding=PADDING,
                                    activation=tf.identity,
                                    data_format=DATA_FORMAT,
                                    name="deep_super_{}".format(d))
                if deep_supervision is None:
                    deep_supervision = pred
                else:
                    deep_supervision = deep_supervision + pred
                deep_supervision = Upsample3D(d, deep_supervision)
                
    layer = tf.layers.conv3d(layer, 
                            filters=config.NUM_CLASS,
                            kernel_size=(1,1,1),
                            padding="SAME",
                            activation=tf.identity,
                            data_format=DATA_FORMAT,
                            name="final")
    if config.DEEP_SUPERVISION:
        layer = layer + deep_supervision
    if DATA_FORMAT == 'channels_first':
        layer = tf.transpose(layer, [0, 2, 3, 4, 1]) # to-channel last
    print("final", layer.shape) # [3, num_class, d, h, w]
    return layer

def Upsample3D(prefix, l, scale=2):
    l = tf.keras.layers.UpSampling3D(size=(2,2,2), data_format=DATA_FORMAT)(l)
    return l

def UnetUpsample(prefix, l, num_filters):
    l = Upsample3D('', l)
    l = tf.layers.conv3d(inputs=l, 
                        filters=num_filters,
                        kernel_size=(3,3,3),
                        strides=1,
                        padding=PADDING,
                        activation=lambda x, name=None: BN_Relu(x),
                        data_format=DATA_FORMAT,
                        name="up_conv1_{}".format(prefix))
    return l

def BN_Relu(x):
    if config.INSTANCE_NORM:
        l = InstanceNorm5d('ins_norm', x, data_format=DATA_FORMAT)
    else:
        l = BatchNorm3d('bn', x, axis=1 if DATA_FORMAT == 'channels_first' else -1)
    l = tf.nn.relu(l)
    return l

def Unet3dBlock(prefix, l, kernels, n_feat, s):
    if config.RESIDUAL:
        l_in = l

    for i in range(2):
        l = tf.layers.conv3d(inputs=l, 
                   filters=n_feat,
                   kernel_size=kernels,
                   strides=1,
                   padding=PADDING,
                   activation=lambda x, name=None: BN_Relu(x),
                   data_format=DATA_FORMAT,
                   name="{}_conv_{}".format(prefix, i))

    return l_in + l if config.RESIDUAL else l

Code with tensorflow

3D代码可以参考[3]。
2D代码可以参考[2]。

思考

深度监督的形式,目前感觉有两种:
第一种形式如第一张图片所示。第二种形式略微有不同,如下图所示:
在这里插入图片描述
4个阶段的map及性能concat,然后进行卷积得到一个map,与gt求loss,也就是说最后只有一个loss。这个代码比较简单,以2D为例,参考了[2],略做修改,使用Keras。

from keras.backend import tf as ktf 
def BatchActivate(x):
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x
def convolution_block(x, filters, size, strides=(1,1), padding='same', activation=True):
    x = Conv2D(filters, size, strides=strides, padding=padding)(x)
    if activation == True:
        x = BatchActivate(x)
    return x
def residual_block(blockInput, num_filters=16, batch_activate = False):
    x = BatchActivate(blockInput)
    x = convolution_block(x, num_filters, (3,3) )
    x = convolution_block(x, num_filters, (3,3), activation=False)
    x = Add()([x, blockInput])
    if batch_activate:
        x = BatchActivate(x)
    return x
    # Build model
def build_model(input_layer, lr, start_neurons, DropoutRatio = 0.5):
    # 101 -> 50
    conv1 = Conv2D(start_neurons * 1, (3, 3), activation=None, padding="same")(input_layer)
    conv1 = residual_block(conv1,start_neurons * 1)
    conv1 = residual_block(conv1,start_neurons * 1, True)
    pool1 = MaxPooling2D((2, 2))(conv1)
    pool1 = Dropout(DropoutRatio/2)(pool1)
    # 50 -> 25
    conv2 = Conv2D(start_neurons * 2, (3, 3), activation=None, padding="same")(pool1)
    conv2 = residual_block(conv2,start_neurons * 2)
    conv2 = residual_block(conv2,start_neurons * 2, True)
    pool2 = MaxPooling2D((2, 2))(conv2)
    pool2 = Dropout(DropoutRatio)(pool2)
    # 25 -> 12
    conv3 = Conv2D(start_neurons * 4, (3, 3), activation=None, padding="same")(pool2)
    conv3 = residual_block(conv3,start_neurons * 4)
    conv3 = residual_block(conv3,start_neurons * 4, True)
    pool3 = MaxPooling2D((2, 2))(conv3)
    pool3 = Dropout(DropoutRatio)(pool3)
    # 12 -> 6
    conv4 = Conv2D(start_neurons * 8, (3, 3), activation=None, padding="same")(pool3)
    conv4 = residual_block(conv4,start_neurons * 8)
    conv4 = residual_block(conv4,start_neurons * 8, True)
    pool4 = MaxPooling2D((2, 2))(conv4)
    pool4 = Dropout(DropoutRatio)(pool4)
    # Middle
    convm = Conv2D(start_neurons * 16, (3, 3), activation=None, padding="same")(pool4)
    convm = residual_block(convm,start_neurons * 16)
    convm = residual_block(convm,start_neurons * 16, True)
    img_pool = AveragePooling2D(pool_size=8)(convm)
    image_pool = Conv2D(64, 1)(img_pool)          
    # 6 -> 12
    deconv4 = Conv2DTranspose(start_neurons * 8, (3, 3), strides=(2, 2), padding="same")(convm)
    uconv4 = concatenate([deconv4, conv4])
    uconv4 = Dropout(DropoutRatio)(uconv4)    
    uconv4 = Conv2D(start_neurons * 8, (3, 3), activation=None, padding="same")(uconv4)
    uconv4 = residual_block(uconv4,start_neurons * 8)
    uconv4 = residual_block(uconv4,start_neurons * 8, True)    
    # 12 -> 25
    #deconv3 = Conv2DTranspose(start_neurons * 4, (3, 3), strides=(2, 2), padding="same")(uconv4)
    deconv3 = Conv2DTranspose(start_neurons * 4, (3, 3), strides=(2, 2), padding="same")(uconv4)
    uconv3 = concatenate([deconv3, conv3])    
    uconv3 = Dropout(DropoutRatio)(uconv3)    
    uconv3 = Conv2D(start_neurons * 4, (3, 3), activation=None, padding="same")(uconv3)
    uconv3 = residual_block(uconv3,start_neurons * 4)
    uconv3 = residual_block(uconv3,start_neurons * 4, True)
    # 25 -> 50
    deconv2 = Conv2DTranspose(start_neurons * 2, (3, 3), strides=(2, 2), padding="same")(uconv3)
    uconv2 = concatenate([deconv2, conv2])        
    uconv2 = Dropout(DropoutRatio)(uconv2)
    uconv2 = Conv2D(start_neurons * 2, (3, 3), activation=None, padding="same")(uconv2)
    uconv2 = residual_block(uconv2,start_neurons * 2)
    uconv2 = residual_block(uconv2,start_neurons * 2, True)    
    # 50 -> 101
    #deconv1 = Conv2DTranspose(start_neurons * 1, (3, 3), strides=(2, 2), padding="same")(uconv2)
    deconv1 = Conv2DTranspose(start_neurons * 1, (3, 3), strides=(2, 2), padding="same")(uconv2)
    uconv1 = concatenate([deconv1, conv1])    
    uconv1 = Dropout(DropoutRatio)(uconv1)
    uconv1 = Conv2D(start_neurons * 1, (3, 3), activation=None, padding="same")(uconv1)
    uconv1 = residual_block(uconv1,start_neurons * 1)
    uconv1 = residual_block(uconv1,start_neurons * 1, True)
    hypercolumn = concatenate(
        [
            uconv1,
            Lambda(lambda image: ktf.image.resize_images(image, (img_size_target, img_size_target)))(uconv2),
            Lambda(lambda image: ktf.image.resize_images(image, (img_size_target, img_size_target)))(uconv3),
            Lambda(lambda image: ktf.image.resize_images(image, (img_size_target, img_size_target)))(uconv4)
        ]
    )
    hypercolumn = Dropout(0.5)(hypercolumn)
    hypercolumn = Conv2D(start_neurons * 1, (3, 3), padding="same", activation='relu')(hypercolumn)
    output_layer_noActi = Conv2D(1, (1,1), padding="same", activation=None)(hypercolumn)
    output_layer =  Activation('sigmoid', name='seg_output')(output_layer_noActi)  
    model = Model(inputs=input_layer, outputs=[classification_cover_class, classification_cover, classification_depth, output_layer, fusion])
    c = optimizers.adam(lr=lr)   
    model.compile(loss=bce_dice_loss, optimizer=c, metrics=[my_iou_metric])   
    return model

此模型训练跟正常的分割网络训练一样,网络的输入(图像输入,包括mask)都不需要做修改。

针对第一种方式,此处仍然以2D举例,添加了4个新的loss,可以给loss施加不同的权重,4个loss那必须有4个output,训练时的代码也需要进行修改。

其他

  1. 3D U-net with Multi-level Deep Supervision:Fully Automatic Segmentation of Proximal Femur in 3D MR Images
  2. Convolutional Pose Machines
  3. 3DUnet-Tensorflow-Brats18
  4. unet-resnetblock-hypercolumn-deep-supervision-fold from kaggle