深度学习论文阅读——UNet


在这里插入图片描述

关键点

  • 编码-解码 结构
    • UNet经过4次下采样和4次上采样后再做出预测,而不是直接在编码器最后一个阶段直接还原回原输入图尺寸。4次上采样逐步还原与细化编码器输出的 拥有高级语义特征信息的特征图,能增强边缘等细节信息的预测
  • 跳跃链接-长链接
    • 有效结合浅层信息与深层信息,补充上采样时信息不足。经过多层卷积和池化后,难免损失了些空间信息和低级特征

常见的一些改动

  • 对编码器的改动
    • 显然UNet当中用作于特征提取的编码设计过于简单,可以使用一些流行分类网络嵌入作为编码器如ResNet,DenseNet 等等
  • 多尺度策略
    • 输入不同尺度的图像到网络中进行训练
    • 结合不同大小的卷积核的卷积操作
    • 结合不同空洞率的空洞卷积操作
  • 注意力策略
    • 空间注意力
    • 通道注意力
    • 空间-通道注意力
    • 基于RNN, LSTM 等思想的注意力
  • 深度监督

关于UNet的一些细节研究

可以到CSDN 和 知乎上搜UNet++(知乎有作者解析) 和 UNet3+ 等文章解读;
这两篇文章在我看来改动的大体方向差不多,评论区也说到UNet++ 这类结构在Kaggle挺吃香,
用于实际应用提点是一个选择,但想基于此来再进一步发paper有点难。
这两篇文章的实验可以更好的认识到一些基本的改动对网络的影响。 

UNet代码实现-Pytorch

import torch
import torch.nn as nn
import torchsummaryX

def conv3x3(ch_in, ch_out):
    return nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)

class Conv3x3_BN_PReLU(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(Conv3x3_BN_PReLU, self).__init__()
        self.conv = nn.Sequential(
            conv3x3(ch_in, ch_out),
            nn.BatchNorm2d(ch_out),
            nn.PReLU(ch_out)
            )
    
    def forward(self, x):
        out = self.conv(x)
        return out

class UNetConv(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(UNetConv, self).__init__()
        self.conv = nn.Sequential(
            Conv3x3_BN_PReLU(ch_in, ch_out),
            Conv3x3_BN_PReLU(ch_out, ch_out)
        )

    def forward(self, x):
        out = self.conv(x)
        return out

class UNetDownsample(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(UNetDownsample, self).__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = UNetConv(ch_in, ch_out)
    
    def forward(self, x):
        x_pool = self.pool(x)
        out = self.conv(x_pool)
        return out

class UNetUpsample(nn.Module):
    def __init__(self, ch_in, ch_out, is_trans=False):
        super(UNetUpsample, self).__init__()
        self.is_trans = is_trans

        if self.is_trans is True:
            self.up = nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride=2)
        else:
            self.up = nn.Sequential(
                nn.UpsamplingBilinear2d(scale_factor=2),
                Conv3x3_BN_PReLU(ch_in, ch_out)
            )

        self.conv = UNetConv(ch_in, ch_out)
    
    def forward(self, x, en):
        x_up = self.up(x)
        concat = torch.cat([en, x_up], 1)
        out = self.conv(concat)
        return out

class UNet(nn.Module):
    def __init__(self, img_in, num_classes):
        super(UNet, self).__init__()
        self.Input_conv = UNetConv(img_in, 64)
        self.en1 = UNetDownsample(64, 128)
        self.en2 = UNetDownsample(128, 256)
        self.en3 = UNetDownsample(256, 512)
        self.en4 = UNetDownsample(512, 1024)

        self.de4 = UNetUpsample(1024, 512)
        self.de3 = UNetUpsample(512, 256)
        self.de2 = UNetUpsample(256, 128)
        self.de1 = UNetUpsample(128, 64)

        self.final = nn.Conv2d(64, num_classes, kernel_size=1, stride=1)

    def forward(self, x):
        in_conv = self.Input_conv(x)

        en1 = self.en1(in_conv)
        en2 = self.en2(en1)
        en3 = self.en3(en2)
        en4 = self.en4(en3)

        de4 = self.de4(en4, en3)
        de3 = self.de3(de4, en2)
        de2 = self.de2(de3, en1)
        de1 = self.de1(de2, in_conv)

        out = self.final(de1)
        return out

if __name__ == "__main__":
    input = torch.rand(1, 3, 256, 256)
    model = UNet(3, 1)
    torchsummaryX.summary(model.cuda(), input.cuda())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值