对抗训练(Adversarial Training)和虚拟对抗训练(Virtual Adversarial Training),都可以作为正则化方法,来增强机器学习模型的鲁棒性。
一.对抗训练
GAN之父Ian Goodfellow在15年的ICLR中 第一次提出了对抗训练这个概念,简而言之,就是在原始输入样本 x 上加一个扰动 radv ,得到对抗样本后,用其进行训练。也就是说,问题可以被抽象成这么一个模型:
minθ−logP(y|x+radv;θ)
其中,y为gold label,θ 为模型参数。那扰动要如何计算呢?Goodfellow认为,神经网络由于其线性的特点,很容易受到线性扰动的攻击。
This linear behavior suggests that cheap, analytical perturbations of a linear model should also damage neural networks.
于是,他提出了 Fast Gradient Sign Method (FGSM) ,来计算输入样本的扰动。扰动可以被定义为:
radv=ϵ⋅sgn(▽xL(θ,x,y))
其中,sgn为符号函数,L为损失函数。Goodfellow发现,令ϵ=0.25,用这个扰动能给一个单层分类器造成99.9%的错误率。看似这个扰动的发现有点拍脑门,但是仔细想想,其实这个扰动计算的思想可以理解为:将输入样本向着损失上升的方向再进一步,得到的对抗样本就能造成更大的损失,提高模型的错误率。
Goodfellow还总结了对抗训练的两个作用:
1.提高模型应对恶意对抗样本时的鲁棒性;
2.作为一种regularization,减少overfitting,提高泛化能力。
对抗训练的损失函数:
minθE(x,y)∼D[maxradv∈SL(θ,x+radv,y)]
该公式分为两个部分,一个是内部损失函数的最大化,一个是外部经验风险的最小化。
内部max是为了找到worst-case的扰动,也就是攻击,其中,L 为损失函数,S 为扰动的范围空间。
外部min是为了基于该攻击方式,找到最鲁棒的模型参数,也就是防御,其中D是输入样本的分布。
目标是希望的是预测加入扰动后的结果和真实的一致,考虑损失函数为:
Ladv(xl,θ):=D[q(y|xl),p(y|xl+radv,θ)
这里的q(y|xl)为真实数据的预测,p(y|xl+radv,θ)为加入扰动的预测,D是衡量的标准,这里可以是KL散度。 而对于扰动,则是尽可能的在扰动上使得这两个预测最大化,即:
radv:=argmaxr;∥r∥leqϵDq(y|xl),p(y|xl+radv,θ)
这个radv无法用封闭形式表达,可以做一下近似处理:
radv=ϵg∥g∥2,whereg=▽xlD[h(y;yl),p(y|xl,θ)
就是在梯度下获取radv
这里的h(y;yl)为q(y|xl)的ont-hot表现形式。构造了对抗样本,那么对抗训练如何进行呢?
对抗训练的本质上就是让模型具有较强的鲁棒性,可以抵抗对抗样本的干扰,采用的方式就是生成这些数据,并且把这些数据加入到训练数据中。这样模型就会正视这些数据, 并且尽可能地拟合这些数据,最终完成了模型拟合,这些盲区也就覆盖住了。将对抗样本和原有数据一起进行训练,对抗样本产生的损失作为原损失的一部分, 即在不修改原模型结构的情况下增加模型的loss,产生正则化的效果。
此时目标函数可以表示为:
我们记模型的损失函数为J(θ;x;y),其中负梯度方向−∇Jx(θ;x;y)是模型的损失下降最快的方向,那么也就是说负梯度上模型优化最快, 为了使x^对模型的output产生最大的改变,正梯度方向也就是模型梯度下降最慢的方向定为扰动方向,也就是∇Jx(θ;x;y) 方向上x^=x+ϵsign(∇Jx(θ;x;y))。这里的ϵ为一个超参,控制扰动的选取的界限。
J˜(θ,x,y)=αJ(θ,x,y)+(1−α)J(θ,x+ϵsign(∇Jx(θ,x,y)))
VAT
虚拟对抗训练是对抗训练的一个变种,唯一的不同是正则化过程只需要无标签数据即可完成
1.数学模型
损失函数为:
其中,加号左半部分为带标签数据损失,右半部分为扰动部分损失,相当于增加的正则项数据。R()通过LDS()获求和求平均获取,LDS()的计算类似之前对抗训练部分模型,只是输入的是无标签的数据
2.快速确定rvadv值
rvadv的计算不能像原始对抗训练那样使用线性近似。因为对于预训练模型D[p(y|x,θ^),p(y|x∗+r,θ)] 在r=0时,存在最小值0,对于r的一阶微分也为0,梯度上提供不了有用的信息。
推倒过程自行查阅论文
作者采用文章采用幂迭代法和有限差分法 来计算rvadv
3.效果
在训练开始时,分类器预测同一簇中输入数据点的不同标签, 并且边界上的LDS非常高(训练刚开时边界)。算法优化模型在具有大LDS值的点周围平滑。随着训练的进行,模型的演变使得具有大LDS值的点上的标签预测受到附近标记输入的强烈影响。 这鼓励模型预测属于同一群集的点集的相同标签,这是在半监督学习中经常需要的。在迭代1000次后,边界已经分的很清晰了,同时预测的标签也是越来越精准。
如何看出效果:第一,从图看,LDS值大的点(相对灰色值的点)是分类分界点,说明正则项(VAT损失)是有作用的;第二,分类任务边界越来越平滑,也说明了模型学的越来越好。
代码说明
1.训练部分截图
训练输入:有标签训练数据的输入、标签;无标签训练集的输入
将有标签数据的loss加上无标签数据的loss,进行反向传播
2.vat_loss ,可结合论文分析
试验效果
待更新
参考
http://www.twistedwg.com/2018/12/04/VAT.html
https://zhuanlan.zhihu.com/p/96106267