MindSpore复现Attention U-Net

这个是我在mindspore的经典论文复现的活动中看到的,大家感兴趣的可以去看一下
【MindSpore开发者群英会】经典论文复现活动 · Issue #I6Q8R0 · MindSpore/community - Gitee.com

 然后我们先来看看论文:Attention U-Net: Learning Where to Look for the PancreasAttention U-Net: Learning Where to Look for the Pancreas | Papers With Code

然后看到这个题目很熟悉,因为这个是之前我写paper的时候也参考的文章,是将传统unet两个层之间的concat(拼接)操作进行了一个处理,也是“魔改”unet的一个方向吧哈哈哈哈哈(上采样、下采样、concat操作,太经典了只能说)

我们可以先来看看传统的unet长啥样

        只能说,非常简单,经过卷积、池化还有激活函数的操作(因为医学图像的像素本来就比较小,而且背景一般是偏黑或者白色,所以模型复杂反而学习的结果不太好)

        那我们这篇paper的作者就针对这个unet模型进行一定的修改,主要的修改内容在于横向拼接的concat操作,这个模型我们依照文章里面的称为:AG-Unet

        然后文章的作者在摘要里面写: 我们提出了一种用于医学成像的新型注意门(AG)模型,该模型可以自动学习聚焦于不同形状和大小的目标结构。用AG训练的模型隐含地学习抑制输入图像中的不相关区域,同时突出对特定任务有用的显著特征。这使我们能够消除使用级联卷积神经网络(CNNs)的显式外部组织/器官定位模块的必要性。AGs可以很容易地集成到标准CNN架构中,例如U-Net模型,具有最小的计算开销,同时提高了模型的灵敏度和预测精度。在两个用于多类图像分割的大型CT腹部数据集上对所提出的注意力U-Net架构进行了评估。实验结果表明,AGs在保持计算效率的同时,在不同的数据集和训练大小上持续提高了U-Net的预测性能。拟议架构的代码是公开的。

那么我简单的讲一下作者修改这个的意图:作者认为传统的直接拼接并没有让模型很好的学习到下采样中的前景特征,直接进行拼接反而会不好,没有对想要学习部分进行一个很好的关注,所以作者希望设计一个容易插入的模块,能够针对目标区域(抑制不相关的区域被注意),其实就是相当于做一个特征筛选或者关注的工作,针对于图像的背景。

 那么作者就设计了这么一个block,叫做Attention Gates(AG),他是由两个部分进行输入,g和xl,那么分别是同一层中下采样的结果和上采样的结果,先分别进行卷积conv和归一化bn,接着一个简答的相加然后送进激活函数relu里面,再进行一个卷积conv和归一化bn,然后送入sigmoid函数进行激活,然后将传入的x,就是同一层下采样的结果进行叉乘,最后得出结果。

class Attention_block(nn.Cell):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.SequentialCell(
            nn.Conv2d(F_g, F_int, kernel_size=3, has_bias=True, bias_init="zeros", pad_mode="same", weight_init="normal"),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.SequentialCell(
            nn.Conv2d(F_l, F_int, kernel_size=3, has_bias=True, bias_init="zeros", pad_mode="same", weight_init="normal"),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.SequentialCell(
            nn.Conv2d(F_int, 1, kernel_size=3, has_bias=True, bias_init="zeros", pad_mode="same", weight_init="normal"),
            nn.BatchNorm2d(1),
        )
        self.relu = nn.ReLU()
        
    def construct(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        psi = ops.sigmoid(psi)

        return x*psi

        这个就是简单的根据模型去实现的AG-block, 接着就将这个插入原本的unet就可以了,但是需要注意的是, 因为这里进行拼接和叉乘的操作,需要tensor的大小一样,所以的话要进行padding,mindspore框架是与pytorch有点不一样的。

        上采样还是跟原来的unet一样,主要是在下采样里面进行修改,这里我就给出一个下采样的代码,其余的也是一样的。

class Up1(nn.Cell):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        self.concat = F.Concat(axis=1)
        self.factor = 56.0 / 64.0
        self.center_crop = CentralCrop(central_fraction=self.factor)
        self.print_fn = F.Print()
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        self.up = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2,
                                     weight_init="normal", bias_init="zeros")
        self.relu = nn.ReLU()
        self.Att1 = Attention_block(F_g=512,F_l=512,F_int=256)

    def construct(self, x1, x2):
        x1 = self.up(x1)
        x1 = self.relu(x1)
        x2 = self.center_crop(x2)
        x2 = self.Att1(g=x1, x=x2)
        x = self.concat((x1, x2))
        return self.conv(x)

 然后模型的框架没有改,直接就写在了block里面,这个就算具体模型的实现了,其余的就直接cv大法了(狗头)

class AttU_Net(nn.Cell):
    def __init__(self, n_channels, n_classes):
        super(AttU_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        
        self.up1 = Up1(1024, 512)
        self.up2 = Up2(512, 256)
        self.up3 = Up3(256, 128)
        self.up4 = Up4(128, 64)
        self.outc = OutConv(64, n_classes)

    def construct(self, x):

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

然后完整的网络就这样子啦,其实感觉是魔改unet第一个比较成功的,虽然没有用attention(虽然很快就有人直接将这个改成attention直接发paper了),但是也是一个注意力的方式,相当于一个筛选器,通过上下采样的结果更好的去进行concat的操作,有助于更加注意想要分割的目标和背景的关注。

然后我们使用的是细胞的数据集进行训练和推理测试

两者均训练600epoch,学习率为:0.0001A

 unet的训练结果:

unet的精度:

unet-AG 训练结果:

推理结果:

 

 然后我的代码也会上传到github上面:

数据集放在dataset里面了,用的是细胞分割的数据集tif格式

https://github.com/SzuPc/mindspore-UnetAG/tree/master

希望大家多star!!

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值