结合残差结构的Res-Unet及其代码实现

本文详细介绍了Res-Unet网络的构建方法,通过使用ResNet50作为U-net特征提取层的替代,加深了网络结构,提高了特征提取能力。文中提供了完整的代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本博客主要为代码实现的小伙伴提供模板,具体的原理已经有好多文章啦,所以在这里我也就不啰嗦啦,只作简单介绍!

1.残差结构

1.1 残差单元
与普通网络的串行结构相比,残差单元增加了跳跃映射,将输入与输出直接进行相加,补充卷积过程中损失的特征信息,这点与U-net的跳跃连接结构有点类似,不过Res中的跳跃连接做的是Add操作,而U-net的跳跃连接做的是Concatenate操作,还是有本质的不同,残差单元如图所示:在这里插入图片描述
1.2 残差家族
根据层数的不同ResNet有以下的伐木累
在这里插入图片描述
由图中可以看出ResNet18和ResNet34的残差单元只有两层,而ResNet50、ResNet101、ResNet152的残差单元有三层。本文使用的是ResNet50所以主要讲解ResNet50。
1.3 Resnet50
ResNet50由两个基本结构组成,一个为Conv_block,一个为Identity_block。如图所示:Identity_blockConv_block
Conv_block用于放在第一个位置调整通道以及特征图大小,Identity_block用于加深网络。

2.U-net

这个就不说了好吧
自己看吧…

3.Res-Unet

本文讲的Res-Unet主要是将U-net特征提取层的普通卷积用ResNet50代替,以达到加深网络的目的,直接上代码吧!

import keras
from keras.models import *
from keras.layers import *
from keras import layers
import keras.backend as K

IMAGE_ORDERING = 'channels_last'


def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING,
               padding='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    filters1, filters2, filters3 = filters

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    x = Conv2D(filters1, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
               name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters2, kernel_size, data_format=IMAGE_ORDERING, padding='same',
               name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)
    x = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    shortcut = Conv2D(filters3, (1, 1), data_format=IMAGE_ORDERING, strides=strides,
                      name=conv_name_base + '1')(input_tensor)
    shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x


def get_resnet50_encoder(input_height=256, input_width=256, pretrained='imagenet',
                         include_top=True, weights='imagenet',
                         input_tensor=None, input_shape=None,
                         pooling=None,
                         classes=1000):
    assert input_height % 32 == 0
    assert input_width % 32 == 0

    if IMAGE_ORDERING == 'channels_first':
        img_input = Input(shape=(6, input_height, input_width))
    elif IMAGE_ORDERING == 'channels_last':
        img_input = Input(shape=(input_height, input_width, 6))

    if IMAGE_ORDERING == 'channels_last':
        bn_axis = 3
    else:
        bn_axis = 1

    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
    f1 = x
    x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), data_format=IMAGE_ORDERING, strides=(2, 2))(x)
    x = conv_block(x, 3, [32, 32, 128], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [32, 32, 128], stage=2, block='b')
    x = identity_block(x, 3, [32, 32, 128], stage=2, block='c')
    f2 = x

    x = conv_block(x, 3, [64, 64, 256], stage=3, block='a')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='c')
    x = identity_block(x, 3, [64, 64, 256], stage=3, block='d')
    f3 = x

    x = conv_block(x, 3, [128, 128, 512], stage=4, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='d')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='e')
    x = identity_block(x, 3, [128, 128, 512], stage=4, block='f')
    f4 = x

    x = conv_block(x, 3, [256, 256, 1024], stage=5, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=5, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=5, block='c')
    f5 = x


    return img_input, [f1, f2, f3, f4, f5]

def _unet(n_classes , encoder , input_height = 256, input_width = 256):

    img_input , levels = encoder(input_height=input_height,input_width=input_width)
    [f1 , f2 , f3 , f4 , f5] = levels

    o = f5
    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f4],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(512, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f3],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(256, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f2],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(128, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = UpSampling2D((2, 2), data_format=IMAGE_ORDERING)(o)
    # o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (2, 2), padding='same', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = concatenate([o,f1],axis=3)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)
    o = ZeroPadding2D((1, 1) , data_format=IMAGE_ORDERING)(o)
    o = Conv2D(64, (3, 3), padding='valid', data_format=IMAGE_ORDERING)(o)
    o = Activation('relu')(o)
    o = BatchNormalization()(o)

    o = Conv2D(n_classes, (3,3) ,padding='same',data_format=IMAGE_ORDERING)(o)
    o = Reshape((int(input_height)*int(input_width),-1))(o)
    o = Softmax()(o)
    model = Model(img_input,o)

    return model

def conv_unet(n_classes , input_height=256, input_width=256):

    model = _unet(n_classes,get_resnet50_encoder,input_height=input_height,input_width=input_width)
    model.model_name = 'conv_unet'
    return model

model = conv_unet(2)
model.summary()



什么?你用的不是Keras?
什么?你用的Pytroch?

### Res-UNet 模型架构详细说明 #### 基础 UNet 结构 Res-UNet 的基础是经典的 U-Net 架构,该架构由编码器路径和解码器路径组成。编码器负责提取图像中的高层次特征,而解码器则通过逐步上采样来恢复空间分辨率并生成最终的分割图。 #### 编码器部分 编码器通常采用预训练的卷积神经网络(CNN),如 VGG 或 ResNet 作为骨干网络。每一层都包含多个卷积操作以及池化层用于下采样。这种设计有助于捕捉不同尺度下的语义信息[^4]。 ```python import torch.nn as nn class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super(EncoderBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool2d(kernel_size=2) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) return self.pool(x), x ``` #### 解码器部分 解码器的任务是对来自编码器的信息进行重构,并逐渐提高输出的空间维度直至与输入一致。它利用跳跃连接将低级特征与高级抽象相结合,从而增强局部定位能力。此外,在某些版本中引入了残差模块以促进梯度传播。 ```python class DecoderBlock(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(DecoderBlock, self).__init__() self.upconv = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=2, stride=2) self.conv1 = nn.Conv2d(mid_channels * 2, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x, skip_connection): x = self.upconv(x) x = torch.cat([skip_connection, x], dim=1) x = self.conv1(x) x = self.bn1(x) return self.relu(x) ``` #### 加权注意力机制 为了进一步提升模型的表现力,特别是在处理复杂背景的情况下,加入了加权注意机制。这一改进允许网络更加关注于那些对于分类至关重要的区域,进而改善整体性能。具体来说,权重被分配给每个像素位置上的响应值,以便突出显示目标对象及其边界[^1]。 ```python class AttentionGate(nn.Module): def __init__(self, F_g, F_l, F_int): super(AttentionGate, self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1), nn.BatchNorm2d(F_int)) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1), nn.BatchNorm2d(F_int)) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1), nn.BatchNorm2d(1), nn.Sigmoid()) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.psi(torch.add(g1,x1)) return x * psi ``` #### 超参数调整 针对特定应用场景的需求,可以通过实验确定最佳配置方案。例如,在 Synapse 数据集的研究表明适当调节上采样的方式及优化算法的选择能够显著影响最终效果[^3]。
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值