图像语义分割网络系列博文索引
FCN与SegNet | U-Net与V-Net | DeepLab系列 | DenseNet、PSPNet、与DenseASPP | Mask R-CNN |
---|
(后四个在计划中,敬请期待)
图像语义网络分割之U-Net与V-Net
U-Net
U-Net是医学图像处理领域最常用的一种网络结构,很多医学图像处理的网络结构都由U-Net改进而来。U-Net可以被看作是基于FCN和SegNet的一种改进方法,采用了FCN的全卷积、反卷积上采样、越级连接的方法,采用了SegNet的Encoder-Decoder结构。原文链接:U-Net: Convolutional Networks for Biomedical Image Segmentation
为了应对小样本学习问题,U-Net提出的网络结构和训练策略使用了充分利用了数据增强方法尽可能提高了对样本和标注的利用率。该体系结构由捕获上下文的跃迁路径和支持精确定位的对称展开路径组成。实验表明,该网络可以从很少的图像中进行端到端训练,并且性能优于先验最佳方法。此外,网络速度很快,在最新的GPU上,512x512图像的分割时间不到一秒。
U-Net主要贡献
- 与FCN不同的是U-net越级层融合方式采用的是concat方式,是对其通道数进行拼接,使特征图变厚,FCN采取的是直接加和的方式。此外U-Net的越级层融合次数增加,FCN只在最后一层进行了融合,Unet有4次融合,实现了多尺度的特征融合,充分的利用了上下文(context)信息,一定解决了感受野大小与分割精度之间的矛盾。
- 提出了overlap策略,该策略能够无缝分割任意大的图像。为了预测图像边界区域的像素,通过镜像输入图像来外推缺失的上下文。该策略使得网络在应对普遍像素大小比较大的医学图像具有了优势,否则分辨率将受到GPU内存的限制。
- 为了应对训练样本少的问题,U-Net采用了随机的弹性形变进行数据增强。
- Unet的优化方法为带动量项的SGD,能量函数为加权的交叉熵形式,离边界越近的像素点权重越大,使得网络对边界像素有更好的训练效果。
U-Net网络结构
其网络结构图如下所示,每一个蓝色方框对应一个多通道的特征图,通道数标注在框的上方。蓝色箭头代表3x3卷积和ReLu
激励,灰色箭头代表复制和裁剪,红色剪头代表2x2最大池化,绿色剪头代表2x2的反卷积上采样,青色箭头为1x1卷积。
Unet由一条收缩路径(左侧)和一条扩张路径(右侧)组成。收缩路径
和卷积网络的典型结构一致。它由两个3x3卷积(未填充卷积)的重复应用
组成,在每个卷积后跟ReLU
激励和一个2x2最大池化操作,步长为2,以实
现下采样。在每个下采样步骤中,将特征通道的数量增加一倍。扩展路径中的每个步骤都包括对特征图进行上采样,然后是将特征通道数量减半的2x2卷积(向上卷积),与来自收缩路径的相应裁剪的特征图的串联以及两个3x3卷积,后跟一个ReLU
。由于每次卷积中都会丢失边界像素,因此有必要进行裁剪。在最后一层,使用1x1卷积将每个64分量特征向量映射
到所需的类数。网络总共有23个卷积层。
Pytorch框架下U-Net实现
这里使用的代码是在Github中guanfuchen/semseg原代码的基础上做了一些顺序上的调整及更多的标注帮助理解。
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models
class unetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(unetDown, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class unetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(unetUp, self).__init__()
self.upConv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.ReLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0),
nn.ReLU(inplace=True),
)
def forward(self, x_cur, x_prev):
x = self.upConv(x_cur)
x = torch.cat([F.upsample_bilinear(x_prev, size=x.size()[2:]), x], 1)
x = self.conv1(x)
x = self.conv2(x)
return x
def cross_entropy2d(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
nt, ht, wt = target