论文解读《Semi-supervised Le ft Atrium Segmentation with Mutual Consistency Training》
论文出处:MICCAI2022
论文地址:论文地址
代码地址:代码地址
一、摘要:
(1) 低估了训练中挑战性区域(如小分支或模糊边缘)的重要性。
(2) 提出了一种Mutual Consistency Network(MC-Net)用于从3D MR图像中分割半监督左心房。
(3) MC-Net优于最近六种半监督左心房分割方法,并在LA数据库上设置了最新的最先进的性能。
二、引言:
(1) 具有挑战性的区域包含更关键的信息,因为困难的样本可以使训练更有效。
(2) 我们认为认知不确定性能够评估模型的泛化能力。
(3) 因此,在本文中,我们提出了一种新的相互一致性模型(MC-Net,见图2),用于从3D MR图像中分割半监督左心房。MC-Net由一个编码器和两个略有不同的解码器组成,两个输出的差异用于捕获不确定性信息。
(4) 该模型的贡献包括:
(1)探索基于模型的不确定性信息,以强调训练过程中未标记的挑战区域;
(2)(sharpening function)设计了一种新的循环伪标签方案,通过鼓励相互一致性来促进模型的训练;
(3)实验表明,所提出的MC-Net在LA数据库上的半监督左心房分割任务中取得了最新的性能。
三、方法:
3.1 模型结构:
认知不确定性的测量有几种有代表性的方法。例如,Monte Carlo dropout[3]就是一个流行的例子。给定一个3D输入
样本
,我们可以用随机dropout执行N个随机正向传递,其中dropout层能够从原始模型θ中采样子模型N。这样,深度模型θ输出一组概率向量:{Pn}N N =1。不确定性u可以用所有子模型预测的统计量θn来近似。例如,相关工作[17]使用蒙特卡洛dropout来估计不确定性u为:
其中Pcn表示第C个类在第N次中的输出,μc是N个预测的平均值,C是类的数量,不确定性
本质上是体素级熵。该方法的一个问题是需要进行多次推断,例如在[17]中每次迭代都需要进行8次随机正向传递来估计不确定性,这带来了更多的计算成本。
其中概率输出PA和PB分别由深度特征FA和FB通过Sigmoid激活函数得到。
3.2 循环伪标签
锐化函数定义:
我们使用锐化函数将概率输出PA和PB转换为软伪标签sPLA和sPLB ,
。
其中T是一个常数,用来控制锐化温度。软伪标签对训练[5]的熵正则化有贡献。相比通过使用固定阈值生成伪标签,可以实现软伪标签
为了消除一些错误标记的训练数据[14]的影响。
def sharpening(P):
T = 1/args.temperature
P_sharpen = P ** T / (P ** T + (1-P) ** T)
return P_sharpen
sPLA来监督PB,再用sPLB来监督PA,达到相互一致。通过这种方式,两个解码器可以相互学习。
其中Dice代表Dice损失,L2是均方误差(MSE)损失,Y是ground truth,入是平衡Lseg和Lc的权重。注意,Lseg仅从标记数据中计算,Lc是无监督的,用于监督所有的训练数据。
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
model.train()
outputs = model(volume_batch)
num_outputs = len(outputs)
y_ori = torch.zeros((num_outputs,) + outputs[0].shape)
y_pseudo_label = torch.zeros((num_outputs,) + outputs[0].shape)
loss_seg = 0
loss_seg_dice = 0
for idx in range(num_outputs):
y = outputs[idx][:labeled_bs,...]
y_prob = F.softmax(y, dim=1)
loss_seg += F.cross_entropy(y[:labeled_bs], label_batch[:labeled_bs])
loss_seg_dice += dice_loss(y_prob[:,1,...], label_batch[:labeled_bs,...] == 1)
y_all = outputs[idx]
y_prob_all = F.softmax(y_all, dim=1)
y_ori[idx] = y_prob_all
y_pseudo_label[idx] = sharpening(y_prob_all)
loss_consist = 0
for i in range(num_outputs):
for j in range(num_outputs):
if i != j:
loss_consist += consistency_criterion(y_ori[i], y_pseudo_label[j])
四、实验结果:
4.1 结果:
图3从左到右分别显示了UA-MT[17]、SASSNet[6]、DTC[7]、我们的MC-Net在LA数据库上获得的结果以及相应的ground truth。
表1给出了Dice、Jaccard、95% Hausdorff Distance (95HD)和平均表面距离(ASD)的定量结果。
4.2消融实验:
class MCNet3d_v1(nn.Module):
def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
super(MCNet3d_v1, self).__init__()
self.encoder = Encoder(n_channels, n_classes, n_filters,normalization, has_dropout, has_residual)
self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0)
self.decoder2 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 1)
def forward(self, input):
features = self.encoder(input)
out_seg1 = self.decoder1(features)
out_seg2 = self.decoder2(features)
return out_seg1, out_seg2
五、结论:
(1) 提出了一种相互一致性网络(MC-Net)用于半监督左心房分割。
(2) 因此,通过设计的循环伪标签方案,我们的模型被鼓励生成一致和低熵的预测,以便从这些关键区域捕获更多的广义特征来提高模型的训练。
(3) 所提出的MC-Net实现了LA数据库上最准确的半监督左心房分割性能。