参考:
[1] B站霹导
1 UNet模型架构
整个框架可以分为以下几个模块:
-
DoubleConv :
接收的参数包括(in_c, out_c, mid_c = None)
; 首先需要判断是否有mid_c
,如果没有,则令其为mid_c=out_c
,一般来说,下采样部分是没有mid_c
的,而上采样部分有mid_c
。
主要包括的层有:
(1)nn.conv2d(in_c, mid_c,kernel_size=3,padding=1,bias=False)
(2)nn.BatchNorm2d(mid_c)
(3)nn.ReLU()
(4)nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1,bias=False)
(5)nn.BatchNorm2d(out_c)
(6)nn.ReLU()
-
Down :
主要包括一个nn.MaxPool2d(kernel_size=2,stride=2)
,高宽缩小一半,深度不变; -
Up :
使用双线性插值代替原论文中的转置卷积,并且DoubleConv
的参数为(in_c,out_c,in_c//2)
,上采样之后先进行concat操作,然后再通过doubleconv,定义和前向过程如下:
简洁版....
def __init__(self,in_c,out_c):
self.up = nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True)
self.conv = DoubleConv(in_c,out_c,in_c//2)
def forward(self,x_1,x_2)
x_1 = self.up(x_1)
x_1 = torch.cat([x_1,x_2],dim=1) # 深度维拼接
x = self.conv(x_1)
return x
- Out :
一个1x1的卷积,输入维为in_c
,输出维为类别数
2 手敲UNet代码
# This Python file uses the following encoding: utf-8
import torch
import torch. nn as nn
class DoubleConv(nn.Sequential):
def __init__(self, in_c, out_c, mid_c=None):
if mid_c is None:
mid_c = out_c
super().__init__(
nn.Conv2d(in_c,mid_c,kernel_size=3,padding=1,bias=False),
nn.BatchNorm(mid_c),
nn.ReLU(),
nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1,bias=False),
nn.BatchNorm(out_c),
nn.ReLU()
)
class Down(nn.Sequential):
def __init__(self,in_c,out_c):
super().__init__(
nn.MaxPool2d(kernel_size=2,stride=2),
DoubleConv(in_c,out_c)
)
class Up(nn.Module):
def __init__(self,in_c,out_c):
super().__init__()
self.up = nn.UpSample(scale_factor=2,mode='bilinear',align_corner=True)
self.conv = DoubleConv(in_c,out_c,in_c//2)
def forward(self,x1,x2):
x1 = self.up(x1)
x1 = torch.cat([x1,x2],dim=1)
x = self.conv(x1)
return x
class Out(nn.Module):
def __init__(self,in_c,num_c):
super().__init__()
self.out = nn.Conv2d(in_c,num_c,kernel_size=1)
def forward(self,x):
return self.out(x)
class UNet(nn.Module):
def __init__(self,in_c,num_c,base_c=64):
self.in_conv = DoubleConv(in_c,base_c) # 3->64
self.down1 = Down(base_c,base_c*2) # 64->128
self.down2 = Down(base_c*2,base_c*4) # 128->256
self.down3 = Down(base_c*4,base_c*8) # 256->512
self.dowm4 = Down(base_c*8,base_c*8) # 512 still
self.up1 = Up(base_c*16,base_c*4) # 输入的维度为concat之后的维度,既1024 -> 512 -> 256
self.up2 = Up(base_c*8, base_c*2) # 512-> 256-> 128
self.up3 = Up(base_c*4, base_c*1) # 256->128->64
self.up4 = Up(base_c*2, base_c) # 128->64->64
self.out = Out(base_c,num_c)
def forward(self,x):
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5,x4)
x = self.up2(x,x3)
x = self.up3(x,x2)
x = self.up4(x,x1)
x = self.out(x)
return x
model = UNet(in_c=3,num_c=5,base_c=64)
x = torch.rand(16,3,480,480)
out = model(x)
print(out.shape)