在写Unet网络的部分之前需要知道VGG16网络的架构,因为Unet借鉴了VGG16前面部分的结构。
Unet借鉴了除了全连接层外,即最后一个最大池化层之前的所有结构,以下为Unet网络结构。
然后根据Unet的网络结构图,大致可以写得代码如下。为了方便,卷积块和反卷积块的结构都写得比较简单。
def conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels), #定义了一个二维批量归一化层,标准化每个小批量的特征图
nn.ReLU(inplace=True) #就地操作,节约内存
)
def up_conv(in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.ReLU(inplace=True)
)
from torchvision.models import vgg16_bn
class UNet(nn.Module):
def __init__(self, pretrained=True, out_channels=12):
super().__init__()
#编码器,将 VGG16 模型的前34层分为5个卷积块,分别提取低级到高级特征
self.encoder = vgg16_bn(pretrained=pretrained).features
self.block1 = nn.Sequential(*self.encoder[:6]) #包含5个层Conv1_1, ReLU,Conv1_2, ReLU,MaxPool1,输出通道数64
self.block2 = nn.Sequential(*self.encoder[6:13]) #输出通道数128
self.block3 = nn.Sequential(*self.encoder[13:20]) #包含7个层Conv3_1, ReLU,Conv3_2, ReLU,Conv3_3, ReLU,MaxPool3,256
self.block4 = nn.Sequential(*self.encoder[20:27]) #512
self.block5 = nn.Sequential(*self.encoder[27:34]) #512
#连接层
self.bottleneck = nn.Sequential(*self.encoder[34:])
self.conv_bottleneck = conv(512, 1024) #将通道数从512增加到1024
#解码器,恢复原图尺寸
self.up_conv6 = up_conv(1024, 512) #上采样到512
self.conv6 = conv(512 + 512, 512) #上采样后的特征图与block5特征图拼接,并卷积处理
self.up_conv7 = up_conv(512, 256)
self.conv7 = conv(256 + 512, 256)
self.up_conv8 = up_conv(256, 128)
self.conv8 = conv(128 + 256, 128)
self.up_conv9 = up_conv(128, 64)
self.conv9 = conv(64 + 128, 64)
self.up_conv10 = up_conv(64, 32)
self.conv10 = conv(32 + 64, 32)
self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1) #输出层
反向传播部分代码。
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
bottleneck = self.bottleneck(block5)
x = self.conv_bottleneck(bottleneck)
x = self.up_conv6(x)
x = torch.cat([x, block5], dim=1) #将上采样后的特征图x与编码器相应层的特征图 block5 进行拼接(跳跃连接),保留高分辨率的细节信息
x = self.conv6(x)
x = self.up_conv7(x)
x = torch.cat([x, block4], dim=1)
x = self.conv7(x)
x = self.up_conv8(x)
x = torch.cat([x, block3], dim=1)
x = self.conv8(x)
x = self.up_conv9(x)
x = torch.cat([x, block2], dim=1)
x = self.conv9(x)
x = self.up_conv10(x)
x = torch.cat([x, block1], dim=1)
x = self.conv10(x)
x = self.conv11(x)
return x