域泛化与不变性风险最小化

域泛化与不变性风险最小化

论文链接:https://arxiv.org/abs/1907.02893

1. 域泛化

给定源数据集 E s = { e 1 s , ⋯   , e m s } \mathcal{E}^s=\{e_1^s,\cdots,e_m^s\} Es={e1s,,ems},与目标数据集 E t = { e 1 t , ⋯   , e n t } \mathcal{E}^t=\{e_1^t,\cdots,e_n^t\} Et={e1t,,ent},每个 e e e代表一个域或者环境。

同一环境下的样本与标签 ( X , Y ) (X,Y) (X,Y)服从同一分布;不同环境下的 ( X , Y ) (X,Y) (X,Y)服从不同分布,记作

( X i s , Y i s ) ∼ P i s ( X , Y ) i = 1 , ⋯   , m ; ( X j t , Y j t ) ∼ P j t ( X , Y ) j = 1 , ⋯   , n ; (1) (X_i^s,Y_i^s)\sim P_i^s(X,Y)\quad i=1,\cdots,m;\\ (X_j^t,Y_j^t)\sim P_j^t(X,Y)\quad j=1,\cdots,n;\tag{1} (Xis,Yis)Pis(X,Y)i=1,,m;(Xjt,Yjt)Pjt(X,Y)j=1,,n;(1)

其中的概率分布 P i s ( X , Y ) , P j t ( X , Y ) P_i^s(X,Y),P_j^t(X,Y) Pis(X,Y),Pjt(X,Y)各不相同。

域泛化可以理解为:通过学到 P i s P^s_i Pis中的分布,去尽可能模拟 P j t P_j^t Pjt的分布,思路的可行性是因为不同域中实际上存在公共信息,我们需要去学到这些公共信息

2. 不变性风险最小化

简单来讲,是指在训练集中的所有域下,存在某个风险最小化的共同最优解

我们将整个网络模型表示为: w ∘ Φ : X → Y w\circ \Phi:\mathcal{X}\to\mathcal{Y} wΦ:XY

其中 Φ : X → H \Phi:\mathcal{X}\to\mathcal{H} Φ:XH为数据表示representation w : H → Y w:\mathcal H\to\mathcal Y w:HY为不变性因子invariance

记域或环境 e e e中的风险函数为 R e ( w ∘ Φ ) R^e(w\circ \Phi) Re(wΦ),对于一般的学习模型,经验风险最小化表示为

min ⁡ Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) (ERM) \min_{\Phi,w}\quad\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\tag{ERM} Φ,wmineEsRe(wΦ)(ERM)

而不变性风险最小化表示为

min ⁡ Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) s . t . w ∈ ⋂ e ∈ E s arg ⁡ min ⁡ w R e ( w ∘ Φ ) (IRM) \begin{aligned} \min_{\Phi,w}\quad&\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\\ s.t.\quad&w\in\bigcap_{e\in\mathcal E^s}\arg\min_{w}R^e(w\circ\Phi) \end{aligned}\tag{IRM} Φ,wmins.t.eEsRe(wΦ)weEsargwminRe(wΦ)(IRM)

也就是在 ( E R M ) (ERM) (ERM)的基础上添加了约束条件

但是双优化问题在实际过程中是很难交给计算机处理的,尤其是在数据规模较大的情况下

根据个人理解,考虑到了一个非常通俗的数学知识

可微函数在某区域内的极值点一定是边界点或者驻点

所以第一步是将上述条件放宽为

min ⁡ Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) s . t . w ∈ ⋂ e ∈ E s { w : ∇ w R e ( w ∘ Φ ) = 0 } (2) \begin{aligned} \min_{\Phi,w}\quad&\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\\ s.t.\quad&w\in\bigcap_{e\in\mathcal E^s}\{w:\nabla_wR^e(w\circ \Phi)=\boldsymbol{0}\} \end{aligned}\tag{2} Φ,wmins.t.eEsRe(wΦ)weEs{w:wRe(wΦ)=0}(2)

第二步将约束条件换为惩罚项,改为

min ⁡ Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) + λ ∣ ∣ ∇ w R e ( w ∘ Φ ) ∣ ∣ 2 (3) \min_{\Phi,w}\quad\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)+\lambda||\nabla_wR^e(w\circ \Phi)||^2\tag{3} Φ,wmineEsRe(wΦ)+λ∣∣wRe(wΦ)2(3)

最后固定 w = 1.0 w=1.0 w=1.0,得到论文中的实用公式

min ⁡ Φ ∑ e ∈ E s R e ( Φ ) + λ ∣ ∣ ∇ w ∣ w = 1.0 R e ( w ∘ Φ ) ∣ ∣ 2 (IRMv1) \min_{\Phi}\quad\sum_{e\in\mathcal E^s}R^e(\Phi)+\lambda||\nabla_{w|w=1.0}R^e(w\circ \Phi)||^2\tag{IRMv1} ΦmineEsRe(Φ)+λ∣∣ww=1.0Re(wΦ)2(IRMv1)

实验中也是用的这个公式,在colored_mnist实验中,有代码

def penalty(logits, y):
	scale = torch.tensor(1.).cuda().requires_grad_()
    loss = mean_nll(logits * scale, y)
    grad = autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)

原文有对惩罚项的形式做进一步探究(比如损失函数的凸性),讨论了什么时候惩罚项是有效的,满足预期需要

3. 一些思考

  • 为什么固定 w = 1.0 w=1.0 w=1.0进行训练?

考虑 ( Φ , w ) → ( γ Φ , w / γ ) (\Phi,w)\to(\gamma\Phi,w/\gamma) (Φ,w)(γΦ,w/γ),当 γ → 0 \gamma\to0 γ0即也可能保证惩罚项趋于零,使得惩罚项对 R e ( Φ ) R^e(\Phi) Re(Φ)的约束不起作用

  • w = 1.0 w=1.0 w=1.0是一个人工引入,会不会导致算法错过最优解?

可以将原来的表示改写为

w ∘ Φ = ( w ∘ Ψ − 1 ) ∘ ( Ψ ∘ Φ ) = w ~ ∘ Φ ~ (4) w\circ\Phi=(w\circ\Psi^{-1})\circ(\Psi\circ\Phi)=\tilde{w}\circ\tilde{\Phi}\tag{4} wΦ=(wΨ1)(ΨΦ)=w~Φ~(4)

其中 w ~ = w ∘ Ψ − 1 , Φ ~ = Ψ ∘ Φ \tilde{w}=w\circ\Psi^{-1},\tilde{\Phi}=\Psi\circ\Phi w~=wΨ1,Φ~=ΨΦ,有

∂ R ∂ w = ∂ R ∂ w ~ ∂ w ~ ∂ w = ∂ R ∂ w ~ (5) \frac{\partial R}{\partial w}=\frac{\partial R}{\partial\tilde{w}}\frac{\partial\tilde{w}}{\partial w}=\frac{\partial R}{\partial\tilde{w}}\tag{5} wR=w~Rww~=w~R(5)

当存在可逆函数 Ψ \Psi Ψ时, R R R w w w或者 w ~ \tilde{w} w~求导,在数值上两者完全等价

求导是个人理解,原文好像没有,因为是复合多元函数求导,比矩阵更复杂,其中不太能把握的一点在于

w ~ ( H ) = w [ Ψ − 1 ( H ) ] ⇒ ∂ w ~ ∂ w = I ( 单位阵 ) (6) \tilde{w}(H)=w[\Psi^{-1}(H)]\Rightarrow\frac{\partial\tilde{w}}{\partial w}=I(单位阵)\tag{6} w~(H)=w[Ψ1(H)]ww~=I(单位阵)(6)

如果有问题,欢迎大家在评论区交流

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值