动手搭建U-Net
我终于开始搭建第一个网络了,我决定拿UNet下手
论文中对该结构的描述是这样的:
Network Architecture
The network architecture is illustrated in Figure 1. It consists of a contracting path (left side) and an expansive path (right side). The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels. Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution. At the final layer a 1x1 convolution is used to map each 64component feature vector to the desired number of classes. In total the network has 23 convolutional layers.
将图片对称分为两个部分,左边为channel扩张的部分,右边为channel缩小的部分。
简单搭建的代码如下:
#2020-2-15
#作者:lee
#U-NET
#边搭边学
import torch
from torch import nn
#每层中都会用到的步骤:一个卷积层
class ConvLayer(nn.Module):
def __init__(self,in_ch, out_ch):
super(ConvLayer,self).__init__()
self.conv = nn.Sequential(
nn.conv2d(in_ch,out_ch,3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.conv2d(out_ch,out_ch,3),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self,in_ch,out_ch):
return self.conv(in_ch,out_ch)
#写U-NET
class UNet(nn.Module):
def __init__(self, in_ch,out_ch):
super(UNet, self).__init__()
#形式为函数之间的赋值(传递的是函数,而不是参数),得到的是函数
self.conv1 = ConvLayer(in_ch,64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = ConvLayer(64,128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = ConvLayer(128,256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = ConvLayer(256,512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = ConvLayer(512,1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = ConvLayer(1024,512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = ConvLayer(512,256)
self.up8 = nn.ConvTranspose2d(2)
self.conv8 = ConvLayer(256,128)
self.up9 = nn.ConvTranspose2d(2)
self.conv9 = ConvLayer(512,out_ch)
def forward(self,in_ch):
c1 = self.conv1(in_ch)#定义出来是为了后面要做cat(拼接)
p1 = self.pool1(c1)
c2 = self.conv2(p1)
p2 = self.pool2(c2)
c3 = self.conv3(p2)
p3 = self.pool3(c3)
c4 = self.conv4(p3)
p4 = self.pool4(c4)
c5 = self.conv5(p4)
u6 = self.up6(c5)
m6 = torch.cat([c4,u6], dim=1)
c6 = self.conv6(m6)
u7 = self.up7(c6)
m7 = torch.cat([c3,u7], dim=1)
c7 = self.conv7(m7)
u8 = self.up8(c7)
m8 = torch.cat([c2,u8], dim=1)
c8 = self.conv5(m8)
u9 = self.up9(c8)
m9 = torch.cat([c1,u9], dim=1)
c9 = self.conv5(m9)
return c9
这里用到的几个函数:
ConvTransposed:逆卷积,在右边与左边的maxpooling相呼应。不同的是maxpooling改变的是图的尺寸,而ConvTransposed会改变channels的数量
cat:图像的拼接,在右边改变了图的尺寸,但我还是对dim=1不太理解
论文的其他部分还没有仔细看,这个网络的构思也是我好奇的地方。我决定用这个网络先完成一两个具体的任务,然后再回头仔细看这篇论文