论文解读《基于双向复制粘贴的半监督医学图像分割》
论文地址:论文地址
代码地址:代码地址
论文出处:CVPR2023
一、摘要:
(1) 本文提出一种直接的方法来缓解这个问题——在简单的Mean Teacher架构中双向复制粘贴有标签和无标签的数据。该方法鼓励未标记数据从标记数据中从内向和外向两个方向学习全面的公共语义。
(2) 具体来说,我们将随机裁剪的图像从标记图像(前景)复制粘贴到未标记图像(背景)上,并将未标记图像(前景)复制粘贴到标记图像(背景)上。将两幅混合图像送入学生网络,并由伪标签和真实值的混合监督信号进行监督。
(3) 实验表明,在各种半监督医学图像分割数据集上,与其他先进技术相比,有增益(例如,在ACDC数据集上,有5%的标记数据的Dice提高了超过21%)。
图2 LA数据集上不同模型的未标记和标记训练数据的Dice。在我们的方法中观察到更小的性能差距。
我们基于由现有技术和我们的方法训练的模型,从LA数据集[39]计算标记和未标记训练集的Dice分数,如图2所示。以前分别处理标记数据和未标记数据的模型在标记数据和非标记数据之间存在很大的性能差距。例如,MC-Net对标记数据获得95.59%的骰子,但对未标记数据仅获得87.63%的骰子
二、具体工作:
CutMix[42]是一种简单而强大的数据处理方法,也被称为复制粘贴(CP),它有可能鼓励未标记的数据从标记的数据中学习通用语义,因为同一地图中的像素共享的语义更接近[29]。
(1) 为了缓解标记数据和未标记数据之间的不匹配问题,我们通过提出一种令人惊讶的简单但非常有效的双向复制粘贴(BCP)方法来实现这一点。
(2) 为了训练学生网络,我们通过从有标签的图像(前景)复制粘贴随机作物到无标签的图像(背景),并相反地,从无标签的图像(前景)复制粘贴随机作物到有标签的图像(背景)来增加输入。学生网络由来自教师网络的未标记图像的伪标签和标记图像的标签映射之间的双向复制粘贴生成的监督信号来监督。
三、方法:
图3。Mean Teacher架构中的双向复制粘贴框架概述,使用2D输入绘制,以便更好的可视化。学生网络的输入由两幅有标签和两幅无标签图像以双向复制粘贴的方式混合而成。然后,为了向学生网络提供监督信号,我们通过相同的双向复制粘贴将教师网络产生的真值和伪标签组合为一个监督信号,使真值的强监督帮助伪标签的弱监督。
input data:
1、有标签的图像(前景)复制粘贴到无标签的图像(背景)
2、从无标签的图像(前景)复制粘贴到有标签的图像(背景)来增加输入
存在教师网络Ft(Xu p,Xu q;θt)和学生网络Fs(Xin,Xout;θs),其中θt和θs是参数。
在数学上,我们将医学图像的3D体积定义为。半监督医学图像分割的目标是预测每个体素的标签映射,指示背景和目标在X中的位置。K是类数。我们的训练集D由N个标记数据和M个未标记数据(N《M)组成,表示为两个子集:,其中和。
所提出的双向复制粘贴方法的总体流程如图3所示。3,在Mean Teacher架构中。我们从训练集中随机挑选两个未标记的图像和两个标记的图像。然后,我们将随机裁剪从(前景)复制粘贴到(背景)上以生成混合图像,并从(前景)拷贝粘贴到(背景)上来生成另一个混合图像。未标记的图像能够从标记的图像中从向内(Xin)和向外(Xout)两个方向学习全面的公共语义。然后将图像Xin和Xout输入到Student网络中,以预测分割掩码Yin和Yout。通过双向复制粘贴来自教师网络的未标记图像的预测和标记图像的标签图来监督分割掩模。
3.1 双向复制粘贴
3.1.1 双向复制粘贴图像
为了在一对图像之间进行复制粘贴,我们首先生成零中心掩码,指示体素来自前景(0)还是背景(1)图像。零值区域的大小为βH×βW×βL,其中β∈(0,1)。然后,我们双向复制粘贴标记和未标记的图像,如下所示:
,,表示逐元素乘法。采用两个标记和未标记的图像来保持输入的多样性。
def context_mask(img, mask_ratio):
batch_size, channel, img_x, img_y, img_z = img.shape[0],img.shape[1],img.shape[2],img.shape[3],img.shape[4]
loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda()
mask = torch.ones(img_x, img_y, img_z).cuda()
patch_pixel_x, patch_pixel_y, patch_pixel_z = int(img_x*mask_ratio), int(img_y*mask_ratio), int(img_z*mask_ratio)
w = np.random.randint(0, 112 - patch_pixel_x)
h = np.random.randint(0, 112 - patch_pixel_y)
z = np.random.randint(0, 80 - patch_pixel_z)
mask[w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
loss_mask[:, w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
return mask.long(), loss_mask.long()
volume_batch, label_batch = sampled_batch['image'][:args.labeled_bs], sampled_batch['label'][:args.labeled_bs]
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:]
lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:]
with torch.no_grad():
img_mask, loss_mask = context_mask(img_a, args.mask_ratio)#mask_ratio=2/3
"""Mix Input"""
volume_batch = img_a * img_mask + img_b * (1 - img_mask)
label_batch = lab_a * img_mask + lab_b * (1 - img_mask)
3.1.2 双向复制粘贴监督信号
为了训练学生网络,还通过BCP操作生成监控信号。将未标记的图像和输入到教师网络中,并计算它们的概率图
with torch.no_grad():
unoutput_a, _ = ema_model(unimg_a)
unoutput_b, _ = ema_model(unimg_b)
标准标签伪标签
将作为监督,监督的学生网络预测
3.1.3损失函数
学生网络的每个输入图像由来自标记图像和未标记图像的分量组成。直观地说,标记图像的真实掩模通常比未标记图像的伪标记更准确。我们使用α来控制未标记图像像素对损失函数的贡献。Xin和Xout的损失函数分别由:
其中Lseg是Dice损失和交叉熵损失的线性组合。Qin和Qout的计算公式为:
outputs_l, _ = model(mixl_img)#######model学生模型
outputs_u, _ = model(mixu_img)
(8) 在每次迭代中,我们通过损失函数的随机梯度下降更新学生网络中的参数θs:
(9)然后,更新第(k+1)次迭代时的教师网络参数:其中λ是平滑系数参数。
四、与其他优秀模型做对比
图4 在LA数据集上使用10%标记数据和地面实况的几种半监督分割方法的可视化
4.1 消融实验
我们进行消融研究,以显示BCP中各成分的影响。包括CP方向、掩蔽策略的设计选择。我们还逐步研究了在ACDC数据集上,与5%标记率的竞争对手相比,我们的方法的显著改进。补充材料中显示了对ACDC数据集的一些消融研究。
复制粘贴方向
表4。消融研究的复制粘贴方向。In:向内复制粘贴(前景:未标记,背景:已标记)。Out:向外复制粘贴(前景:已标记,背景:未标记)。CP:直接复制粘贴(背景和前景:已标记和已标记以及未标记和未标记)。
我们设计了三个实验来研究表4中不同复制粘贴方向的影响。向内和向外复制粘贴(表中的In和Out)是指分别使用或来训练网络。
教师模型的初始化策略
def pre_train(args, snapshot_path):
model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
sub_bs = int(args.labeled_bs/2
model.train()
iter_num = 0
best_dice = 0
max_epoch = pre_max_iterations // len(trainloader) + 1
iterator = tqdm(range(max_epoch), ncols=70)
for epoch_num in iterator:
for _, sampled_batch in enumerate(trainloader):
volume_batch, label_batch = sampled_batch['image'][:args.labeled_bs], sampled_batch['label'][:args.labeled_bs]
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:]
lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:]
with torch.no_grad():
img_mask, loss_mask = context_mask(img_a, args.mask_ratio)
"""Mix Input"""
volume_batch = img_a * img_mask + img_b * (1 - img_mask)
label_batch = lab_a * img_mask + lab_b * (1 - img_mask)
outputs, _ = model(volume_batch)
loss_ce = F.cross_entropy(outputs, label_batch)
loss_dice = DICE(outputs, label_batch)
loss = (loss_ce + loss_dice) / 2
def self_train(args, pre_snapshot_path, self_snapshot_path):
model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
ema_model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes, mode="train")
iterator = tqdm(range(max_epoch), ncols=70)
for epoch in iterator:
for _, sampled_batch in enumerate(trainloader):
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:args.labeled_bs]
lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:args.labeled_bs]
unimg_a, unimg_b = volume_batch[args.labeled_bs:args.labeled_bs+sub_bs], volume_batch[args.labeled_bs+sub_bs:]
with torch.no_grad():
unoutput_a, _ = ema_model(unimg_a)
unoutput_b, _ = ema_model(unimg_b)
plab_a = get_cut_mask(unoutput_a, nms=1)
plab_b = get_cut_mask(unoutput_b, nms=1)
img_mask, loss_mask = context_mask(img_a, args.mask_ratio)
consistency_weight = get_current_consistency_weight(iter_num // 150)
mixl_img = img_a * img_mask + unimg_a * (1 - img_mask)
mixu_img = unimg_b * img_mask + img_b * (1 - img_mask)
mixl_lab = lab_a * img_mask + plab_a * (1 - img_mask)
mixu_lab = plab_b * img_mask + lab_b * (1 - img_mask)
outputs_l, _ = model(mixl_img)
outputs_u, _ = model(mixu_img)
loss_l = mix_loss(outputs_l, lab_a, plab_a, loss_mask, u_weight=args.u_weight)
loss_u = mix_loss(outputs_u, plab_b, lab_b, loss_mask, u_weight=args.u_weight, unlab=True)
loss = loss_l + loss_u
五 结论
我们提出了用于半监督医学图像分割的双向复制粘贴(BCP)。我们以双向方式扩展了基于复制粘贴的方法,减少了标记数据和未标记数据之间的分布差距。在LA、NIH胰腺和ACDC数据集上的实验表明了所提出的BCP的优越性,在具有5%标记数据的ACDC数据集中,Dice甚至提高了21%以上。注意,与骨干网络相比,BCP没有引入新的参数或计算成本。
局限性我们没有专门设计一个模块来增强局部特征学习。尽管BCP的表现比所有竞争对手都好,但对比度极低的目标零件仍然很难很好地分割(例如,图4第二行的左下角零件缺失)。