结合残差结构的Res-Unet及其代码实现

本博客主要为代码实现的小伙伴提供模板,具体的原理已经有好多文章啦,所以在这里我也就不啰嗦啦,只作简单介绍!

1.残差结构

1.1 残差单元
与普通网络的串行结构相比,残差单元增加了跳跃映射,将输入与输出直接进行相加,补充卷积过程中损失的特征信息,这点与U-net的跳跃连接结构有点类似,不过Res中的跳跃连接做的是Add操作,而U-net的跳跃连接做的是Concatenate操作,还是有本质的不同,残差单元如图所示:在这里插入图片描述
1.2 残差家族
根据层数的不同ResNet有以下的伐木累
在这里插入图片描述
由图中可以看出ResNet18和ResNet34的残差单元只有两层,而ResNet50、ResNet101、ResNet152的残差单元有三层。本文使用的是ResNet50所以主要讲解ResNet50。
1.3 Resnet50
ResNet50由两个基本结构组成,一个为Conv_block,一个为Identity_block。如图所示:Identity_blockConv_block
Conv_block用于放在第一个位置调整通道以及特征图大小,Identity_block用于加深网络。

2.U-net

这个就不说了好吧
自己看吧…

3.Res-Unet

本文讲的Res-Unet主要是将U-net特征提取层的普通卷积用ResNet50代替,以达到加深网络的目的,直接上代码吧!

import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K

IMAGE_ORDERING = 'channels_last'


def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
               padding='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    filters1, filters2, filters3 = filters

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same',
               name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x


def get_resnet50_encoder(input_height=256, input_width=256, pretrained='imagenet',
                         include_top=True, weights='imagenet',
                         input_tensor=None, input_shape=None,
                         pooling=None,
                         classes=1000):
    assert input_height % 32 == 0
    assert input_width % 32 == 0

    if IMAGE_ORDERING == 'channels_first':
        img_input = Input(shape=(6, input_height, input_width))
    elif IMAGE_ORDERING == 'channels_last':
        img_input = Input(shape=(input_height, input_width, 6))

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
    f1 = x
    x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
    x = conv_block(x, 3, [32, 32, 128], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [32, 32, 128], stage=2, block='b')
    x = identity_block(x, 3, [32, 32, 128], stage=2, block='c')
    f2 = x

    x = conv_block(x, 3, [64, 64, 256], stage=3, block='a')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='c')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='d')
    f3 = x

    x = conv_block(x, 3, [128, 128, 512], stage=4, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='d')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='e')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='f')
    f4 = x

    x = conv_block(x, 3, [256, 256, 1024], stage=5, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=5, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=5, block='c')
    f5 = x


    return img_input, [f1, f2, f3, f4, f5]

def _unet(n_classes , encoder , input_height = 256, input_width = 256):

    img_input , levels = encoder(input_height=input_height,input_width=input_width)
    [f1 , f2 , f3 , f4 , f5] = levels

    o = f5
    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f4],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f3],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f2],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f1],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = Conv2D(n_classes, (3,3) ,padding='same',data_format=IMAGE_ORDERING)(o)
    o = Reshape((int(input_height)*int(input_width),-1))(o)
    o = Softmax()(o)
    model = Model(img_input,o)

    return model

def conv_unet(n_classes , input_height=256, input_width=256):

    model = _unet(n_classes,get_resnet50_encoder,input_height=input_height,input_width=input_width)
    model.model_name = 'conv_unet'
    return model

model = conv_unet(2)
model.summary()



什么?你用的不是Keras?
什么?你用的Pytroch?

  • 15
    点赞
  • 131
    收藏
    觉得还不错? 一键收藏
  • 14
    评论
引用\[1\]:Li, H., Xiong, P., An, J., Wang, L.: Pyramid attention network for semantic segmentation. arXiv preprint arXiv:1805.10180 (2018) 。 引用\[2\]:原文地址:RAUNet: Residual Attention U-Net for Semantic Segmentation of Cataract Surgical Instruments Abstract 。 引用\[3\]:解码器通过上采样恢复位置细节。然而,上采样会导致边缘模糊和位置细节丢失。现有的一些工作\[8\]采用跳转连接将低级特征与高级特征连接起来,这有助于补充位置细节。但这是一种幼稚的方法。由于底层特征中缺乏语义信息,因此包含了大量无用的背景信息。该信息可能会干扰目标对象的分割。为了解决这个问题,增强注意模块被设计成捕捉高级语义信息并强调目标特征。 。 问题:请提供attention res-unet原文。 回答:很抱歉,我无法提供attention res-unet的原文。但是根据引用\[1\]和引用\[2\]的信息,可以得知Pyramid attention network和RAUNet是两种用于语义分割的模型。Pyramid attention network是一种基于金字塔注意力机制的网络,用于语义分割任务\[1\]。而RAUNet是一种用于白内障手术器械语义分割的残差注意力U-Net模型\[2\]。根据引用\[3\]的描述,注意力模块被设计用于捕捉高级语义信息并强调目标特征,以解决上采样导致的边缘模糊和位置细节丢失的问题。 #### 引用[.reference_title] - *1* *2* *3* [Residual Attention U-Net 论文笔记](https://blog.csdn.net/charles_zhang_/article/details/124209523)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值