为了应对秋招,最近在复习之前学过的深度学习算法。之前也复习了其他算法的基础知识和网络框架,但是之前没有想起来写一写博客,今天突然想起来可以写篇博客保存下来,一方面自己可以用于以后得复习,另一方面说不定可以帮助刚接触该算法的同学,有需要的同学可以互相学习下。
这里主要给出自己和结合他人的代码写的U-Net的网络框架,详细的内容解释请参考最后的参考博客。
本篇博客中算法也是根据网上其他大佬的笔记总结出来的,如果有问题可以指出互相学习!
一、U-Net基础
整个网络框架全部使用卷积完成特征的提取工作,该网络框架主要分为两部分:左半部分(contracting path,压缩路径)和右半部分(expansive path, 扩展路径),压缩路径主要用来搭建深层网络结构,提取深层次的语义信息,也被称为U-Net的encoder部分。该部分主要分为四个block,每个block包括了两个卷积和一个Max Pooling,其中Max Pooling为下采用操作,每次下采样后feature map的宽和高变为原来的一半。扩展路径同样有四个block,每个包括两个卷积和一个上采样,由于下采用操作会丢失部分细节信息,在上采样中会通过copy and cat操作剪切和拼接浅层的语义信息,这样可以保留较多的语义信息(在拼接操作时注意尺寸大小,需要进行剪裁操作)。
二、主体代码
import torch.nn as nn
from torchvision import transforms
import torch
class Double_convs(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
super(Double_convs, self).__init__()
self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.block(x)
return x
class Block_DownSample(nn.Module):
def __init__(self) -> None:
super(Block_DownSample, self).__init__()
self.block_1 = Double_convs(1, 64)
self.block_2 = Double_convs(64, 128)
self.block_3 = Double_convs(128, 256)
self.block_4 = Double_convs(256, 512)
self.block_5 = Double_convs(512, 1024)
self.maxpooling = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
def forward(self, x):
out_1 = self.block_1(x)
x = self.maxpooling(out_1)
out_2 = self.block_2(x)
x = self.maxpooling(out_2)
out_3 = self.block_3(x)
x = self.maxpooling(out_3)
out_4 = self.block_4(x)
x = self.maxpooling(out_4)
out_5 = self.block_5(x)
return out_1, out_2, out_3, out_4, out_5
class Block_UpSample(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
super(Block_UpSample, self).__init__()
self.upsample_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
nn.ReLU(inplace=True))
self.block_1 = Double_convs(in_channels, out_channels)
def forward(self, x1, x2, wh):
x1 = self.upsample_1(x1)
centercrop = transforms.CenterCrop((wh, wh)) # 这里我采用中心裁剪的方式,不知道对不对
x2 = centercrop(x2)
x = torch.cat([x1, x2], dim=1)
x = self.block_1(x)
return x
class U_net(nn.Module):
def __init__(self) -> None:
super(U_net, self).__init__()
self.down = Block_DownSample()
self.up_1 = Block_UpSample(1024, 512)
self.up_2 = Block_UpSample(512, 256)
self.up_3 = Block_UpSample(256, 128)
self.up_4 = Block_UpSample(128, 64)
self.conv = nn.Sequential(nn.Conv2d(64, 2, kernel_size=1))
def forward(self, x):
out_1, out_2, out_3, out_4, out_5 = self.down(x)
x = self.up_1(out_5, out_4, 56)
x = self.up_2(x, out_3, 104)
x = self.up_3(x, out_2, 200)
x = self.up_4(x, out_1, 392)
x = self.conv(x)
return x
看到其他博文中有博主提到在encoder中卷积部分的padding设置成1可以不进行剪切就能直接进行拼接操作,有兴趣的同学可以尝试一下。
整个网络的输入输出信息如下所示:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 640
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
Double_convs-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
Double_convs-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
Double_convs-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
Double_convs-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
Double_convs-39 [-1, 1024, 28, 28] 0
Block_DownSample-40 [[-1, 64, 568, 568], [-1, 128, 280, 280], [-1, 256, 136, 136], [-1, 512, 64, 64], [-1, 1024, 28, 28]] 0
Upsample-41 [-1, 1024, 56, 56] 0
Conv2d-42 [-1, 512, 56, 56] 524,800
ReLU-43 [-1, 512, 56, 56] 0
Conv2d-44 [-1, 512, 54, 54] 4,719,104
BatchNorm2d-45 [-1, 512, 54, 54] 1,024
ReLU-46 [-1, 512, 54, 54] 0
Conv2d-47 [-1, 512, 52, 52] 2,359,808
BatchNorm2d-48 [-1, 512, 52, 52] 1,024
ReLU-49 [-1, 512, 52, 52] 0
Double_convs-50 [-1, 512, 52, 52] 0
Block_UpSample-51 [-1, 512, 52, 52] 0
Upsample-52 [-1, 512, 104, 104] 0
Conv2d-53 [-1, 256, 104, 104] 131,328
ReLU-54 [-1, 256, 104, 104] 0
Conv2d-55 [-1, 256, 102, 102] 1,179,904
BatchNorm2d-56 [-1, 256, 102, 102] 512
ReLU-57 [-1, 256, 102, 102] 0
Conv2d-58 [-1, 256, 100, 100] 590,080
BatchNorm2d-59 [-1, 256, 100, 100] 512
ReLU-60 [-1, 256, 100, 100] 0
Double_convs-61 [-1, 256, 100, 100] 0
Block_UpSample-62 [-1, 256, 100, 100] 0
Upsample-63 [-1, 256, 200, 200] 0
Conv2d-64 [-1, 128, 200, 200] 32,896
ReLU-65 [-1, 128, 200, 200] 0
Conv2d-66 [-1, 128, 198, 198] 295,040
BatchNorm2d-67 [-1, 128, 198, 198] 256
ReLU-68 [-1, 128, 198, 198] 0
Conv2d-69 [-1, 128, 196, 196] 147,584
BatchNorm2d-70 [-1, 128, 196, 196] 256
ReLU-71 [-1, 128, 196, 196] 0
Double_convs-72 [-1, 128, 196, 196] 0
Block_UpSample-73 [-1, 128, 196, 196] 0
Upsample-74 [-1, 128, 392, 392] 0
Conv2d-75 [-1, 64, 392, 392] 8,256
ReLU-76 [-1, 64, 392, 392] 0
Conv2d-77 [-1, 64, 390, 390] 73,792
BatchNorm2d-78 [-1, 64, 390, 390] 128
ReLU-79 [-1, 64, 390, 390] 0
Conv2d-80 [-1, 64, 388, 388] 36,928
BatchNorm2d-81 [-1, 64, 388, 388] 128
ReLU-82 [-1, 64, 388, 388] 0
Double_convs-83 [-1, 64, 388, 388] 0
Block_UpSample-84 [-1, 64, 388, 388] 0
Conv2d-85 [-1, 2, 388, 388] 130
================================================================
Total params: 28,953,474
Trainable params: 28,953,474
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 3864.11
Params size (MB): 110.45
Estimated Total Size (MB): 3975.81
----------------------------------------------------------------
有问题可以提出来,我们共同学习!
参考博客: