【PyTorch】torch.eye() 函数:创建单位矩阵

PyTorch 中的 torch.eye() 函数

torch.eye() 是 PyTorch 提供的一个 创建单位矩阵(Identity Matrix) 的函数,返回一个对角线上全为 1,其他元素为 0 的方阵(即单位矩阵)。


1. torch.eye() 的基本语法

torch.eye(n, m=None, dtype=None, device=None, requires_grad=False)

参数说明

  • n(int):行数,即矩阵的大小(n × m)。
  • m(int,可选):列数。默认为 None,表示创建 n × n 的方阵(即正方形单位矩阵)。
  • dtype(可选):数据类型,如 torch.float32torch.int64 等。
  • device(可选):指定计算设备,如 cuda(GPU)或 cpu
  • requires_grad(bool,可选):是否需要计算梯度(用于自动求导)。

2. torch.eye() 的示例

(1) 创建 n × n 单位矩阵

如果 m=None,则创建 n × n 的单位矩阵。

import torch

I = torch.eye(3)
print(I)

输出

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

(2) 创建 n × m 矩阵

如果 m 不是 None,则创建 n × m 的矩阵,其中主对角线仍然是 1,超出部分补 0。

I = torch.eye(3, 4)  # 3 行 4 列
print(I)

输出

tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

(3) 指定数据类型 dtype

可以使用 dtype 指定矩阵的数据类型,如 torch.inttorch.float 等。

I = torch.eye(3, dtype=torch.int)
print(I)

输出

tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]], dtype=torch.int32)

(4) 在 GPU 上创建单位矩阵

可以使用 device='cuda' 在 GPU 上创建单位矩阵。

I = torch.eye(3, device='cuda')
print(I.device)  # 输出: cuda:0

(5) requires_grad=True 用于梯度计算

如果 requires_grad=True,那么 PyTorch 将会跟踪该张量的计算图,并支持自动求导。

I = torch.eye(3, requires_grad=True)
print(I.requires_grad)  # 输出: True

3. torch.eye() 在机器学习中的应用

(1) 生成 one-hot 编码

单位矩阵可用于 one-hot 编码,因为 torch.eye(n) 生成的矩阵每行正好是一个 one-hot 向量。

labels = torch.tensor([0, 2, 1])  # 3 个类别
one_hot = torch.eye(3)[labels]  # 选择索引对应的 one-hot 向量
print(one_hot)

输出

tensor([[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]])

这里 torch.eye(3)[labels] 通过索引获取 one-hot 编码。


(2) 用于神经网络的单位矩阵初始化

在某些神经网络层中,初始化权重为单位矩阵可以帮助模型稳定训练,例如 RNNLSTM

rnn_weight = torch.eye(10, 10)  # 10×10 单位矩阵

(3) 计算矩阵逆(torch.inverse()

如果一个矩阵是可逆的,它的逆矩阵乘以自身应该得到单位矩阵。

A = torch.tensor([[4., 7.], [2., 6.]])
A_inv = torch.inverse(A)  # 计算逆矩阵
I = torch.mm(A, A_inv)  # 矩阵乘法
print(I)  # 结果接近单位矩阵

4. torch.eye() vs torch.ones() vs torch.zeros()

函数作用示例
torch.eye(n, m)生成单位矩阵torch.eye(3)
torch.ones(n, m)生成全 1 矩阵torch.ones(3, 3)
torch.zeros(n, m)生成全 0 矩阵torch.zeros(3, 3)

5. 总结

功能代码
创建 3×3 单位矩阵torch.eye(3)
创建 3×4 矩阵torch.eye(3, 4)
指定数据类型 dtypetorch.eye(3, dtype=torch.float64)
在 GPU 上创建torch.eye(3, device='cuda')
用于 one-hot 编码torch.eye(n)[labels]

torch.eye() 在 PyTorch 神经网络初始化、one-hot 编码、矩阵计算 等任务中非常有用,是创建单位矩阵的最佳选择。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值