unet分割如何取其中一类_3D U-Net脑胶质瘤分割BraTs + Pytorch实现

75b1708b8346efd0d9c11412a02d185f.png

原论文地址: 连接

一、网络模型的分析和对比

原始2D-Unet网络模型

8fd59b29f24c37cf8299a26550e57a36.png

我的2D-Unet网络模型

1、和原来的2D-Unet网络不同的是,我输入通道为4,我这里应该改为4个通道,对应四个模态图像,而输出通道为3,我对应的是三个嵌套子区域标签(WT、TC、ET)

2、另外,最大不同的是我的3X3卷积后的图像尺寸与卷积前一致,所以不用像原来2D-Unet那样因为编码和解码的尺寸不一致,需要裁剪后再拼接.问题来了,为什么原来2D-Unet的卷积会导致卷积前后图尺寸发生改变,因为原来的卷积操作为 kernelsize = 3 ,stride =1 ,padding=0,此卷积为valid方式,这种卷积只能使得图的尺寸越卷越小.而我这里为 kernelsize = 3 ,stride =1 ,padding=1.根据公式可以得出卷积前后的特征图尺寸一致.这种卷积方式为same卷积.

[概念]卷积的三种模式:valid、same、full_程序猿的养生馆-CSDN博客​blog.csdn.net
cce11f09c1a07d4fc6d86f09632d2523.png

关于2D-Unet的讲解和Pytorch代码实现前面我也做了详细地讲解.

玖零猴:U-Net+与FCN的区别+医学表现+网络详解+创新​zhuanlan.zhihu.com
fd8ec234091f23c491e647457d3c956d.png
玖零猴:2D-UNet脑胶质瘤分割BraTs + Pytorch实现​zhuanlan.zhihu.com
1acfb1300bd6fac35d722d3d16c6b33e.png

原始3D-Unet网络模型

e15df553027b0a6ef8d087e90b7d0adc.png

1、和原始的2D-Unet对比,最显著的不同就是3D-Unet池化下采样共3次,所以这里一共有4个尺度,而原来2D-Unet有5个尺度

3、另外,需要注意的是在编码部分中每个尺度的两次卷积后特征图的通道数变化,很明显和2D-Unet不同.而解码部分却一样.

我的3D-Unet网络模型

1、原来3D-Unet输入通道为3,我这里应该改为4个通道,对应四个模态图像,而输出通道一样都为3,我对应的是三个嵌套子区域标签(WT、TC、ET).这里根据自己的需求设置自己网络的输入和输出通道.

2、原来3D-Unet和原来2D-Unet一样都采用了valid卷积,为了保证网络的输入输出分辨率一致,因此我的2D-Unet和3D-Unet都采用same卷积,所以就不用裁剪了,同时可以保证网络的输入输出分辨率一致

其余都是和原始3D-Unet一样的,为了更加好看,我把网络标注清晰点,如下图

b48145c7e5a2602b4d75b585b0bad213.png

二、预处理与数据的获取

玖零猴:(3D网络)医学三维数据且又多模态多标签该如何预处理​zhuanlan.zhihu.com
1acfb1300bd6fac35d722d3d16c6b33e.png

三、环境的配置

1、系统环境 WIN10 + CUDA 92 + CUDNN7 + ANACONDA

2、ANACONDA指令快速配置环境,先下载下面文件

https://download.csdn.net/download/weixin_40519315/12394604

a460141ad024caec63e649722372c340.png

四、3D-Unet模型代码

from torch import nn
from torch import cat

class pub(nn.Module):

    def __init__(self, in_channels, out_channels, batch_norm=True):
        super(pub, self).__init__()
        inter_channels = out_channels if in_channels > out_channels else out_channels//2

        layers = [
                    nn.Conv3d(in_channels, inter_channels, 3, stride=1, padding=1),
                    nn.ReLU(True),
                    nn.Conv3d(inter_channels, out_channels, 3, stride=1, padding=1),
                    nn.ReLU(True)
                 ]
        if batch_norm:
            layers.insert(1, nn.BatchNorm3d(inter_channels))
            layers.insert(len(layers)-1, nn.BatchNorm3d(out_channels))
        self.pub = nn.Sequential(*layers)

    def forward(self, x):
        return self.pub(x)


class unet3dEncoder(nn.Module):

    def __init__(self, in_channels, out_channels, batch_norm=True):
        super(unet3dEncoder, self).__init__()
        self.pub = pub(in_channels, out_channels, batch_norm)
        self.pool = nn.MaxPool3d(2, stride=2)

    def forward(self, x):
        x = self.pub(x)
        return x,self.pool(x)


class unet3dUp(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=True, sample=True):
        super(unet3dUp, self).__init__()
        self.pub = pub(in_channels//2+in_channels, out_channels, batch_norm)
        if sample:
            self.sample = nn.Upsample(scale_factor=2, mode='nearest')
        else:
            self.sample = nn.ConvTranspose3d(in_channels, in_channels, 2, stride=2)

    def forward(self, x, x1):
        x = self.sample(x)
        #c1 = (x1.size(2) - x.size(2)) // 2
        #c2 = (x1.size(3) - x.size(3)) // 2
        #x1 = x1[:, :, c1:-c1, c2:-c2, c2:-c2]
        x = cat((x, x1), dim=1)
        x = self.pub(x)
        return x


class unet3d(nn.Module):
    def __init__(self, args):
        super(unet3d, self).__init__()
        init_channels = 4
        class_nums = 3
        batch_norm = True
        sample = True

        self.en1 = unet3dEncoder(init_channels, 64, batch_norm)
        self.en2 = unet3dEncoder(64, 128, batch_norm)
        self.en3 = unet3dEncoder(128, 256, batch_norm)
        self.en4 = unet3dEncoder(256, 512, batch_norm)

        self.up3 = unet3dUp(512, 256, batch_norm, sample)
        self.up2 = unet3dUp(256, 128, batch_norm, sample)
        self.up1 = unet3dUp(128, 64, batch_norm, sample)
        self.con_last = nn.Conv3d(64, class_nums, 1)
        #self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1,x = self.en1(x)
        x2,x= self.en2(x)
        x3,x= self.en3(x)
        x4,_ = self.en4(x)

        x = self.up3(x4, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        out = self.con_last(x)
        return out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_uniform(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

完整代码通过知乎@我(知乎号:玖零猴)(QQ:704783475、博主想恰杯奶茶 )

五、训练

python train.py --arch=“unet3d” --dataset=“Jiu0Monkey”

六、测试

python test.py --name="jiu0Monkey_unet3d_woDS"

七、和2D U-Net对比

在此之前,本专栏中的2D网络预测的时候,是把所有的切片预测完指标再求平均值,这样测的值极容易收到一些差的切片而影响整体的指标.所以以后的2D网络预测都采用下面方式进行计算指标,即把所有预测的切片拼接回3D,然后对3D数据整体进行计算指标.这样计算的值会偏高点.不只是2D网络这样,3D网络也是如此,把所有分块拼接后再对整体进行指标的计算.这样统一之后,我们就可以将2D和3D网络进行对比了.此外,代码预测生成的数据都是NII格式的,可以通过ITK-SNAP软件查看三维的分割效果,如果想看2D切片的分割效果,可以用该软件导出即可.

2D网络新的预测代码(test.py)如下

https://download.csdn.net/download/weixin_40519315/12466322​download.csdn.net

通过实验得出,2D U-Net、3D U-Net分割指标表如下:

94708d34c548ca7d3ab1254fd240797c.png

分割效果对比图如下,可见3D网络提高了肿瘤周围的预测,少了很多小渣点.

bbfcd13008255afec93c93f3c3e5b65a.png
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值