UNet论文
UNet的简介
- UNet是一个
对称
的网络结构,左侧为下采样,右侧为上采样
; - 下采样为encoder,上采样为decoder;
- 四条灰色的平行线,就是在上采样的过程中,融合下采样过程的特征图的通道,Concat
原理
就是:一本大小为10cm10cm的书,厚度为3cm的书本(10103)的A书,和一本大小为10cm10cm,厚度为4cm的B书(10103)- 将A书和B书,边缘对齐的摞在一起,这样就可以得到一个大小10107的一摞书了
- 所以对feature map,一个大小为
256*256*64
的feature map(w为256,h为256,c为64),和一个大小为256*256*32
的feature map进行Concat融合,你就会得到一个大小为256*256*96
的feature map - 在实际使用中,Concat融合的两个feature map的大小不一定相同,例如25625664的feature map和24024032的feature map进行Concat
- 两种方法
- 1.将大的
256*256*64
的feature map进行裁剪,裁剪为240*240*64
的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat
,得到24024096的feature map。 - 2.将小的
240*240*32
的feature map进行padding操作
,padding为256*256*32
的feature map,比如上下左右,各补8 pixel,padding后再进行Concat
,得到25625696的feature map。
- 1.将大的
- UNet采用的Concat方案就是
第二种
,将小的feature map进行padding,padding的方式是补0
,一种常规
的常量填充。(详细看代码Up)
- 两种方法
代码解读
-
组成U-Net的模型块主要有如下几个部分:
-
1)每个子块内部的两次卷积(Double Convolution)
-
2)左侧模型块之间的下采样连接,即最大池化(Max pooling)
-
3)右侧模型块之间的上采样连接(Up sampling)
-
4)输出层的处理(OutConv)
-
DoubleConv模块
- 两次卷积操作:
class DoubleConv(nn.Module):
# mid_channel是第一次conv的out和第二次conv的输入
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
# 大小 (高宽 + 2*padding - kernel_size)/stride + 1
# (572 + 2*0 - 3 )/1 +1 = 570
# 通道1->64
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=0, bias=False),
# 帮助网络训练, 对输入数据做规范化,称为Covariate shift
# BatchNorm后是不改变输入的shape
# num_features: 输入维度,也就是数据的特征维度;
# eps: 是在分母上加的一个值,是为了防止分母为0的情况,让其能正常计算;
# affine: 是仿射变化,将,分别初始化为1和0;
# nn.BatchNorm2d是对channel做归一化处理,也就是对批次内的特征进行归一化
# 加快收敛,防止梯度爆炸和消失
nn.BatchNorm2d(mid_channels),
# inplace = True 时,会修改输入对象的值,所以打印出对象存储地址相同,类似于C语言的址传递
# inplace = False 时,不会修改输入对象的值,而是返回一个新创建的对象,所以打印出对象存储地址不同,类似于C语言的值传递
nn.ReLU(inplace=True),
# (570 + 2*0 - 3 )/1 +1 = 568
# 设置padding=1(原始的是设置为0,会改变)经过卷积后不会改变特征层的大小,这也是现在主流的实现方式
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.double_conv(x)
return x
- 注意代码里面的卷积是如何计算的
( (高宽 + 2*padding - kernel_size)/stride + 1),如U图中的初始:(572 + 2*0 - 3 )/1 +1 = 570
Down模块
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
# 扩大通道64->128
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
Up模块
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# 上采样需要插值,双
if bilinear:
# 两倍
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
# 反卷积
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
# 算出相差多少
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
# F.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充
# 最后一维padding,第一个元素代表左边padding的个数,第二个元素代表右边padding的个数
# input:需要扩充的tensor,可以是图像数据,抑或是特征矩阵数据
# pad:扩充维度,用于预先定义出某维度上的扩充参数
# mode:扩充方法,’constant‘, ‘reflect’ or ‘replicate’三种模式,分别表示常量,反射,复制
# value:扩充时指定补充值,但是value只在mode='constant’有效,
# 即使用value填充在扩充出的新维度位置,而在’reflect’和’replicate’模式下,value不可赋值
# https://www.modb.pro/db/227153
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
__init__
初始化函数里定义的上采样方法
以及卷积采用DoubleConv
。-
上采样,定义了两种方法:
Upsample
和ConvTranspose2d
,也就是双线性插值
和反卷积
。-
双线性插值:
-
简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。
-
对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。
-
-
反卷积:
-
就是反着卷积
-
下面的蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。
-
这个示意图,就是一个从
2*2的feature map->4*4的feature map过程
。 -
在
forward前向传播函数中
,x1
接收的是上采样的数据
,x2
接收的是特征融合的数据
。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat
。
-
-
OutConv模块
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__() # 和super().__init__()一样
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
整个UNet
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
# 最左边
self.inc = DoubleConv(n_channels, 64)
# 下采样(里面包括下采样完的两次卷积)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
# 插值方式
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 512 // factor, bilinear)
self.up3 = Up(256, 256 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits =self.outc(x)
return logits