域泛化与不变性风险最小化
论文链接:https://arxiv.org/abs/1907.02893
1. 域泛化
给定源数据集 E s = { e 1 s , ⋯ , e m s } \mathcal{E}^s=\{e_1^s,\cdots,e_m^s\} Es={e1s,⋯,ems},与目标数据集 E t = { e 1 t , ⋯ , e n t } \mathcal{E}^t=\{e_1^t,\cdots,e_n^t\} Et={e1t,⋯,ent},每个 e e e代表一个域或者环境。
同一环境下的样本与标签 ( X , Y ) (X,Y) (X,Y)服从同一分布;不同环境下的 ( X , Y ) (X,Y) (X,Y)服从不同分布,记作
( X i s , Y i s ) ∼ P i s ( X , Y ) i = 1 , ⋯ , m ; ( X j t , Y j t ) ∼ P j t ( X , Y ) j = 1 , ⋯ , n ; (1) (X_i^s,Y_i^s)\sim P_i^s(X,Y)\quad i=1,\cdots,m;\\ (X_j^t,Y_j^t)\sim P_j^t(X,Y)\quad j=1,\cdots,n;\tag{1} (Xis,Yis)∼Pis(X,Y)i=1,⋯,m;(Xjt,Yjt)∼Pjt(X,Y)j=1,⋯,n;(1)
其中的概率分布 P i s ( X , Y ) , P j t ( X , Y ) P_i^s(X,Y),P_j^t(X,Y) Pis(X,Y),Pjt(X,Y)各不相同。
域泛化可以理解为:通过学到 P i s P^s_i Pis中的分布,去尽可能模拟 P j t P_j^t Pjt的分布,思路的可行性是因为不同域中实际上存在公共信息,我们需要去学到这些公共信息
2. 不变性风险最小化
简单来讲,是指在训练集中的所有域下,存在某个风险最小化的共同最优解
我们将整个网络模型表示为: w ∘ Φ : X → Y w\circ \Phi:\mathcal{X}\to\mathcal{Y} w∘Φ:X→Y
其中
Φ
:
X
→
H
\Phi:\mathcal{X}\to\mathcal{H}
Φ:X→H为数据表示representation
,
w
:
H
→
Y
w:\mathcal H\to\mathcal Y
w:H→Y为不变性因子invariance
记域或环境 e e e中的风险函数为 R e ( w ∘ Φ ) R^e(w\circ \Phi) Re(w∘Φ),对于一般的学习模型,经验风险最小化表示为
min Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) (ERM) \min_{\Phi,w}\quad\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\tag{ERM} Φ,wmine∈Es∑Re(w∘Φ)(ERM)
而不变性风险最小化表示为
min Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) s . t . w ∈ ⋂ e ∈ E s arg min w R e ( w ∘ Φ ) (IRM) \begin{aligned} \min_{\Phi,w}\quad&\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\\ s.t.\quad&w\in\bigcap_{e\in\mathcal E^s}\arg\min_{w}R^e(w\circ\Phi) \end{aligned}\tag{IRM} Φ,wmins.t.e∈Es∑Re(w∘Φ)w∈e∈Es⋂argwminRe(w∘Φ)(IRM)
也就是在 ( E R M ) (ERM) (ERM)的基础上添加了约束条件
但是双优化问题在实际过程中是很难交给计算机处理的,尤其是在数据规模较大的情况下
根据个人理解,考虑到了一个非常通俗的数学知识
可微函数在某区域内的极值点一定是边界点或者驻点
所以第一步是将上述条件放宽为
min Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) s . t . w ∈ ⋂ e ∈ E s { w : ∇ w R e ( w ∘ Φ ) = 0 } (2) \begin{aligned} \min_{\Phi,w}\quad&\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)\\ s.t.\quad&w\in\bigcap_{e\in\mathcal E^s}\{w:\nabla_wR^e(w\circ \Phi)=\boldsymbol{0}\} \end{aligned}\tag{2} Φ,wmins.t.e∈Es∑Re(w∘Φ)w∈e∈Es⋂{w:∇wRe(w∘Φ)=0}(2)
第二步将约束条件换为惩罚项,改为
min Φ , w ∑ e ∈ E s R e ( w ∘ Φ ) + λ ∣ ∣ ∇ w R e ( w ∘ Φ ) ∣ ∣ 2 (3) \min_{\Phi,w}\quad\sum_{e\in\mathcal E^s}R^e(w\circ \Phi)+\lambda||\nabla_wR^e(w\circ \Phi)||^2\tag{3} Φ,wmine∈Es∑Re(w∘Φ)+λ∣∣∇wRe(w∘Φ)∣∣2(3)
最后固定 w = 1.0 w=1.0 w=1.0,得到论文中的实用公式
min Φ ∑ e ∈ E s R e ( Φ ) + λ ∣ ∣ ∇ w ∣ w = 1.0 R e ( w ∘ Φ ) ∣ ∣ 2 (IRMv1) \min_{\Phi}\quad\sum_{e\in\mathcal E^s}R^e(\Phi)+\lambda||\nabla_{w|w=1.0}R^e(w\circ \Phi)||^2\tag{IRMv1} Φmine∈Es∑Re(Φ)+λ∣∣∇w∣w=1.0Re(w∘Φ)∣∣2(IRMv1)
实验中也是用的这个公式,在colored_mnist
实验中,有代码
def penalty(logits, y):
scale = torch.tensor(1.).cuda().requires_grad_()
loss = mean_nll(logits * scale, y)
grad = autograd.grad(loss, [scale], create_graph=True)[0]
return torch.sum(grad**2)
原文有对惩罚项的形式做进一步探究(比如损失函数的凸性),讨论了什么时候惩罚项是有效的,满足预期需要
3. 一些思考
- 为什么固定 w = 1.0 w=1.0 w=1.0进行训练?
考虑 ( Φ , w ) → ( γ Φ , w / γ ) (\Phi,w)\to(\gamma\Phi,w/\gamma) (Φ,w)→(γΦ,w/γ),当 γ → 0 \gamma\to0 γ→0即也可能保证惩罚项趋于零,使得惩罚项对 R e ( Φ ) R^e(\Phi) Re(Φ)的约束不起作用
- w = 1.0 w=1.0 w=1.0是一个人工引入,会不会导致算法错过最优解?
可以将原来的表示改写为
w ∘ Φ = ( w ∘ Ψ − 1 ) ∘ ( Ψ ∘ Φ ) = w ~ ∘ Φ ~ (4) w\circ\Phi=(w\circ\Psi^{-1})\circ(\Psi\circ\Phi)=\tilde{w}\circ\tilde{\Phi}\tag{4} w∘Φ=(w∘Ψ−1)∘(Ψ∘Φ)=w~∘Φ~(4)
其中 w ~ = w ∘ Ψ − 1 , Φ ~ = Ψ ∘ Φ \tilde{w}=w\circ\Psi^{-1},\tilde{\Phi}=\Psi\circ\Phi w~=w∘Ψ−1,Φ~=Ψ∘Φ,有
∂ R ∂ w = ∂ R ∂ w ~ ∂ w ~ ∂ w = ∂ R ∂ w ~ (5) \frac{\partial R}{\partial w}=\frac{\partial R}{\partial\tilde{w}}\frac{\partial\tilde{w}}{\partial w}=\frac{\partial R}{\partial\tilde{w}}\tag{5} ∂w∂R=∂w~∂R∂w∂w~=∂w~∂R(5)
当存在可逆函数 Ψ \Psi Ψ时, R R R对 w w w或者 w ~ \tilde{w} w~求导,在数值上两者完全等价
求导是个人理解,原文好像没有,因为是复合多元函数求导,比矩阵更复杂,其中不太能把握的一点在于
w ~ ( H ) = w [ Ψ − 1 ( H ) ] ⇒ ∂ w ~ ∂ w = I ( 单位阵 ) (6) \tilde{w}(H)=w[\Psi^{-1}(H)]\Rightarrow\frac{\partial\tilde{w}}{\partial w}=I(单位阵)\tag{6} w~(H)=w[Ψ−1(H)]⇒∂w∂w~=I(单位阵)(6)
如果有问题,欢迎大家在评论区交流