3-图像分割之Fusionnet&Unet

Unet和Fusionnet都是医学图像分割的经典,医学图像分割不同于常规的的分割,医学图像分割分割的图像一般来源于医学研究中,譬如分割出某几类细胞,一般这种医学图像的数据集可用数量少,而且图像的分辨率高,而且医学是贴近于实际生活的,往往对分割结果的准确率有较高的要求,目前的医学研究正逐渐从手动或半自动分割向全自动分割方向所发展。医学目前也有自己的挑战赛ISBI,Unet则是ISBI2015挑战赛的冠军,成为医学图像分割的基石。

传统的分割方法一般是在滑动窗口中进行训练,用待预测像素周围的点来做预测,这样的方法网络具有局部感知能力,而且实际训练的图片会比给的样本多(因为对图片进行分割,原来一张图片可能会被分成好几张),但是这样的网络也存在缺点,一是网络运行效率慢,因为要一块一块对图像进行预测,且重复领域多。二是局部信息与全局信息不可兼得,划分大patch需要大池化层,分割精确度低,小patch则难以获得上下文信息。

Unet网络摘要:

  • 主要贡献:本文提出了一个网络和训练策略,使用数据增强,以便更有效的使用可用的带标签样本
  • 网络结构:设计了完全对称的U型网络,网络由两部分组成,定义一个收缩路径来获取全局信息,同时定义一个对称的扩张路径用以精确定位
  • 网络效果:该网络可以用很少的图片进行端到端训练,处理速度也比较快
  • 实验结果:以很大的优势赢得了2015 ISBI细胞跟踪挑战赛
  • Unet还对数据进行增强操作,将图像最外侧的像素进行镜像翻转再输入至网络进行预测,这样做的好处是可以丰富最外侧像素的信息,使得分割更准确。
  • 使用了加权损失函数

 可以理解为给学习的参数前面乘以了一个权重,这个权重可以加强分割细胞边界的像素,给处于细胞边界的像素点一个比较大的权值,给处于类似背景像素点这种相对不那么重要的像素一个较小的权值。

  • Unet倾向于使用大的patch,小的batch-size
  • Unet提供了一种网络初始化权重的方法(像Unet这种网络架构可以从标准差为根号下2/N的高斯分布中初始化,可以和随机初始化比较)

 Fusionnet网络摘要:

  • 主要贡献:本文提出了一种新的深度神经网络FusionNet,用于自动分割连接组学(大概就是学习神经元网络的结构,以便更好的研究大脑机理)数据中的神经元结构
  • 主要方法:引入了基于求和的跳跃连接,允许更深入的网络结构以实现更精确的分割
  • 实验结果:通过与ISBI-EM分割挑战中的最新方法比较,我们展示了方法的新性能。还展示了两个其他任务的分割结果,包括细胞膜和细胞体的分割以及细胞形态学的统计分析

Padding: 

卷积操作存在两个问题:

  • 图像越来越小;
  • 图像边界信息丢失,即有些图像角落和边界的信息发挥作用较少

因此需要padding

  • padding的操作就是在图像块的周围加上格子,从而使得图像经过卷积过后大小不会变化,这种操作是使得图像的边缘数据也能被利用到,这样才能更好地扩张整张图像的边缘特征
  • padding是在卷积之前补0;
  • padding补0的策略是四周都补,如果padding=1,那么就会在原来输入层的基础上,上下左右各补一行
  • padding的用途:保持边界信息;可以对有差异的图片进行补齐,使得图像的输入大小一致

 卷积的三种形式:

Pytorch实现UNet:

# -*- coding:utf-8 -*-
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
# from IPython.core.debugger import set_trace

def contracting_block(in_channels, out_channels):  # 收缩路径,包含两个卷积,池化,归一层
    block = torch.nn.Sequential(
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels)
    )
    return block


class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=(3, 3), stride=2, padding=1,
                                     output_padding=1)  # 56*56*512

        self.block = nn.Sequential(  # 拼接之后56*56*1024
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels),  # 54*54*512
            nn.ReLU(),
            nn.BatchNorm2d(mid_channels),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels),  # 52*52*256
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, e, d):
        d = self.up(d)
        # concat
        diffY = e.size()[2] - d.size()[2]  # batch_size,chanel,h,w
        diffX = e.size()[3] - d.size()[3]
        e = e[:, :, diffY // 2:e.size()[2] - diffY // 2, diffX // 2:e.size()[3] - diffX // 2]
        cat = torch.cat([e, d], dim=1)
        out = self.block(cat)
        return out


def final_block(in_channels, out_channels):
    block = nn.Sequential(
        nn.Conv2d(kernel_size=(1, 1), in_channels=in_channels, out_channels=out_channels),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels)
    )
    return block


class UNet(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()
        # Encode
        self.conv_encode1 = contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024),
            nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024),
            nn.ReLU(),
            nn.BatchNorm2d(1024)
        )

        # Decode
        self.conv_decode4 = expansive_block(1024, 512, 512)
        self.conv_decode3 = expansive_block(512, 256, 256)
        self.conv_decode2 = expansive_block(256, 128, 128)
        self.conv_decode1 = expansive_block(128, 64, 64)

        self.final_layer = final_block(64, out_channel)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x);
        print('encode_block1:', encode_block1.size())
        encode_pool1 = self.conv_pool1(encode_block1);
        print('encode_pool1:', encode_pool1.size())
        encode_block2 = self.conv_encode2(encode_pool1);
        print('encode_block2:', encode_block2.size())
        encode_pool2 = self.conv_pool2(encode_block2);
        print('encode_pool2:', encode_pool2.size())
        encode_block3 = self.conv_encode3(encode_pool2);
        print('encode_block3:', encode_block3.size())
        encode_pool3 = self.conv_pool3(encode_block3);
        print('encode_pool3:', encode_pool3.size())
        encode_block4 = self.conv_encode4(encode_pool3);
        print('encode_block4:', encode_block4.size())
        encode_pool4 = self.conv_pool4(encode_block4);
        print('encode_pool4:', encode_pool4.size())

        # Bottleneck
        bottleneck = self.bottleneck(encode_pool4);
        print('bottleneck:', bottleneck.size())

        # Decode
        decode_block4 = self.conv_decode4(encode_block4, bottleneck);
        print('decode_block4:', decode_block4.size())
        decode_block3 = self.conv_decode3(encode_block3, decode_block4);
        print('decode_block3:', decode_block3.size())
        decode_block2 = self.conv_decode2(encode_block2, decode_block3);
        print('decode_block2:', decode_block2.size())
        decode_block1 = self.conv_decode1(encode_block1, decode_block2);
        print('decode_block1:', decode_block1.size())

        final_layer = self.final_layer(decode_block1)

        return final_layer


# test
image1 = torch.rand((1, 3, 572, 572))
unet = UNet(in_channel=3, out_channel=1)
mask = unet(image1)

代码结果: 

对于归一化和激活函数的先后问题:

一般是卷积-BN-Relu.
Sigmoid:如果先BN再Sigmoid,由于BN后方差接近于1,均值接近于0,使得BN后的数据接近于Sigmoid的线性区域,降低了激活函数的非线性能力,这种情况下建议Sigmoid+BN。

Relu:如果先Relu再BN,Relu后部分神经元已经失活,失活的神经元将对BN的归一化产生影响,这种情况下建议BN+Relu。
 

Pytorch实现Fusionnet:

# -*- coding:utf-8 -*-

import torch.nn as nn
import torch


def conv_block(in_dim, out_dim, act_fn, stride=1):  # 把单独卷积层封装,方便调用
    model = nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=stride, padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn
    )
    return model


def conv_trans_block(in_dim, out_dim, act_fn):  # 同样把反卷积层封装,方便调用
    model = nn.Sequential(
        nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
        nn.BatchNorm2d(out_dim),
        act_fn
    )
    return model


def conv_block_3(in_dim, out_dim, act_fn):
    model = nn.Sequential(
        conv_block(in_dim, out_dim, act_fn),
        conv_block(out_dim, out_dim, act_fn),
        nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_dim)
    )
    return model


class Conv_residual_conv(nn.Module):  # 小汉堡结构
    def __init__(self, in_dim, out_dim, act_fn):
        super(Conv_residual_conv, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        act_fn = act_fn

        self.conv_1 = conv_block(self.in_dim, self.out_dim, act_fn)
        self.conv_2 = conv_block_3(self.out_dim, self.out_dim, act_fn)
        self.conv_3 = conv_block(self.out_dim, self.out_dim, act_fn)

    def forward(self, input):
        conv_1 = self.conv_1(input)
        conv_2 = self.conv_2(conv_1)
        res = conv_1 + conv_2
        conv_3 = self.conv_3(res)

        return conv_3


class Fusionnet(nn.Module):
    def __init__(self, input_nc, output_nc, ngf, out_clamp=None):
        super(Fusionnet, self).__init__()

        self.out_clamp = out_clamp
        self.in_dim = input_nc
        self.out_dim = ngf
        self.final_out_dim = output_nc

        act_fn = nn.LeakyReLU(0.2, inplace=True)
        act_fn_2 = nn.ELU(inplace=True)

        # encoder
        self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn)
        self.pool_1 = conv_block(self.out_dim, self.out_dim, act_fn, 2)

        self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn)
        self.pool_2 = conv_block(self.out_dim * 2, self.out_dim * 2, act_fn, 2)

        self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn)
        self.pool_3 = conv_block(self.out_dim * 4, self.out_dim * 4, act_fn, 2)

        self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn)
        self.pool_4 = conv_block(self.out_dim * 8, self.out_dim * 8, act_fn, 2)

        # bridge
        self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn)

        # decoder
        self.deconv_4 = conv_trans_block(self.out_dim * 16, self.out_dim * 8, act_fn_2)
        self.up_4 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2)

        self.deconv_3 = conv_trans_block(self.out_dim * 8, self.out_dim * 4, act_fn_2)
        self.up_3 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2)

        self.deconv_2 = conv_trans_block(self.out_dim * 4, self.out_dim * 2, act_fn_2)
        self.up_2 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2)

        self.deconv_1 = conv_trans_block(self.out_dim * 2, self.out_dim, act_fn_2)
        self.up_1 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2)

        # output
        self.out = nn.Conv2d(self.out_dim, self.final_out_dim, kernel_size=3, stride=1, padding=1)

    def forward(self, input):
        down_1 = self.down_1(input);print('down_1:', down_1.size())
        pool_1 = self.pool_1(down_1);print('pool_1:', pool_1.size())

        down_2 = self.down_2(pool_1);print('down_2:', down_2.size())
        pool_2 = self.pool_2(down_2);print('pool_2:', pool_2.size())

        down_3 = self.down_3(pool_2);print('down_3:', down_3.size())
        pool_3 = self.pool_3(down_3);print('pool_3:', pool_3.size())

        down_4 = self.down_4(pool_3);print('down_4:', down_4.size())
        pool_4 = self.pool_4(down_4);print('pool_4:', pool_4.size())

        bridge = self.bridge(pool_4);print('bridge:', bridge.size())

        deconv_4 = self.deconv_4(bridge);print('deconv_4:', deconv_4.size())
        skip_4 = (deconv_4 + down_4) / 2;print('skip_4:', skip_4.size())
        up_4 = self.up_4(skip_4);print('up_4:', up_4.size())

        deconv_3 = self.deconv_3(up_4);print('deconv_3:', deconv_3.size())
        skip_3 = (deconv_3 + down_3) / 2;print('skip_3:', skip_3.size())
        up_3 = self.up_3(skip_3);print('up_3:', up_3.size())

        deconv_2 = self.deconv_2(up_3);print('deconv_2:', deconv_2.size())
        skip_2 = (deconv_2 + down_2) / 2;print('skip_2:', skip_2.size())
        up_2 = self.up_2(skip_2);print('up_2:', up_2.size())

        deconv_1 = self.deconv_1(up_2);print('deconv_1:', deconv_1.size())
        skip_1 = (deconv_1 + down_1) / 2;print('skip_1:', skip_1.size())
        up_1 = self.up_1(skip_1);print('up_1:', up_1.size())

        out = self.out(up_1);print('out:', out.size())

        return out


# test
image = torch.rand((1, 3, 352, 480))
FusionNet = Fusionnet(3, 12, 64)
mask = FusionNet(image)

代码结果:

参考来源:

B站深度之眼

https://blog.csdn.net/fuzizhu1/article/details/116273339

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值