论文:http://www.arxiv.org/pdf/1505.04597.pdf
国内镜像:http://xxx.itp.ac.cn/pdf/1505.04597.pdf
U-Net名称主要来源是网络结构呈现U型,左右基本对称。
左边是下采样,右边对应着下采样,结构很简单。
直接总结我认为有启迪性的点:
1.特征融合方式
U-Net的特征融合是把channel concat起来,形成更厚的特征,而FCN则是将对应点相加。
2.使用了加权loss
为了将边缘更清晰的划分,作者将边缘的损失函数权值加大,迫使网络更加关注对边缘的分类。
3.Overlap-tile策略
对于U-Net的改进,产生了U-Net++,作者在知乎讲了下思路,太赞了,这才是做研究的态度,仰慕。
U-Net++:https://zhuanlan.zhihu.com/p/44958351
Pytorch实践:
仅展现网络主体部分,计划以后搞一个代码合集,现在就当积累。
定义编码块:
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout=False):
super(EncoderBlock, self).__init__()
self.encode = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
if dropout:
self.encode.add_module('dropout', nn.Dropout())
self.encode.add_module('maxpool', nn.MaxPool2d(2, stride=2))
def forward(self, x):
return self.encode(x)
定义解码块:
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super(DecoderBlock, self).__init__()
self.decode = nn.Sequential(
nn.Conv2d(in_channels, middle_channels, 3, padding=1),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels, 3, padding=1),
nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(middle_channels, out_channels, 2, stride=2),
)
def forward(self, x):
return self.decode(x)
定义U-Net:
class UNet(nn.Module):
def __init__(self, num_classes):
super(UNet, self).__init__()
self.encoder1 = EncoderBlock(3, 64)
self.encoder2 = EncoderBlock(64, 128)
self.encoder3 = EncoderBlock(128, 256)
self.encoder4 = EncoderBlock(256, 512)
self.center = DecoderBlock(512, 1024, 512)
self.decoder4 = DecoderBlock(1024, 512, 256)
self.decoder3 = DecoderBlock(512, 256, 128)
self.decoder2 = DecoderBlock(256, 128, 64)
self.decoder1 = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes, 1),
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
x = self.center(enc4)
x = self.decoder4(torch.cat((enc4, x), dim=1))
x = self.decoder3(torch.cat((enc3, x), dim=1))
x = self.decoder2(torch.cat((enc2, x), dim=1))
x = self.decoder1(torch.cat((enc1, x), dim=1))
return x