Meta-Semi: A Meta-learning Approach for Semi-supervised Learning

论文地址

Abstract

半监督学习近年来得到广泛关注,然而他们往往引入多个超参数,但一般半监督问题有标注的数据是稀缺的,没有足够的标注用来调整这些超参数。本文提出了一种基于meta-learning的半监督算法,只用调整一种超参数就可以在很多半监督学习场景中获得不错的表现。我们定义了一个meta optimization problem,通过动态加权未标注样本的损失来最小化有标注的数据上的loss,未标注样本在训练时和伪标注相关联。由于直接求解meta问题计算量很大,本文提出了一种有效的方法来获得近似解。

1 Introduction

全监督的标注花费大量时间,因此半监督(SSL)受到了青睐。

在深度学习领域,许多成功的SSL方法通过应用无监督一致性正则化来合并无标注数据。具体来说是在无标签样本中加入一些扰动,然后强化原始数据和加入扰动的数据之间的模型预测结果的一致性。这种方法引入了多种超参数,但在实际中没有足够的样本来调整这些超参数。

SSL领域的另一挑战是如何高效地得到labeled data。一般的SSL算法在计算loss是都是把有标注的数据和无标注的数据分开成不同的项计算,因为无标注的数据没有监督啊,这样的话就没有充分利用有标记的数据。

本文提出了一种基于meta-learning的SSL算法,叫做Meta-Semi,来高效利用labeled data,然而只需要调整一种超参数就能获得比较好的性能。这种方法基于一个简单的直觉:*if the network is trained with correctly “pseudo-labeled” unannotated samples, the final loss on labeled data should be minimized.*具体来说,我们首先明确定义meta reweighting objective:为不同的伪标记样本找到最佳权重以训练网络,从而使标记数据的最终损失最小化。注意到该算法通过优化器求解计算量很大,我们提出了一种近似公式,在此基础上可以获得相似度结果。我们从理论上表明,one meta gradient step足以在每次训练迭代中获得近似解。最后我们提出一种动态权重方法用0-1的权重来为伪标注样本重新分配权重。理论分析表明,我们的方法可收敛到监督损失函数的驻点。

在image classification benchmarks上做了实验。

2 Related Work

3 Method

我们先用伪标注计算未标注数据的交叉熵损失,然后通过解决一个meta optimization problem来reweight每个未标注样本的loss,从而最小化标记数据监督下的loss。由于这个问题计算量,我们踢出一个近似求法只需要一个meta的梯度下降步骤就可能获得0-1的动态权重。

3.1 Meta Optimization Problem

假设我们用SGD做梯度下降。在每个iteration我们送入一些labeled样本 X = { ( x i , y i ) } \mathcal{X}=\{(x_i,y_i)\} X={(xi,yi)}和一些unlabeled样本 U = { ( u j , y ^ j ) } \mathcal U=\{(u_j,\hat y_j)\} U={(uj,y^j)}。我们使用MixUp进行数据增广,而不是直接使用这些样本,进行增广操作后的样本记作 X ~ \widetilde {\mathcal X} X U ~ \widetilde {\mathcal U} U

考虑到用参数 θ \theta θ训练深度网络。我们首先把一些未标注样本 u ~ j \widetilde {u}_j u j送入网络,生成预测值 p ( u ~ j ∣ θ ) p(\widetilde u_j|\theta) p(u jθ)。然后用伪标注 y ^ j \hat y_j y^j计算交叉熵损失 L ( y ^ j , p ( u ~ j ∣ θ ) ) L(\hat y_j,p(\widetilde u_j|\theta)) L(y^j,p(u jθ))。这些样本的损失之后通过 w j ∗ ∈ [ 0 , 1 ] w^*_j\in[0,1] wj[0,1]加权得到最终的损失函数:
L m e t a = 1 ∑ j = 1 ∣ u ~ ∣ w j ∗ ∑ j = 1 u ~ w j ∗ L ( y ^ j , p ( u ~ j ∣ θ ) ) . \mathcal L_{meta}=\frac{1}{\sum^{|\widetilde u|}_{j=1}w^*_j}\sum^{\widetilde u}_{j=1}w^*_jL(\hat y_j,p(\widetilde u_j|\theta)). Lmeta=j=1u wj1j=1u wjL(y^j,p(u jθ)).
不失一般性,假定当 ∑ j = 1 ∣ u ~ ∣ w j ∗ = 0 \sum^{|\widetilde u|}_{j=1}w^*_j=0 j=1u wj=0 L m e t a = 0 \mathcal L_{meta}=0 Lmeta=0。这个权重 w j ∗ w^*_j wj是通过从标记的数据上训练得来的。为了说明这一点,我们首先考虑用类似的加权损失训练网络:
θ ∗ ( w ) = arg min ⁡ θ ∑ j = 1 ∣ u ~ ∣ w j L ( y ^ j , p ( u ~ j ∣ θ ) ) \theta^*(w)=\argmin_{\theta}\sum^{|\widetilde u|}_{j=1}w_jL(\hat y_j,p(\widetilde u_j|\theta)) θ(w)=θargminj=1u wjL(y^j,p(u jθ))
其中 θ ∗ ( w ) \theta^*(w) θ(w)就是weighted loss的最优解,它是关于权重向量 w = [ w 1 , w 2 , ⋯   ] T w=[w_1,w_2,\cdots]^T w=[w1,w2,]T的函数。权重 w ∗ w^* w通过最小化labeled data X ~ \widetilde {\mathcal X} X θ ∗ ( w ) \theta^*(w) θ(w)的loss得到:
w ∗ = arg min ⁡ w j ∈ [ 0 , 1 ] , j = 1 , ⋯   , ∣ u ~ ∣ ∑ i = 1 ∣ X ~ ∣ L ( y ~ i , p ( x ~ i ∣ θ ∗ ( w ) ) ) w^*=\argmin_{w_j\in[0,1],j=1,\cdots,|\widetilde u|}\sum^{|\widetilde {\mathcal X}|}_{i=1}L(\widetilde y_i,p(\widetilde x_i|\theta^*(w))) w=wj[0,1],j=1,,u argmini=1X L(y i,p(x iθ(w)))
小结一下,也就是选择j个unlabeled data的样本,在labeled data中训练出j个权重,从而最小化unlabeled data的样本得到的loss值。

直观来说,labeled data的作用就是决定哪些伪标注样本应该被使用,哪些不该被使用,这样可以更好地利用监督信息。

3.2 Approximating the Meta Solution

在这里插入图片描述
3.4节证明了为什么我们的方法最终会收敛,感兴趣可以看看,关系到Lipschitz-smooth连续假设,同时作者做的实验也证明了这一点。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值