简单入门理解半监督中的Mean Teacher

Mean Teacher出自此文。本文所用代码引用自此处。接下来我们以偏伪代码的风格来通俗解释Mean Teacher。

首先,Mean Teacher中有两个网络,一个称为Teacher,一个称为Student,其结构完全一致,只是网络权重更新方法不同:

model = create_model()	# Student Model
ema_model = create_model(ema=True)	# Teacher Model (Equipped with EMA)

先暂时不管EMA是什么意思。一般来讲,在半监督中,每个输入Batch包含一半已标注的图像与一般未标注的图像。首先,整个Batch会被送入Student Model中,得到一个预测结果。对于Batch中的已标注部分,利用结果与真值计算loss,进行梯度反传,从而更新Student Model的参数,如下所示:

outputs = model(volume_batch)	# 将图像输入Student中
supervised_loss = ce_loss(outputs[:args.labeled_bs], label_batch[:][:args.labeled_bs].long())	# 计算已标注部分的loss

而对于Batch中的未标注部分,其输入Student Model也会得到一个结果(记为A),这个结果有什么用呢?现在我们来看Teacher Model。具体来说,未标注的图像会在加入随机噪声后,会被送入Teacher Model中,得到一个预测结果(记为B):

ema_output = ema_model(ema_inputs)	# ema_inputs通过batch中的未标注图像加噪得到

那么我们希望A与B的结果保持一致,如下所示:

consistency_loss = torch.mean((outputs[args.labeled_bs:]-ema_output)**2)

几个常见问题:

  • Q1:EMA是什么?Teacher模型不通过Loss反传更新梯度,那么其参数是怎么更新的?
  • A1:EMA即Exponential Moving Average,指数移动平均。通俗来讲的话,Teacher模型的参数由Student模型过去一段时间的参数共同决定,可以通过拷贝Student模型的参数并计算以得到。这么设计可以使Teacher模型反映Student在过去一段时间内的状态。
  • Q2:上面提到的Consistency为何会对半监督起到帮助?
  • A2:有很多种理解。比方说,如果Teacher与Student能对相同的样本得到一致的结果,说明网络当前的参数比较鲁棒泛化——加噪前后的结果一致,说明网络不太可能overfit到一些特殊特征,在这种情况下网络的预测结果一般是比较好的;另一种理解是,此时可以将Unlabeled Data视为一种特殊的有监督训练。无标注数据送入Student得到预测结果后,此时对应的GT为Teacher生成的伪标签。
  • Q3:为什么不将已标注数据送入Student中去和Teacher算一致性?
  • A3:已标注数据已经有真实GT了,因此没必要去用Student产生的(可能有误)的伪标签。
  • 15
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值