【脑肿瘤分割】Brain Tumor Segmentation and Radiomics Survival Prediction: Contribution to the BRATS 2017 Ch

概述

这篇文章相对而言会比较简单一点,就是使用了一个修改了的3D UNet做脑肿瘤分割任任务。

细节

网络结构

乍一看就是简单把UNet从2D推广到了3D,但是仔细看发现会有很多的小细节:如残差结构、深度监督,还有很多需要从代码中发现的细节,如归一化方法、网络中特征图的数量、上采样方法、激活函数等。
在这里插入图片描述
一些解释:
残差结构:有很多有研究都往UNet中添加残差结构
深度监督:decoder中每个stage都会得到分割结果,然后将这些结果全部逐元素相加得到最终的分割结果。
归一化方法:这里的归一化方法采用的是实例归一化,因为采用一个比较小的batch(size = 2),所以学习到的均值和标准差不够稳定,而采用实例归一化可以提供更稳定一致的结果。注:BN对Batch中的每一张图片的同一个通道一起进行Normalization操作,而IN是指单张图片的单个通道单独进行Normalization操作。具体点:BN是对batch中的所有样本做处理的,也就是第一个样本第一个通道,第二个样本第一个通道,一直到第n个样本第一个通道,求平均值和方差,做归一化;而IN是对batch中的单个样本做处理的么也就是对当前样本的第一个通道,求平均值和方差,做归一化(减去均值除以标准差)
上采样方法:采用双线性插值法而不是反卷积,作者也没说为什么改,估计是实验效果好吧…
激活函数:采用leaky ReLU而不是ReLU

其他内容

防止过拟合:大量的数据增广
类别不均衡:使用多类别的dice loss、
在这里插入图片描述
两类loss的对比:原始dice loss和多类别的dice loss
在这里插入图片描述

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
 
    def forward(self, input, target):
        N = target.size(0)
        smooth = 1
		input_flat = input.view(N, -1)
		target_flat = target.view(N, -1)
		intersection = input_flat * target_flat
		loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
		loss = 1 - loss.sum() / N
		return loss

在这里插入图片描述

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, input, target):
        smooth = 0.01  # 防止分母为0
        input1 = F.softmax(input, dim=1)
        target1 = F.one_hot(target,self.n_classes)
        input1 = rearrange(input1,'b n h w s -> b n (h w s)')
        target1 = rearrange(target1,'b h w s n -> b n (h w s)')
        
        inter = torch.sum(input1 * target1)
        union = torch.sum(input1) + torch.sum(target1) + smooth
        dice = 2.0 * inter / union

        loss = F.cross_entropy(input,target, weight=self.weight)

        total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha

        return total_loss

简单实现与对比

首先是直接将UNet从2D转换到3D,转换的话就是直接将所有的2d操作都变成3d操作(2D相关的介绍可以看这里

import torch
import torch.nn as nn


# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(VGGBlock, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.layer(x)

# 将decoder当前层上采样并且和encoder当前层做concat
# 反卷积(转置卷积)计算公式:
# 输出大小 = (输入大小 − 1) * Stride + Filter - 2 * Padding
# 当前这种设置使得输入输出尺寸相同
class Up(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Up, self).__init__()
        self.layer=nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, 4, 2, 1)
        )

    def forward(self,x1,x2):
        x1=self.layer(x1)
        # 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
        return torch.cat([x2,x1],dim=1)


class UNet(nn.Module):
    def __init__(self,in_channels,num_classes=2):
        super(UNet, self).__init__()

        filters=[64, 128, 256, 512, 1024]
        self.pool= nn.MaxPool3d(2)
        ## -------------encoder-------------
        self.encoder_1=VGGBlock(in_channels,filters[0])
        self.encoder_2=VGGBlock(filters[0],filters[1])
        self.encoder_3=VGGBlock(filters[1],filters[2])
        self.encoder_4=VGGBlock(filters[2],filters[3])
        self.encoder_5=VGGBlock(filters[3],filters[4])

        ## -------------decoder-------------
        self.up_4=Up(filters[4],filters[3])
        self.up_3=Up(filters[3],filters[2])
        self.up_2=Up(filters[2],filters[1])
        self.up_1=Up(filters[1],filters[0])

        self.decoder_4 = VGGBlock(filters[4],filters[3])
        self.decoder_3 = VGGBlock(filters[3],filters[2])
        self.decoder_2 = VGGBlock(filters[2],filters[1])
        self.decoder_1 = VGGBlock(filters[1],filters[0])

        self.final = nn.Sequential(
            nn.Conv3d(filters[0],num_classes,3,1,1),
        )
    def forward(self,x):
        ## -------------encoder-------------
        encoder_1=self.encoder_1(x)
        encoder_2=self.encoder_2(self.pool(encoder_1))
        encoder_3=self.encoder_3(self.pool(encoder_2))
        encoder_4=self.encoder_4(self.pool(encoder_3))
        encoder_5=self.encoder_5(self.pool(encoder_4))
        ## -------------decoder-------------
        decoder_4=self.up_4(encoder_5,encoder_4)
        decoder_4=self.decoder_4(decoder_4)

        decoder_3 = self.up_3(decoder_4,encoder_3)
        decoder_3=self.decoder_3(decoder_3)

        decoder_2 = self.up_2(decoder_3,encoder_2)
        decoder_2=self.decoder_2(decoder_2)

        decoder_1 = self.up_1(decoder_2,encoder_1)
        decoder_1=self.decoder_1(decoder_1)

        output = self.final(decoder_1)
        return output

if __name__ == '__main__':
    x = torch.randn(1, 4, 160, 160, 128)
    net = UNet(in_channels=4, num_classes=4)
    y = net(x)
    print(y.shape)


下面是论文中修改过了的3DUNet(大体上是一样的,但是很多细节):

import torch
import torch.nn as nn
from torchsummary import summary


class Modified3DUNet(nn.Module):
    def __init__(self, in_channels, n_classes, base_n_filter=8):
        super(Modified3DUNet, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.base_n_filter = base_n_filter

        self.lrelu = nn.LeakyReLU()
        self.dropout3d = nn.Dropout3d(p=0.6)
        self.upsacle = nn.Upsample(scale_factor=2, mode='nearest')
        self.softmax = nn.Softmax(dim=1)

        # Level 1 context pathway
        self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                     bias=False)
        self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1,
                                     bias=False)
        self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter)
        self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter)

        # Level 2 context pathway
        self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2)
        self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2)

        # Level 3 context pathway
        self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4)
        self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4)

        # Level 4 context pathway
        self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8)
        self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8)

        # Level 5 context pathway, level 0 localization pathway
        self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1,
                                   bias=False)
        self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16)
        self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16,
                                                                                             self.base_n_filter * 8)

        self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0,
                                   bias=False)
        self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8)

        # Level 1 localization pathway
        self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16)
        self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0,
                                   bias=False)
        self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8,
                                                                                             self.base_n_filter * 4)

        # Level 2 localization pathway
        self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8)
        self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0,
                                   bias=False)
        self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4,
                                                                                             self.base_n_filter * 2)

        # Level 3 localization pathway
        self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4)
        self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0,
                                   bias=False)
        self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2,
                                                                                             self.base_n_filter)

        # Level 4 localization pathway
        self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2)
        self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0,
                                   bias=False)

        self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0,
                                        bias=False)
        self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0,
                                        bias=False)

    def conv_norm_lrelu(self, feat_in, feat_out):
        return nn.Sequential(
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(feat_out),
            nn.LeakyReLU())

    def norm_lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.InstanceNorm3d(feat_in),
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def lrelu_conv(self, feat_in, feat_out):
        return nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False))

    def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out):
        return nn.Sequential(
            nn.InstanceNorm3d(feat_in),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            # should be feat_in*2 or feat_in
            nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm3d(feat_out),
            nn.LeakyReLU())

    def forward(self, x):
        # x[N,c,d,h,w]:[1,3,64, 64, 64]
        #  Level 1 context pathway
        out = self.conv3d_c1_1(x)
        residual_1 = out
        out = self.lrelu(out)
        out = self.conv3d_c1_2(out)
        out = self.dropout3d(out)
        out = self.lrelu_conv_c1(out)
        # Element Wise Summation
        out += residual_1
        context_1 = self.lrelu(out)
        out = self.inorm3d_c1(out)
        out = self.lrelu(out)

        # out[N,c,d,h,w]:[1,8,64, 64, 64]
        # Level 2 context pathway
        out = self.conv3d_c2(out)
        residual_2 = out
        out = self.norm_lrelu_conv_c2(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c2(out)
        out += residual_2
        out = self.inorm3d_c2(out)
        out = self.lrelu(out)
        context_2 = out

        # out[N,c,d,h,w]:[1,16,32, 32, 32]
        # Level 3 context pathway
        out = self.conv3d_c3(out)
        residual_3 = out
        out = self.norm_lrelu_conv_c3(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c3(out)
        out += residual_3
        out = self.inorm3d_c3(out)
        out = self.lrelu(out)
        context_3 = out

        # out[N,c,d,h,w]:[1,32,16, 16, 16]
        # Level 4 context pathway
        out = self.conv3d_c4(out)
        residual_4 = out
        out = self.norm_lrelu_conv_c4(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c4(out)
        out += residual_4
        out = self.inorm3d_c4(out)
        out = self.lrelu(out)
        context_4 = out

        # out[N,c,d,h,w]:[1,64,8, 8, 8]
        # Level 5
        out = self.conv3d_c5(out)
        residual_5 = out
        out = self.norm_lrelu_conv_c5(out)
        out = self.dropout3d(out)
        out = self.norm_lrelu_conv_c5(out)
        out += residual_5
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out)

        out = self.conv3d_l0(out)
        out = self.inorm3d_l0(out)
        out = self.lrelu(out)

        # out[N,c,d,h,w]:[1,64,8, 8, 8]
        # Level 1 localization pathway
        out = torch.cat([out, context_4], dim=1)
        out = self.conv_norm_lrelu_l1(out)
        out = self.conv3d_l1(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out)

        # out[N,c,d,h,w]:[1,32,16, 16, 16]
        # Level 2 localization pathway
        out = torch.cat([out, context_3], dim=1)
        out = self.conv_norm_lrelu_l2(out)
        ds2 = out
        out = self.conv3d_l2(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out)

        # out[N,c,d,h,w]:[1,16,32, 32, 32]
        # Level 3 localization pathway
        out = torch.cat([out, context_2], dim=1)
        out = self.conv_norm_lrelu_l3(out)
        ds3 = out
        out = self.conv3d_l3(out)
        out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out)

        # out[N,c,d,h,w]:[1,8,64, 64, 64]
        # Level 4 localization pathway
        out = torch.cat([out, context_1], dim=1)
        out = self.conv_norm_lrelu_l4(out)
        out_pred = self.conv3d_l4(out)

        # 深度监督
        ds2_1x1_conv = self.ds2_1x1_conv3d(ds2)
        ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv)
        ds3_1x1_conv = self.ds3_1x1_conv3d(ds3)
        ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv
        ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum)

        out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale
        seg_layer = out
        # 改变数据shape,只是改变数据的读取方式,没有改变数据的实际存储 而view方法是对实际存储做的,所以需要contiguous方法同时改变存储
        out = out.permute(0, 2, 3, 4, 1).contiguous().view(-1, self.n_classes)
        # out = out.view(-1, self.n_classes)
        out = self.softmax(out)
        return out, seg_layer


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Modified3DUNet(3, 2).to(device)
    print(model(torch.randn(1, 3, 64, 64, 64))[0].shape)
    print(model(torch.randn(1, 3, 64, 64, 64))[1].shape)
    # summary(model, (3,64, 64, 64),1)


if __name__ == '__main__':
    main()

完整的流程

参考这里,大佬给出了一个实战,就是网络用的是 3D UNet,我们直接换成本文的网络就好了,主要学习的是除了网络结构之外的部分,代码写的超级好,关键还有博客做简单介绍,太强了!

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值