IRM(Invariant Risk Minimization)原理与最小实现


IRM(Invariant Risk Minimization)是2019年Martin Arjovsky等人提出的一种用于跨域图像分类的新方法,其提出的背景是当我们使用机器学习方法完成图片分类任务时,训练模型所使用的数据集与真实情况的数据集可能存在差别(数据集分布偏移),造成这种分布偏移的原因有很多,比如:数据选择的偏差(单一环境)、混淆因素等,该问题被称为跨域分类问题(注:跨域分类可能在其他的地方有其他的意思),目前大部分解决的方法是减小跨域分布偏差或者提取不变特征。而Martin提出的方法与之前很多跨域分类方法不同之处在于:为了提高机器学习的可解释性,并从根本上解决跨域分类问题,Martin考虑从数学方面推导出特征与标签预测的内在因果关系,即特征与标签之前存在与域无关的内在因果关系。

1、Invariant Risk Minimization原理

1.1提出问题

首先作者提出了一个问题,假设有一个SEM模型:
一个特征维度为2,输出维度为1的模型
如上式所示, X 1 X_1 X1是一组服从正态分布的数据, Y Y Y是由 X 1 X_1 X1加上一个服从正态分布的白噪声构成, X 2 X_2 X2 Y Y Y加上一个服从正态分布的白噪声构成。

当使用最小二乘方法由 ( X 1 , X 2 ) (X_1,X_2) (X1,X2) Y Y Y进行预测时,设其预测模型为: Y ^ e = X 1 e α 1 ^ + X 2 e α 2 ^ \hat{Y}^e=X_1^e\hat{\alpha_1}+X_2^e\hat{\alpha_2} Y^e=X1eα1^+X2eα2^,因此若对 X 1 X_1 X1 Y Y Y的噪声乘以一个与环境有关的系数,那么当使用 X 1 X_1 X1 X 2 X_2 X2预测 Y Y Y时,其根据算法是否能够识别出不变特征,回归系数会出现以下三种情况,因此作者的目标是得到第一种情况。
在这里插入图片描述

1.2提出模型

根据所总结的问题,作者做出如下定义,将模型分为两个部分,即数据表示 Φ \Phi Φ与分类器 ω ^ \hat{\omega} ω^
在这里插入图片描述
将定义转化为数学模型得IRM表达式,
在这里插入图片描述
但是由于上式是一个两层优化问题,因此将上式简化为单变量优化问题,
在这里插入图片描述
其中 Φ \Phi Φ成为不变预测器,其由两项组成,即经验风险最小项和不变风险最小项,而 λ \lambda λ作为平衡两项的一个超参数;由IRM到IRMv1的转变过程,作者还考虑了其他的因素,详细推导可看其论文第三章。

最终作者根据所提出的模型得到训练的损失函数表达式:
在这里插入图片描述

2、IRM最小实现

参照论文附录的基于Pytorch的IRM最小实现

import torch
from torch.autograd import grad
import numpy as np
import torchvision


def compute_penalty(losses, dummy_w):
    # print(np.shape(losses[0::2]))
    # print(dummy_w)
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    # print(g1*g2)
    return (g1*g2).sum()


def example_1(n=10000, d=2, env=1):
    x = torch.randn(n, d)*env
    y = x + torch.randn(n, d)*env
    z = y + torch.randn(n, d)
    # z = y
    # print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)


phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")

environments = [example_1(env=0.1), example_1(env=1.0)]

# s = [[1, 2], [3, 4]]
# for i, j in s:
#     print(i)
#     print(j)
for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e, y_e in environments:
        # print(np.shape(x_e))
        # print(np.shape(y_e))
        p = torch.randperm(len(x_e))
        error_e = mse(x_e[p]@phi*dummy_w, y_e[p])
        # error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])
        # print(np.shape(error_e))
        penalty += compute_penalty(error_e, dummy_w)
        error += error_e.mean()
        # print(iteration)
        # print(error_e.mean())

        # print(error)

    opt.zero_grad()
    (1e-5 * error + penalty).backward()
    opt.step()
    if iteration % 1000 == 0:
        print(phi)

参考文献

Arjovsky, M., et al. (2019). “Invariant Risk Minimization.”

  • 12
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值