论文解读《Bidirectional Copy-Paste for Semi-Supervised Medical Image Segmentation》

论文解读《基于双向复制粘贴的半监督医学图像分割》
论文地址论文地址
代码地址代码地址
论文出处:CVPR2023
一、摘要
(1) 本文提出一种直接的方法来缓解这个问题——在简单的Mean Teacher架构中双向复制粘贴有标签和无标签的数据。该方法鼓励未标记数据从标记数据中从内向和外向两个方向学习全面的公共语义。
(2) 具体来说,我们将随机裁剪的图像从标记图像(前景)复制粘贴到未标记图像(背景)上,并将未标记图像(前景)复制粘贴到标记图像(背景)上。将两幅混合图像送入学生网络,并由伪标签和真实值的混合监督信号进行监督。
(3) 实验表明,在各种半监督医学图像分割数据集上,与其他先进技术相比,有增益(例如,在ACDC数据集上,有5%的标记数据的Dice提高了超过21%)。
在这里插入图片描述
图2 LA数据集上不同模型的未标记和标记训练数据的Dice。在我们的方法中观察到更小的性能差距。

我们基于由现有技术和我们的方法训练的模型,从LA数据集[39]计算标记和未标记训练集的Dice分数,如图2所示。以前分别处理标记数据和未标记数据的模型在标记数据和非标记数据之间存在很大的性能差距。例如,MC-Net对标记数据获得95.59%的骰子,但对未标记数据仅获得87.63%的骰子
二、具体工作:
在这里插入图片描述
CutMix[42]是一种简单而强大的数据处理方法,也被称为复制粘贴(CP),它有可能鼓励未标记的数据从标记的数据中学习通用语义,因为同一地图中的像素共享的语义更接近[29]。

(1) 为了缓解标记数据和未标记数据之间的不匹配问题,我们通过提出一种令人惊讶的简单但非常有效的双向复制粘贴(BCP)方法来实现这一点。

(2) 为了训练学生网络,我们通过从有标签的图像(前景)复制粘贴随机作物到无标签的图像(背景),并相反地,从无标签的图像(前景)复制粘贴随机作物到有标签的图像(背景)来增加输入。学生网络由来自教师网络的未标记图像的伪标签和标记图像的标签映射之间的双向复制粘贴生成的监督信号来监督。

三、方法
在这里插入图片描述
图3。Mean Teacher架构中的双向复制粘贴框架概述,使用2D输入绘制,以便更好的可视化。学生网络的输入由两幅有标签和两幅无标签图像以双向复制粘贴的方式混合而成。然后,为了向学生网络提供监督信号,我们通过相同的双向复制粘贴将教师网络产生的真值和伪标签组合为一个监督信号,使真值的强监督帮助伪标签的弱监督。

input data:
1、有标签的图像(前景)复制粘贴到无标签的图像(背景)
2、从无标签的图像(前景)复制粘贴到有标签的图像(背景)来增加输入

存在教师网络Ft(Xu p,Xu q;θt)和学生网络Fs(Xin,Xout;θs),其中θt和θs是参数。

在数学上,我们将医学图像的3D体积定义为在这里插入图片描述。半监督医学图像分割的目标是预测每个体素的标签映射(Y∈{0,1,…,K−1}W×H×L),指示背景和目标在X中的位置。K是类数。我们的训练集D由N个标记数据和M个未标记数据(N《M)组成,表示为两个子集:在这里插入图片描述,其中在这里插入图片描述在这里插入图片描述

所提出的双向复制粘贴方法的总体流程如图3所示。3,在Mean Teacher架构中。我们从训练集中随机挑选两个未标记的图像(Xu p,Xu q)和两个标记的图像在这里插入图片描述。然后,我们将随机裁剪从Xl i(前景)复制粘贴到Xu q(背景)上以生成混合图像Xout,并从Xu p(前景)拷贝粘贴到Xl j(背景)上来生成另一个混合图像Xin。未标记的图像能够从标记的图像中从向内(Xin)和向外(Xout)两个方向学习全面的公共语义。然后将图像Xin和Xout输入到Student网络中,以预测分割掩码Yin和Yout。通过双向复制粘贴来自教师网络的未标记图像的预测和标记图像的标签图来监督分割掩模。

3.1 双向复制粘贴
3.1.1 双向复制粘贴图像
为了在一对图像之间进行复制粘贴,我们首先生成零中心掩码M∈{0,1}W×H×L,指示体素来自前景(0)还是背景(1)图像。零值区域的大小为βH×βW×βL,其中β∈(0,1)。然后,我们双向复制粘贴标记和未标记的图像,如下所示:
在这里插入图片描述
Xl i,Xl j∈Dl
,i 6=j,Xu p,Xu q∈Du,p 6=q1∈{1}W×H×L在这里插入图片描述表示逐元素乘法。采用两个标记和未标记的图像来保持输入的多样性。

def context_mask(img, mask_ratio):
    batch_size, channel, img_x, img_y, img_z = img.shape[0],img.shape[1],img.shape[2],img.shape[3],img.shape[4]
    loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda()
    mask = torch.ones(img_x, img_y, img_z).cuda()
    patch_pixel_x, patch_pixel_y, patch_pixel_z = int(img_x*mask_ratio), int(img_y*mask_ratio), int(img_z*mask_ratio)
    w = np.random.randint(0, 112 - patch_pixel_x)
    h = np.random.randint(0, 112 - patch_pixel_y)
    z = np.random.randint(0, 80 - patch_pixel_z)
    mask[w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
    loss_mask[:, w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
    return mask.long(), loss_mask.long()
    
volume_batch, label_batch = sampled_batch['image'][:args.labeled_bs], sampled_batch['label'][:args.labeled_bs]
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:]
lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:]
with torch.no_grad():
    img_mask, loss_mask = context_mask(img_a, args.mask_ratio)#mask_ratio=2/3

"""Mix Input"""
volume_batch = img_a * img_mask + img_b * (1 - img_mask)
label_batch = lab_a * img_mask + lab_b * (1 - img_mask)

在这里插入图片描述
在这里插入图片描述

3.1.2 双向复制粘贴监督信号
为了训练学生网络,还通过BCP操作生成监控信号。将未标记的图像Xu pXu q输入到教师网络中,并计算它们的概率图
在这里插入图片描述

with torch.no_grad():
    unoutput_a, _ = ema_model(unimg_a)
    unoutput_b, _ = ema_model(unimg_b)

在这里插入图片描述标准标签在这里插入图片描述伪标签在这里插入图片描述

Yin和Yout将作为监督,监督Xin和Xout的学生网络预测
在这里插入图片描述

3.1.3损失函数
学生网络的每个输入图像由来自标记图像和未标记图像的分量组成。直观地说,标记图像的真实掩模通常比未标记图像的伪标记更准确。我们使用α来控制未标记图像像素对损失函数的贡献。Xin和Xout的损失函数分别由:
在这里插入图片描述
其中Lseg是Dice损失和交叉熵损失的线性组合。Qin和Qout的计算公式为:
Qin=Fs(Xin;θs),Qout=Fs(Xout;Θs)。

outputs_l, _ = model(mixl_img)#######model学生模型
outputs_u, _ = model(mixu_img)

在这里插入图片描述

(8) 在每次迭代中,我们通过损失函数的随机梯度下降更新学生网络中的参数θs:
Lall=Lin+Lout。
(9)然后,更新第(k+1)次迭代时的教师网络参数在这里插入图片描述在这里插入图片描述其中λ是平滑系数参数。

四、与其他优秀模型做对比
在这里插入图片描述
图4 在LA数据集上使用10%标记数据和地面实况的几种半监督分割方法的可视化
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
4.1 消融实验
我们进行消融研究,以显示BCP中各成分的影响。包括CP方向、掩蔽策略的设计选择。我们还逐步研究了在ACDC数据集上,与5%标记率的竞争对手相比,我们的方法的显著改进。补充材料中显示了对ACDC数据集的一些消融研究。
复制粘贴方向
在这里插入图片描述
表4。消融研究的复制粘贴方向。In:向内复制粘贴(前景:未标记,背景:已标记)。Out:向外复制粘贴(前景:已标记,背景:未标记)。CP:直接复制粘贴(背景和前景:已标记和已标记以及未标记和未标记)。

我们设计了三个实验来研究表4中不同复制粘贴方向的影响。向内和向外复制粘贴(表中的In和Out)是指分别使用Xl M+Xu(1−M)Xu M+Xl(1−M)来训练网络。

教师模型的初始化策略
在这里插入图片描述

def pre_train(args, snapshot_path):
    model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
    sub_bs = int(args.labeled_bs/2
    model.train()
    iter_num = 0
    best_dice = 0
    max_epoch = pre_max_iterations // len(trainloader) + 1
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'][:args.labeled_bs], sampled_batch['label'][:args.labeled_bs]
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:]
            lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:]
            with torch.no_grad():
                img_mask, loss_mask = context_mask(img_a, args.mask_ratio)

            """Mix Input"""
            volume_batch = img_a * img_mask + img_b * (1 - img_mask)
            label_batch = lab_a * img_mask + lab_b * (1 - img_mask)

            outputs, _ = model(volume_batch)
            loss_ce = F.cross_entropy(outputs, label_batch)
            loss_dice = DICE(outputs, label_batch)
            loss = (loss_ce + loss_dice) / 2

def self_train(args, pre_snapshot_path, self_snapshot_path):
    model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
    ema_model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:args.labeled_bs]
            lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:args.labeled_bs]
            unimg_a, unimg_b = volume_batch[args.labeled_bs:args.labeled_bs+sub_bs], volume_batch[args.labeled_bs+sub_bs:]
            with torch.no_grad():
                unoutput_a, _ = ema_model(unimg_a)
                unoutput_b, _ = ema_model(unimg_b)
                plab_a = get_cut_mask(unoutput_a, nms=1)
                plab_b = get_cut_mask(unoutput_b, nms=1)
                img_mask, loss_mask = context_mask(img_a, args.mask_ratio)
            consistency_weight = get_current_consistency_weight(iter_num // 150)

            mixl_img = img_a * img_mask + unimg_a * (1 - img_mask)
            mixu_img = unimg_b * img_mask + img_b * (1 - img_mask)
            mixl_lab = lab_a * img_mask + plab_a * (1 - img_mask)
            mixu_lab = plab_b * img_mask + lab_b * (1 - img_mask)
            outputs_l, _ = model(mixl_img)
            outputs_u, _ = model(mixu_img)
            loss_l = mix_loss(outputs_l, lab_a, plab_a, loss_mask, u_weight=args.u_weight)
            loss_u = mix_loss(outputs_u, plab_b, lab_b, loss_mask, u_weight=args.u_weight, unlab=True)

            loss = loss_l + loss_u
        

五 结论
我们提出了用于半监督医学图像分割的双向复制粘贴(BCP)。我们以双向方式扩展了基于复制粘贴的方法,减少了标记数据和未标记数据之间的分布差距。在LA、NIH胰腺和ACDC数据集上的实验表明了所提出的BCP的优越性,在具有5%标记数据的ACDC数据集中,Dice甚至提高了21%以上。注意,与骨干网络相比,BCP没有引入新的参数或计算成本。
局限性我们没有专门设计一个模块来增强局部特征学习。尽管BCP的表现比所有竞争对手都好,但对比度极低的目标零件仍然很难很好地分割(例如,图4第二行的左下角零件缺失)。

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值