半监督学习之Mean teachers

半监督学习Mean teachers

在这里插入图片描述
网络整体的架构包括两个部分student model和teacher model:

  1. student model的网络参数通过学习,梯度下降获得。

  2. teacher model的网络参数通过student model的网络参数的moving average得到。

student model的网络参数更新方法:

通过损失函数的梯度下降更新参数得到。
其中损失函数包括两个部分:

第一部分是有监督损失函数,保证有标签训练数据拟合;

第二部分是无监督损失函数,主要是保证student model的预测结果和teacher model的预测结果尽量的相似。因为teacher model的参数是student model的网络参数的moving average,所以,对于任何新来的数据,预测结果都不应该有太大的抖动。
如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。

teacher model的网络参数的更新方法:

通过student model网络参数的moving average得到
θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt=αθt1+(1α)θt

基本流程

假设有一批训练样本X1,X2,其中X1使有标签数据(对应标签是z1),X2使无标签数据。具体的训练过程如下:

  1. 把这一批样本作为student网络输入,然后分别得到输出的标签:ys1,ys2;

  2. 构造对于有标签数据X1的损失函数,有标签分类损失函数L1(z1,ys1);

  3. 把这批数据作为teacher model的输入,得到输出的标签yt1,yt2;

  4. 构造无监督损失函数L2,论文中采用MSE损失函数: J ( x , θ ) = E x , η ′ , η [ ∣ ∣ f ( x , θ ′ , η ′ ) − f ( x , θ , η ) ∣ ∣ 2 ] J(x, \theta)=E_{x, \eta ^{\prime}}, \eta \left[ ||f(x, \theta ^{\prime}, \eta ^{\prime})-f(x, \theta , \eta)||^{2}\right] J(x,θ)=Ex,η,η[f(x,θ,η)f(x,θ,η)2]

  5. 总损失函数L1+L2梯度下降,更新student model的网络参数,通过moving average更新teacher model的网络参数 θ t ′ = α θ t − 1 ′ + ( 1 − α ) θ t \theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t} θt=αθt1+(1α)θt

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值