IRM最小实现代码详解

在上一篇的基础上,详细解释IRM代码中每一个函数的用法与作用。

import torch
from torch.autograd import grad
import numpy as np
import torchvision


def compute_penalty(losses, dummy_w):
    # print(np.shape(losses[0::2]))
    # print(dummy_w)
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    # print(g1*g2)
    return (g1*g2).sum()


def example_1(n=10000, d=2, env=1):
    x = torch.randn(n, d)*env
    y = x + torch.randn(n, d)*env
    z = y + torch.randn(n, d)
    # z = y
    # print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)


phi = torch.nn.Parameter(torch.ones(4, 1))
# print(phi)
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))
# print(dummy_w)
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")

environments = [example_1(env=0.1), example_1(env=1.0)]

# s = [[1, 2], [3, 4]]
# for i, j in s:
#     print(i)
#     print(j)
for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e, y_e in environments:
        # print(np.shape(x_e))
        # print(np.shape(y_e))
        p = torch.randperm(len(x_e))
        error_e = mse(x_e[p]@phi*dummy_w, y_e[p])
        # error_e = mse(torch.matmul(x_e[p], phi) * dummy_w, y_e[p])
        # print(np.shape(error_e))
        penalty += compute_penalty(error_e, dummy_w)
        error += error_e.mean()
        # print(iteration)
        # print(error_e.mean())

        # print(error)

    opt.zero_grad()
    (1e-5 * error + penalty).backward()
    opt.step()
    if iteration % 1000 == 0:
        print(phi)

torch.nn.Parameter

torch.nn.Parameter是一种用于定义模型参数的Tensor,当其被指定为模型的属性时,它们会自动添加到参数列表中(PS:目前最新的API调用是torch.nn.parameter.Paramater)。其包含两个形参,分别为data和require_grad,data是所定义参数的值或者tensor,require_grad是设置该参数是否需要梯度,该形参可以提高计算效率。

phi = torch.nn.Parameter(torch.ones(4, 1))
dummy_w = torch.nn.Parameter(torch.Tensor([1.0]))

首先定义两个参数,第一个是 Φ \Phi Φ,其形状是[4,1],其原因是所定义的输入 X 1 X_1 X1 X 2 X_2 X2的维度均为2,其次定义了 ω \omega ω,但是根据IRMv1,该参数是一个常数,因此这里将其设置为1.0

torch.optim.SGD

torch.optim.SGD是Pytorch自带的优化器,SGD是随机梯度下降法。 torch.optim.SGD有六个参数,
在这里插入图片描述

opt = torch.optim.SGD([phi], lr=1e-3)

该语句是定义了随机梯度下降优化器,优化参数是 Φ \Phi Φ,学习率是0.003

torch.nn.MSELoss

torch.nn.MSELoss是均方误差函数,其详细资料如下图
在这里插入图片描述

mse = torch.nn.MSELoss(reduction="none")

当reduction为none时,返回值为一个序列。

example_1函数

根据论文所提出的例子1所定义的数据,torch.randn生成数量为n,维度为d的服从正态分布的数据

def example_1(n=10000, d=2, env=1):
    x = torch.randn(n, d)*env
    y = x + torch.randn(n, d)*env
    z = y + torch.randn(n, d)
    # z = y
    # print(np.shape(torch.cat((x, z), 1))) # torch.Size([10000, 4])
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值