最早被提出应用于医学图像分割,后扩展至通用分割,后面在low-level领域也发挥着巨大的作用。
U-Net
友情链接:U-Net: Convolutional Networks for Biomedical Image Segmentation)
网络原理
U-Net 的网络结构其实很简单,类似于传统图像处理中的金字塔结构。对输入进行多次的 conv+relu 特征提取,然后进行 maxpooling 下采样,扩大感受野的同时减小特征图尺寸,循环多次后得到上图中最下面一层的特征图,然后进行 upsample+conv+concat,再对上采样之后的特征图进行 conv+relu 操作,和前面一样重复多次,便得到了最后的结果。
如今的 U-Net 不光是在分割领域,在笔者所从事的 low-level 视觉中也得到了广泛的应用,比如超分、降噪等,他们的一个共同特征是,输入和输出往往是相同尺寸的,image2image 的任务。
pytorch 代码
import torch
import torch.nn as nn
def double_conv_relu(n_in, n_out):
block = nn.Sequential(
nn.Conv2d(n_in, n_out, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(n_out, n_out, 3, 1, 1),
nn.ReLU(),
)
return block
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.conv1 = double_conv_relu(1, 64)
self.down1 = nn.MaxPool2d(2, 2)
self.conv2 = double_conv_relu(64, 128)
self.down2 = nn.MaxPool2d(2, 2)
self.conv3 = double_conv_relu(128, 256)
self.down3 = nn.MaxPool2d(2, 2)
self.conv4 = double_conv_relu(256, 512)
self.down4 = nn.MaxPool2d(2, 2)
self.conv5 = double_conv_relu(512, 1024)
self.up1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)
self.conv6 = double_conv_relu(1024, 512)
self.up2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
self.conv7 = double_conv_relu(512, 256)
self.up3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.conv8 = double_conv_relu(256, 128)
self.up4 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.conv9 = double_conv_relu(128, 64)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, x):
feat1 = self.conv1(x)
feat2 = self.conv2(self.down1(feat1))
feat3 = self.conv3(self.down2(feat2))
feat4 = self.conv4(self.down3(feat3))
feat5 = self.conv5(self.down4(feat4))
feat6 = self.conv6(torch.cat((feat4, self.up1(feat5)), dim=1))
feat7 = self.conv7(torch.cat((feat3, self.up2(feat6)), dim=1))
feat8 = self.conv8(torch.cat((feat2, self.up3(feat7)), dim=1))
feat9 = self.conv9(torch.cat((feat1, self.up4(feat8)), dim=1))
out = self.conv_last(feat9)
return out
if __name__ == "__main__":
x = torch.rand(1, 1, 256, 256)
net = UNet(3, 64, 3)
y = net(x)
print(y.shape)
自己写的,比较粗糙,主要看一下网络结构。因为 U-Net 提出是在 2015 年,很早了,后续在应用的时候会把原本代码中的 conv+relu 的结构换成残差块的结构,可以取得更好的效果。
UNet++
友情链接:UNet++: A Nested U-Net Architecture for Medical Image Segmentation
nnUNet
友情链接:nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation