【干货教学】unet进阶,如何在unet中加入resnet(残差连接)


U-Net进阶教程:如何在U-Net中加入ResNet的残差连接

在本教程中,我们将探讨如何在经典的U-Net架构中融入ResNet的残差连接。这种结合了U-Net在图像分割领域的优势和ResNet的残差连接的混合模型,我们称之为ResUnet,旨在通过残差学习改善模型的训练效率和性能。

1.什么是残差连接?

残差连接是一种允许数据直接从网络的较低层传递到较高层的结构。这种方式可以帮助解决深度神经网络训练过程中的梯度消失问题,使得网络能够学习到更加复杂的功能。

2.ResUnet架构

2.1 代码实现

在Res1Unet中,我们在每个下采样(编码)和上采样(解码)步骤中都加入了残差连接,本质上是通过一个核为1的卷积操作来实现维度匹配。以下是Python中的实现代码和相应的解释。

class ResUnet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Res1Unet, self).__init__()

        # down sampling
        # 假如输入 224*224*1 的图像
        # H = ((224 - 3 + 1 + 2 - 1) / 1) + 1 = 224  unet的卷积不会改变特征图的大小
        self.conv1 = DoubleConv(in_ch, 64)
        # to increase the dimensions
        self.w1 = nn.Conv2d(in_ch, 64, kernel_size=1, padding=0, stride=1)
        self.pool1 = nn.MaxPool2d(2)  # 224 -> 112

        self.conv2 = DoubleConv(64, 128)  # 不变
        # to increase the dimensions
        self.w2 = nn.Conv2d(64, 128, kernel_size=1, padding=0, stride=1)
        self.pool2 = nn.MaxPool2d(2)  # 56

        self.conv3 = DoubleConv(128, 256)
        # to increase the dimensions
        self.w3 = nn.Conv2d(128, 256, kernel_size=1, padding=0, stride=1)
        self.pool3 = nn.MaxPool2d(2)  # 28

        self.conv4 = DoubleConv(256, 512)
        # to increase the dimensions
        self.w4 = nn.Conv2d(256, 512, kernel_size=1, padding=0, stride=1)
        self.pool4 = nn.MaxPool2d(2)  # 14

        self.conv5 = DoubleConv(512, 1024)
        # to increase the dimensions
        self.w5 = nn.Conv2d(512, 1024, kernel_size=1, padding=0, stride=1)

        # up sampling
        # H_out = (14 - 1) * 2 + 2 = 28 往上反卷积
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)   # 28

        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)

        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)

        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)

        self.conv10 = nn.Conv2d(64, out_ch, 1)

        # 训练时尝试让神经元失活,加大泛化性,仅在训练时使用,pytorch自动补偿参数
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        # 下采样部分
        down0_res = self.w1(x)  # residual block,残差连接
        down0 = self.conv1(x) + down0_res
        down1 = self.pool1(down0)

        down1_res = self.w2(down1)  # residual block
        down1 = self.conv2(down1) + down1_res
        down2 = self.pool2(down1)

        down2_res = self.w3(down2)
        down2 = self.conv3(down2) + down2_res
        down3 = self.pool3(down2)

        down3_res = self.w4(down3)
        down3 = self.conv4(down3) + down3_res
        down4 = self.pool4(down3)

        down4_res = self.w5(down4)
        # 5 , 连接上采样部分前,双卷积卷积操作    [14, 14, 1024]
        down5 = self.conv5(down4) + down4_res

        # 上采样部分
        up_6 = self.up6(down5)   # [28, 28, 512]
        merge6 = torch.cat([up_6, down3], dim=1)    # cat之后又变为[28, 28, 1024]
        c6 = self.conv6(merge6)   # 重新双卷积变为[28, 28, 512]

        up_7 = self.up7(c6)   # [56, 56, 256]
        merge7 = torch.cat([up_7, down2], dim=1)
        c7 = self.conv7(merge7) # [56, 56, 256]

        up_8 = self.up8(c7)   # [112, 112, 128]
        merge8 = torch.cat([up_8, down1], dim=1)
        c8 = self.conv8(merge8) # [112, 112, 128]

        up_9 = self.up9(c8)   # [224, 224, 64]
        merge9 = torch.cat([up_9, down0], dim=1)
        c9 = self.conv9(merge9)  # [224, 224, 64]

        c10 = self.conv10(c9)  # 卷积输出最终图像   [224, 224, t]

        return c10

在这个例子中,w1w2w3w4w5是为了匹配维度而设置的1x1卷积,它们允许我们将原始输入或下采样后的特征添加到特征图中,这样就实现了残差连接。在U-Net的每个编码阶段之后,我们都会加上一个这样的残差连接。

2.2 图示

2.2.1 unet

在这里插入图片描述

2.2.2 加入残差连接改进

在这里插入图片描述

3.为什么要使用ResUnet?

3.1优势

  1. 改善梯度流通:通过加入残差连接,梯度可以直接流经较短的路径,减少训练过程中的梯度消失问题。
  2. 加速收敛:残差连接有助于网络更快地收敛,提高训练效率。
  3. 提高性能:Res1Unet可以更好地捕捉到图像的细节和上下文信息,提高分割的准确性。

3.2缺点

  1. 增加计算负担:虽然残差连接有很多优点,但它们也会稍微增加前向和后向传播时的计算负担。
  2. 可能导致过拟合:在一些小数据集上,过于复杂的模型可能会导致过拟合。

4.结论

Res1Unet是一个强大的网络架构,它结合了U-Net的优秀特性和ResNet的强大能力。虽然这可能会带来一些额外的计算成本,但在许多情况下,这种额外的成本是值得的,因为它可以显著提升模型性能。
希望本教程能够帮助你理解如何在U-Net中加入残差连接,并鼓励你尝试将这种方法应用到你自己的项目中。


往期精彩干货
基于mmdetection3d的单目3D目标检测模型,效果远超CenterNet3D
SSH?Termius?一篇文章教你使用远程服务器训练
Jetson nano开机自启动python程序
【代码实践】focal loss损失函数及其变形原理详细讲解和图像分割实践(含源码)

  • 30
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WanHeng WyattVan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值