关键点
- 编码-解码 结构
- UNet经过4次下采样和4次上采样后再做出预测,而不是直接在编码器最后一个阶段直接还原回原输入图尺寸。4次上采样逐步还原与细化编码器输出的 拥有高级语义特征信息的特征图,能增强边缘等细节信息的预测
- 跳跃链接-长链接
- 有效结合浅层信息与深层信息,补充上采样时信息不足。经过多层卷积和池化后,难免损失了些空间信息和低级特征
常见的一些改动
- 对编码器的改动
- 显然UNet当中用作于特征提取的编码设计过于简单,可以使用一些流行分类网络嵌入作为编码器如ResNet,DenseNet 等等
- 多尺度策略
- 输入不同尺度的图像到网络中进行训练
- 结合不同大小的卷积核的卷积操作
- 结合不同空洞率的空洞卷积操作
- 注意力策略
- 空间注意力
- 通道注意力
- 空间-通道注意力
- 基于RNN, LSTM 等思想的注意力
- 深度监督
关于UNet的一些细节研究
可以到CSDN 和 知乎上搜UNet++(知乎有作者解析) 和 UNet3+ 等文章解读;
这两篇文章在我看来改动的大体方向差不多,评论区也说到UNet++ 这类结构在Kaggle挺吃香,
用于实际应用提点是一个选择,但想基于此来再进一步发paper有点难。
这两篇文章的实验可以更好的认识到一些基本的改动对网络的影响。
UNet代码实现-Pytorch
import torch
import torch.nn as nn
import torchsummaryX
def conv3x3(ch_in, ch_out):
return nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
class Conv3x3_BN_PReLU(nn.Module):
def __init__(self, ch_in, ch_out):
super(Conv3x3_BN_PReLU, self).__init__()
self.conv = nn.Sequential(
conv3x3(ch_in, ch_out),
nn.BatchNorm2d(ch_out),
nn.PReLU(ch_out)
)
def forward(self, x):
out = self.conv(x)
return out
class UNetConv(nn.Module):
def __init__(self, ch_in, ch_out):
super(UNetConv, self).__init__()
self.conv = nn.Sequential(
Conv3x3_BN_PReLU(ch_in, ch_out),
Conv3x3_BN_PReLU(ch_out, ch_out)
)
def forward(self, x):
out = self.conv(x)
return out
class UNetDownsample(nn.Module):
def __init__(self, ch_in, ch_out):
super(UNetDownsample, self).__init__()
self.pool = nn.MaxPool2d(2)
self.conv = UNetConv(ch_in, ch_out)
def forward(self, x):
x_pool = self.pool(x)
out = self.conv(x_pool)
return out
class UNetUpsample(nn.Module):
def __init__(self, ch_in, ch_out, is_trans=False):
super(UNetUpsample, self).__init__()
self.is_trans = is_trans
if self.is_trans is True:
self.up = nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride=2)
else:
self.up = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
Conv3x3_BN_PReLU(ch_in, ch_out)
)
self.conv = UNetConv(ch_in, ch_out)
def forward(self, x, en):
x_up = self.up(x)
concat = torch.cat([en, x_up], 1)
out = self.conv(concat)
return out
class UNet(nn.Module):
def __init__(self, img_in, num_classes):
super(UNet, self).__init__()
self.Input_conv = UNetConv(img_in, 64)
self.en1 = UNetDownsample(64, 128)
self.en2 = UNetDownsample(128, 256)
self.en3 = UNetDownsample(256, 512)
self.en4 = UNetDownsample(512, 1024)
self.de4 = UNetUpsample(1024, 512)
self.de3 = UNetUpsample(512, 256)
self.de2 = UNetUpsample(256, 128)
self.de1 = UNetUpsample(128, 64)
self.final = nn.Conv2d(64, num_classes, kernel_size=1, stride=1)
def forward(self, x):
in_conv = self.Input_conv(x)
en1 = self.en1(in_conv)
en2 = self.en2(en1)
en3 = self.en3(en2)
en4 = self.en4(en3)
de4 = self.de4(en4, en3)
de3 = self.de3(de4, en2)
de2 = self.de2(de3, en1)
de1 = self.de1(de2, in_conv)
out = self.final(de1)
return out
if __name__ == "__main__":
input = torch.rand(1, 3, 256, 256)
model = UNet(3, 1)
torchsummaryX.summary(model.cuda(), input.cuda())