目录
参考资料
论文:
U-Net: Convolutional Networks for Biomedical Image Segmentation
博客:
第1章 前言
深度网络通常需要大量的数据进行训练,当样本量较少的情况下,深度网络可能表现没那么好。对于这个问题,本文提出了新的网络架构和图像增强策略。网络架构包括encoder和decoder(论文中称为浓缩路径(contracting path)和扩展路径(expanding path),实际上就是encoder和decoder),encoder可以有效地捕捉上下文信息,而decoder可以较好地预测位置信息。
- Encoder:使得模型理解了图像的内容,但是丢弃了图像的位置信息。
- Decoder:使模型结合Encoder对图像内容的理解,恢复图像的位置信息。
网络的浓缩路径,图像分辨率逐渐降低,上下文信息会逐渐增强。在扩展路径中,通过上采样的方式,让特征图的分辨率逐渐增大。同时,为了结合低层feature map的强位置信息,将浓缩路径中的相应部分结合到扩展路径中。这种架构可以较好地进行位置定位。
U-Net做的修改有:
- 在上采样部分,Feature Map的通道数非常大,作者认为这样可以将上下文信息传递到分辨率更高的层当中。这样做的一个结果就是它基本上和浓缩路径对称了,因此看上去像一个U形的结构。
- 为了预测图像边界区域的像素点,采用
overlap-tile
策略补全缺失的context。 - 由于训练数据太少,采用大量
弹性形变
的方式增强数据。这可以让模型更好学习形变不变性。这种增强方式对于医学图像来说很重要。 - 在细胞分割任务中的另一个挑战是,如何将同类别的相互接触的目标分开。本文提出了使用一种
带权重的损失(weighted loss)
。在损失函数中,分割相互接触的细胞像素获得了更大的权重。
第2章 U-Net网络结构
U-Net 跟 FCN 都是 Encoder-Decoder
结构,结构简单但很有效。Encoder 负责特征提取,你可以将自己熟悉的各种特征提取网络放在这个位置。由于在医学方面,样本收集较为困难,作者为了解决这个问题,应用了图像增强的方法,在数据集有限的情况下获得了不错的精度。
如上图,U-Net 网络结构是对称的,形似英文字母 U 所以被称为 U-Net。整张图都是由蓝/白色框与各种颜色的箭头组成,其中:
- 蓝/白色框表示
feature map
; - 蓝色箭头表示
3x3
卷积,用于特征提取; - 灰色箭头表示
skip-connection
,用于特征融合; - 红色箭头表示池化
pooling
,用于降低维度; - 绿色箭头表示上采样
upsample
,用于恢复维度; - 青色箭头表示
1x1 卷积
,用于输出结果。 - 其中灰色箭头
copy and crop
中的copy
就是维度上的concatenate
,而crop
是裁剪为了让两者的长宽一致;
可能你会问为啥是 5 层而不是 4 层或者 6 层,emmm,这应该去问作者本人,可能对于当时作者拿到的数据集来说,这个层数的表现更好,但不代表所有的数据集这个结构都适合。我们该多关注这种 Encoder-Decoder 的设计思想,具体实现则应该因数据集而异。
2.1 Encoder
Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。没有 padding 所以每次卷积之后 feature map 的 H 和 W 变小了,在 skip-connection 时要注意 feature map 的维度(其实也可以将 padding 设置为 1 避免维度不对应问题),pytorch 代码:
nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
上述的两次卷积之后是一个 stride 为 2 的 max pooling,输出大小变为 1 2 × ( H , W ) \frac{1}{2} \times(H, W) 21×(H,W):
pytorch 代码:
nn.MaxPool2d(kernel_size=2, stride=2)
上面的步骤重复 5 次,最后一次没有 max-pooling,直接将得到的 feature map 送入 Decoder。
2.2 Decoder
feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积之外比较关键的步骤就是 upsampling
与 skip-connection
。
(1)Upsampling
Upsampling 上采样常用的方式有两种:1.FCN 中介绍的反卷积;2. 插值。
这里介绍文中使用的插值方式。在插值实现方式中,bilinear(双线性插值)
的综合表现较好也较为常见 。
双线性插值的计算过程没有需要学习的参数,实际就是套公式,这里举个例子方便大家理解(例子介绍的是参数 align_corners
为 Fasle
的情况)。
pytorch 里使用 bilinear 插值:
nn.Upsample(scale_factor=2, mode='bilinear')
(2)Skip-Connection
CNN 网络要想获得好效果,skip-connection 基本必不可少。U-Net 中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:
torch.cat([low_layer_features, deep_layer_features], dim=1)
这里需要注意,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 U-Net 是通过拼接的方式。
那么这两者有什么区别呢,其实 在 ResNet 与 DenseNet 中也有一样的区别,Resnet 使用了对应值相加,DenseNet 使用了拼接。
个人理解在相加的方式下,feature map 的维度没有变化,但每个维度都包含了更多特征,对于普通的分类任务这种不需要从 feature map 复原到原始分辨率的任务来说,这是一个高效的选择;而拼接则保留了更多的维度/位置 信息,这使得后面的 layer 可以在浅层特征与深层特征自由选择,这对语义分割任务来说更有优势。
2.3 损失函数
(1)损失函数计算
ISBI数据集的一个非常严峻的挑战是紧密相邻的物体之间的分割问题。如图3所示,(a)是输入数据,(b)是Ground Truth,©是基于Ground Truth生成的分割掩码,(d)是U-Net使用的用于分离边界的损失权值。
网络输出的是pixel-wise的softmax。表达式如下:
其中, x x x 为二维平面 Ω Ω Ω 上的像素位置, a k ( x ) a_k(x) ak(x) 表示网络最后输出层中pixel x x x 对应的第 k k k 个通道的值, K K K 是类别总数。 p k ( x ) p_k(x) pk(x) 表示像素 x x x 属于 k k k 类的概率。
损失函数使用 negative cross entropy
。cross entropy的数学表达式如下:
其中 p l ( x ) p_l(x) pl(x) 表示 x x x 在真实label所在通道上的输出概率。特别注意的是cross entropy中还添加一个权重项 w ( x ) w(x) w(x) ,这是因为考虑到物体间的边界需要更多的关注,所对应的损失权重需要更大。
(2)像素损失权重计算
我们得到一张图片的ground truth是一个二值的mask,本文首先采用形态学方法去计算出物体的边界。然后通过以下的表达式去计算权重图。
其中 w c ( x ) w_c(x) wc(x) 是类别权重,需要根据训练数据集中的各类别出现的频率来进行统计,类别出现的频率越高,应该给的权重越低,频率越低则给的权重越高(文章没有详细说是怎么计算的)。
d 1 ( x ) d_1(x) d1(x) 表示物体像素到最近cell的边界的距离, d 2 ( x ) d_2(x) d2(x) 表示物体像素到第二近的cell的边界的距离。在本文中,设置 w 0 = 10 , σ = 5 w_0=10,σ=5 w0=10,σ=5 。
2.4 数据扩充
由于训练集只有30张训练样本,作者使用了数据扩充的方法增加了样本数量。并且作者指出任意的弹性形变对训练非常有帮助。
第3章 总结
U-Net是比较早的使用多尺度特征进行语义分割任务的算法之一,基于 Encoder-Decoder 结构,通过拼接的方式实现特征融合,结构简明且稳定。其U形结构也启发了后面很多算法。但其也有几个缺点:
- 有效卷积增加了模型设计的难度和普适性;目前很多算法直接采用了same卷积,这样也可以免去Feature Map合并之前的裁边操作
- 其通过裁边的形式和Feature Map并不是对称的,个人感觉采用双线性插值的效果应该会更好。
第4章 Pytorch实现U-Net
参考:
class U_Net(nn.Module):
def __init__(self):
super().__init__()
# 首先定义左半部分网络
# left_conv_1 表示连续的两个(卷积+激活)
# 随后进行最大池化
self.left_conv_1 = ConvBlock(in_channels=3, middle_channels=64, out_channels=64)
self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.left_conv_2 = ConvBlock(in_channels=64, middle_channels=128, out_channels=128)
self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.left_conv_3 = ConvBlock(in_channels=128, middle_channels=256, out_channels=256)
self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.left_conv_4 = ConvBlock(in_channels=256, middle_channels=512, out_channels=512)
self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.left_conv_5 = ConvBlock(in_channels=512, middle_channels=1024, out_channels=1024)
# 定义右半部分网络
self.deconv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
self.right_conv_1 = ConvBlock(in_channels=1024, middle_channels=512, out_channels=512)
self.deconv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=2, output_padding=1)
self.right_conv_2 = ConvBlock(in_channels=512, middle_channels=256, out_channels=256)
self.deconv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=2 ,output_padding=1)
self.right_conv_3 = ConvBlock(in_channels=256, middle_channels=128, out_channels=128)
self.deconv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1)
self.right_conv_4 = ConvBlock(in_channels=128, middle_channels=64, out_channels=64)
# 最后是1x1的卷积,用于将通道数化为3
self.right_conv_5 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0)
def forward(self, x):
# 1:进行编码过程
feature_1 = self.left_conv_1(x)
feature_1_pool = self.pool_1(feature_1)
feature_2 = self.left_conv_2(feature_1_pool)
feature_2_pool = self.pool_2(feature_2)
feature_3 = self.left_conv_3(feature_2_pool)
feature_3_pool = self.pool_3(feature_3)
feature_4 = self.left_conv_4(feature_3_pool)
feature_4_pool = self.pool_4(feature_4)
feature_5 = self.left_conv_5(feature_4_pool)
# 2:进行解码过程
de_feature_1 = self.deconv_1(feature_5)
# 特征拼接
temp = torch.cat((feature_4, de_feature_1), dim=1)
de_feature_1_conv = self.right_conv_1(temp)
de_feature_2 = self.deconv_2(de_feature_1_conv)
temp = torch.cat((feature_3, de_feature_2), dim=1)
de_feature_2_conv = self.right_conv_2(temp)
de_feature_3 = self.deconv_3(de_feature_2_conv)
temp = torch.cat((feature_2, de_feature_3), dim=1)
de_feature_3_conv = self.right_conv_3(temp)
de_feature_4 = self.deconv_4(de_feature_3_conv)
temp = torch.cat((feature_1, de_feature_4), dim=1)
de_feature_4_conv = self.right_conv_4(temp)
out = self.right_conv_5(de_feature_4_conv)
return out
测试网络输入和输出的尺寸是否一致:
if __name__ == "__main__":
x = torch.rand(size=(8, 3, 224, 224))
net = U_Net()
out = net(x)
print(out.size())
print("ok")