代码:https://github.com/CuriousAI/mean-teacher
出处:NIPS2017
一、背景
本文的目标是在学生模型和教师模型完全相同的情况下,从学生模型中逐步形成一个表现更好的教师模型。
首先,由于模型的 softmax 输出通常不能获得在训练数据之外的准确预测,所以可以考虑在训练数据中添加一些噪声来缓解,有噪声的教师模型可以产生更准确的结果,如图 1d 所示
其次, π \pi π 模型 [13] 通过时间集成来进一步的改进教师模型。但每个目标每次迭代值更新一次,学习到的信息注入非常缓慢。
故本文提出了 Mean Teacher,通过平均模型的权重而不是预测的结果来更新 Teacher 模型。
二、方法
本文方法结构框架如图 2 所示
Teacher model 的模型参数是通过 Student model 的模型参数指数移动平均来获得的。
总体过程:
- 假设有带标签的数据 labeled data x 1 x_1 x1 和无标签的数据 unlabeled data x 2 x_2 x2,对种数据分别添加噪声
- 首先,将有标签的数据 x 1 x_1 x1 输入学生模型,得到预测结果 y 1 s y^s_1 y1s,并计算交叉熵损失为 loss1
- 然后,将无标签的数据 x 2 x_2 x2 输入学生模型,得到预测结果 y 2 s y^s_2 y2s,同时,将无标签的数据 x 2 x_2 x2 输入教师模型,得到预测结果 y 2 t y^t_2 y2t,求两个预测结果的损失 loss2,即 J ( θ ) J(\theta) J(θ),也就是求两个输出的均方误差
- 接着,计算总损失 loss1 + loss2
- 最后,学生模型的权重通过梯度反向传播更新,教师模型的权重通过指数移动平均来更新
两个分布的一致性的程度 J ( θ ) J(\theta) J(θ) 定义为学生模型的预测(权重为 θ \theta θ,噪声为 η \eta η)与教师模型的预测(权重为 θ ′ \theta' θ′,噪声为 η ′ \eta' η′)之间的期望差距:
定义教师网络第 t 个 training step 的参数 θ t ′ \theta_t' θt′ 为第 t-1 个 traing step 的参数加上当前学生网络的参数, α \alpha α 为系数