Invariant Risk Minimization原理与最小实现
IRM(Invariant Risk Minimization)是2019年Martin Arjovsky等人提出的一种用于跨域图像分类的新方法,其提出的背景是当我们使用机器学习方法完成图片分类任务时,训练模型所使用的数据集与真实情况的数据集可能存在差别(数据集分布偏移),造成这种分布偏移的原因有很多,比如:数据选择的偏差(单一环境)、混淆因素等,该问题被称为跨域分类问题(注:跨域分类可能在其他的地方有其他的意思),目前大部分解决的方法是减小跨域分布偏差或者提取不变特征。而Martin提出的方法与之前很多跨域分类方法不同之处在于:为了提高机器学习的可解释性,并从根本上解决跨域分类问题,Martin考虑从数学方面推导出特征与标签预测的内在因果关系,即特征与标签之前存在与域无关的内在因果关系。
1、Invariant Risk Minimization原理
1.1提出问题
首先作者提出了一个问题,假设有一个SEM模型:
如上式所示,
X
1
X_1
X1是一组服从正态分布的数据,
Y
Y
Y是由
X
1
X_1
X1加上一个服从正态分布的白噪声构成,
X
2
X_2
X2是
Y
Y
Y加上一个服从正态分布的白噪声构成。
当使用最小二乘方法由
(
X
1
,
X
2
)
(X_1,X_2)
(X1,X2)对
Y
Y
Y进行预测时,设其预测模型为:
Y
^
e
=
X
1
e
α
1
^
+
X
2
e
α
2
^
\hat{Y}^e=X_1^e\hat{\alpha_1}+X_2^e\hat{\alpha_2}
Y^e=X1eα1^+X2eα2^,因此若对
X
1
X_1
X1与
Y
Y
Y的噪声乘以一个与环境有关的系数,那么当使用
X
1
X_1
X1与
X
2
X_2
X2预测
Y
Y
Y时,其根据算法是否能够识别出不变特征,回归系数会出现以下三种情况,因此作者的目标是得到第一种情况。
1.2提出模型
根据所总结的问题,作者做出如下定义,将模型分为两个部分,即数据表示
Φ
\Phi
Φ与分类器
ω
^
\hat{\omega}
ω^。
将定义转化为数学模型得IRM表达式,
但是由于上式是一个两层优化问题,因此将上式简化为单变量优化问题,
其中
Φ
\Phi
Φ成为不变预测器,其由两项组成,即经验风险最小项和不变风险最小项,而
λ
\lambda
λ作为平衡两项的一个超参数;由IRM到IRMv1的转变过程,作者还考虑了其他的因素,详细推导可看其论文第三章。
最终作者根据所提出的模型得到训练的损失函数表达式:
2、IRM最小实现
参照论文附录的基于Pytorch的IRM最小实现
import torch
from torch.autograd import grad
import numpy as np
import torchvision
def compute_penalty(losses, dummy_w):
# print(np.shape(losses[0::2]))
# print(dummy_w)
g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
# print(g1*g2)
return (g1*g2).sum()
def example_1(n=10000, d=2, env=1):
x = torch.randn(n, d)*env
y = x + torch.randn(n, d)*env
z = y + torch.randn(n, d)
# z = y
# print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
return torch.cat((x, z), 1), y.sum(1, keepdim=True)
phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")
environments = [example_1(env=0.1), example_1(env=1.0)]
# s = [[1, 2], [3, 4]]
# for i, j in s:
# print(i)
# print(j)
for iteration in range(50000):
error = 0
penalty = 0
for x_e, y_e in environments:
# print(np.shape(x_e))
# print(np.shape(y_e))
p = torch.randperm(len(x_e))
error_e = mse(x_e[p]@phi*dummy_w, y_e[p])
# error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])
# print(np.shape(error_e))
penalty += compute_penalty(error_e, dummy_w)
error += error_e.mean()
# print(iteration)
# print(error_e.mean())
# print(error)
opt.zero_grad()
(1e-5 * error + penalty).backward()
opt.step()
if iteration % 1000 == 0:
print(phi)
参考文献
Arjovsky, M., et al. (2019). “Invariant Risk Minimization.”