pytorch i2i 图像分割测试集_Unet图像分割在PyTorch上的实现

v2-989e178103744fc26291ecde11f5cfe2_1440w.jpg?source=172ae18b

Unet是一个最近比较火的网络结构。它的理论已经有很多大佬在讨论了。本文主要从实际操作的层面,讲解如何使用pytorch实现unet图像分割。

通常我会在粗略了解某种方法之后,就进行实际操作。在操作过程中,也许会遇到一些疑问,再回过头去仔细研究某个理论。这样的学习方法,是我比较喜欢的方式。这也是fast.ai推崇的自上而下的学习方式。

本文将先简单介绍Unet的理论基础,然后使用pytorch一步一步地实现Unet图像分割。因为主要目的是提供一个baseline模型给大家,所以代码主要关注在如何构造Unet的网络结构。

当你学会了如何用代码实现Unet,我相信你对Unet的理解已经比较深刻了。

本文完整的代码:https://github.com/Qiuyan918/Unet_Implementation_PyTorch/blob/master/Unet_Implementation_PyTorch.ipynb


Unet

v2-f908198e154883523ae9ee1305f1d8c7_b.jpg
图1: Unet的网络结构

Unet主要用于图像分割问题。图1是Unet论文中的网络结构图。可以看出Unet是一个对称的结构,左半边是Encoder,右半边是Decoder。图像会先经过Encoder处理,再经过Decoder处理,最终实现图像分割。它们分别的作用如下:

  • Encoder:使得模型理解了图像的内容,但是丢弃了图像的位置信息。
  • Decoder:使模型结合Encoder对图像内容的理解,恢复图像的位置信息。

Encoder的部分和传统的网络结构类似,可以选择图中的结构,也可以选择VGG,ResNet等。随着卷积层的加深,特征图的长宽减小,通道增加。虽然Encoder提取了图像的高级特征,但是丢弃了图像的位置信息。所以在图像识别问题中,模型只需要Encoder的部分。因为图像识别不需要位置信息,只需要提取图像的内容信息。

Decoder的部分是Unet的重点。Decoder中涉及upconvolution这个概念。关于upconvolution,这里不做详细介绍,简单来说就是convolution的反向运算。Decoder的每一层都通过upconvolution(图中绿色箭头),并且和Encoder相对应的初级特征结合(图中的灰色箭头),逐渐恢复图像的位置信息。在Decoder中,随着卷积层的加深,特征图的长宽增大,通道减少。


数据:

v2-b67935b4410d3469ec7522f90ecd5363_b.jpg
图2: Kaggle盐体分割比赛

本文用到的数据来源于Kaggle盐体分割比赛。这次比赛的问题是一个非常典型的图像分割问题。比赛中的大佬们基本上都用的Unet。

我们的目标就是将图片中的盐体找出来。盐体有一些我不太懂的经济价值,反正是很有意义的。

以下是一些图片样例:

v2-6fc1e68a801ddd022930b456a34012ce_b.jpg
图3: 图片样例

PyTorch实现

Unet

本文定义的Unet网络结构和论文中的略有不同,但本质都采用的是Encoder和Decoder的结构。主要的不同点是:

  • Encoder的backbone基于ResNet18
  • 输入和输出图像大小一致

以下是Unet网络结构的pytorch代码,代码后附了详细的解释。

class Decoder(nn.Module):
  def __init__(self, in_channels, middle_channels, out_channels):
    super(Decoder, self).__init__()
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.conv_relu = nn.Sequential(
        nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
        )
  def forward(self, x1, x2):
    x1 = self.up(x1)
    x1 = torch.cat((x1, x2), dim=1)
    x1 = self.conv_relu(x1)
    return x1

class Unet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = torchvision.models.resnet18(True)
        self.base_layers = list(self.base_model.children())
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            self.base_layers[1],
            self.base_layers[2])
        self.layer2 = nn.Sequential(*self.base_layers[3:5])
        self.layer3 = self.base_layers[5]
        self.layer4 = self.base_layers[6]
        self.layer5 = self.base_layers[7]
        self.decode4 = Decoder(512, 256+256, 256)
        self.decode3 = Decoder(256, 256+128, 256)
        self.decode2 = Decoder(256, 128+64, 128)
        self.decode1 = Decoder(128, 64+64, 64)
        self.decode0 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
            )
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        e1 = self.layer1(input) # 64,128,128
        e2 = self.layer2(e1) # 64,64,64
        e3 = self.layer3(e2) # 128,32,32
        e4 = self.layer4(e3) # 256,16,16
        f = self.layer5(e4) # 512,8,8
        d4 = self.decode4(f, e4) # 256,16,16
        d3 = self.decode3(d4, e3) # 256,32,32
        d2 = self.decode2(d3, e2) # 128,64,64
        d1 = self.decode1(d2, e1) # 64,128,128
        d0 = self.decode0(d1) # 64,256,256
        out = self.conv_last(d0) # 1,256,256
        return out
  • 这里定义了两个class:DecoderUnetUnet是整个模型的结构,Decoder则是模型Decoder中的单层。
  • 使用pytorch构造模型时,需要基于nn.Module定义类。forward函数定义前向传播的逻辑。
  • Decoder中的up运算定义为nn.ConvTranspose2d,也就是upconvolution;conv_relu则定义为nn.Conv2dnn.ReLU的组合。pytorch中需要用到nn.Sequential将多个运算组合在一起。
  • Decoderforward函数定义了其前向传播的逻辑:1. 对特征图x1做upconvolution。2. 将x1和x2(encoder中对应的特征图)组合(concatenate)。3. 对组合后的特征图做卷积和relu。
  • 因为Unet基于resnet18,所以定义运算时从torchvision.models.resnet18取出来就可以。因为resnet18默认的是适用于RGB图片,而比赛中的图片是灰的,只有一个通道,所以layer1中的卷基层需要自己定义。
  • layer1layer5属于encoder,encode4encode0属于decoder,呈对称结构。
  • 下表是经过各层的处理后,特征图的长/宽和通道数:

v2-4758c50c31781162079b1b2659bc3267_b.jpg

Dataset

如果你了解keras,那么就会发现pytorch中的Dataset和keras中的DataGenerator类似。不同的是pytorch定义的Dataset只返回1个样本,再通过DataLoader定义batch_size

Dataset的逻辑很简单,分为以下几步:

  • 读取图片
  • 预处理(resize, pad, 数据增强等)
  • 返回图片和Mask

Pytorch代码如下:

class SaltDataset(Dataset):
  def __init__(self, image_list, mode, mask_list=None, fine_size=202, pad_left=0, pad_right=0):
    self.imagelist = image_list
    self.mode = mode
    self.masklist = mask_list
    self.fine_size = fine_size
    self.pad_left = pad_left
    self.pad_right = pad_right

  def __len__(self):
    return len(self.imagelist)

  def __getitem__(self, idx):
    image = deepcopy(self.imagelist[idx])

    if self.mode == 'train':
      mask = deepcopy(self.masklist[idx])
      label = np.where(mask.sum() == 0, 1.0, 0.0).astype(np.float32)

      if self.fine_size != image.shape[0]:
        image, mask = do_resize2(image, mask, self.fine_size, self.fine_size)

      if self.pad_left != 0:
        image, mask = do_center_pad2(image, mask, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])
      mask = mask.reshape(1, mask.shape[0], mask.shape[1])

      return image, mask, label

    elif self.mode == 'val':
      mask = deepcopy(self.masklist[idx])

      if self.fine_size != image.shape[0]:
        image, mask = do_resize2(image, mask, self.fine_size, self.fine_size)

      if self.pad_left != 0:
        image = do_center_pad(image, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])
      mask = mask.reshape(1, mask.shape[0], mask.shape[1])

      return image, mask

    elif self.mode == 'test':
      if self.fine_size != image.shape[0]:
        image = cv2.resize(image, dsize=(self.fine_size, self.fine_size))

      if self.pad_left != 0:
        image = do_center_pad(image, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])

      return image

Optimizer

optimizer采用的是SGD,同时用到了余弦退火学习率和快照集成来提升模型效果。

scheduler_step = epoch // snapshot
optimizer = torch.optim.SGD(salt.parameters(), lr=max_lr, momentum=momentum, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, min_lr)

结论

更完整的代码在我的github。在没有数据增强和TTA等其他手段的情况下,本文的代码能够拿到0.76的成绩,是一个不错的baseline模型。


参考文献:

[1] https://arxiv.org/abs/1505.04597

[2] https://github.com/ybabakhin/kaggle_salt_bes_phalanx

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值