PyTorch 中的 torch.eye()
函数
torch.eye()
是 PyTorch 提供的一个 创建单位矩阵(Identity Matrix) 的函数,返回一个对角线上全为 1,其他元素为 0 的方阵(即单位矩阵)。
1. torch.eye()
的基本语法
torch.eye(n, m=None, dtype=None, device=None, requires_grad=False)
参数说明
n
(int):行数,即矩阵的大小(n × m
)。m
(int,可选):列数。默认为None
,表示创建n × n
的方阵(即正方形单位矩阵)。dtype
(可选):数据类型,如torch.float32
、torch.int64
等。device
(可选):指定计算设备,如cuda
(GPU)或cpu
。requires_grad
(bool,可选):是否需要计算梯度(用于自动求导)。
2. torch.eye()
的示例
(1) 创建 n × n
单位矩阵
如果 m=None
,则创建 n × n
的单位矩阵。
import torch
I = torch.eye(3)
print(I)
输出
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
(2) 创建 n × m
矩阵
如果 m
不是 None
,则创建 n × m
的矩阵,其中主对角线仍然是 1,超出部分补 0。
I = torch.eye(3, 4) # 3 行 4 列
print(I)
输出
tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.]])
(3) 指定数据类型 dtype
可以使用 dtype
指定矩阵的数据类型,如 torch.int
、torch.float
等。
I = torch.eye(3, dtype=torch.int)
print(I)
输出
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]], dtype=torch.int32)
(4) 在 GPU 上创建单位矩阵
可以使用 device='cuda'
在 GPU 上创建单位矩阵。
I = torch.eye(3, device='cuda')
print(I.device) # 输出: cuda:0
(5) requires_grad=True
用于梯度计算
如果 requires_grad=True
,那么 PyTorch 将会跟踪该张量的计算图,并支持自动求导。
I = torch.eye(3, requires_grad=True)
print(I.requires_grad) # 输出: True
3. torch.eye()
在机器学习中的应用
(1) 生成 one-hot
编码
单位矩阵可用于 one-hot 编码,因为 torch.eye(n)
生成的矩阵每行正好是一个 one-hot 向量。
labels = torch.tensor([0, 2, 1]) # 3 个类别
one_hot = torch.eye(3)[labels] # 选择索引对应的 one-hot 向量
print(one_hot)
输出
tensor([[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.]])
这里
torch.eye(3)[labels]
通过索引获取one-hot
编码。
(2) 用于神经网络的单位矩阵初始化
在某些神经网络层中,初始化权重为单位矩阵可以帮助模型稳定训练,例如 RNN
、LSTM
。
rnn_weight = torch.eye(10, 10) # 10×10 单位矩阵
(3) 计算矩阵逆(torch.inverse()
)
如果一个矩阵是可逆的,它的逆矩阵乘以自身应该得到单位矩阵。
A = torch.tensor([[4., 7.], [2., 6.]])
A_inv = torch.inverse(A) # 计算逆矩阵
I = torch.mm(A, A_inv) # 矩阵乘法
print(I) # 结果接近单位矩阵
4. torch.eye()
vs torch.ones()
vs torch.zeros()
函数 | 作用 | 示例 |
---|---|---|
torch.eye(n, m) | 生成单位矩阵 | torch.eye(3) |
torch.ones(n, m) | 生成全 1 矩阵 | torch.ones(3, 3) |
torch.zeros(n, m) | 生成全 0 矩阵 | torch.zeros(3, 3) |
5. 总结
功能 | 代码 |
---|---|
创建 3×3 单位矩阵 | torch.eye(3) |
创建 3×4 矩阵 | torch.eye(3, 4) |
指定数据类型 dtype | torch.eye(3, dtype=torch.float64) |
在 GPU 上创建 | torch.eye(3, device='cuda') |
用于 one-hot 编码 | torch.eye(n)[labels] |
torch.eye()
在 PyTorch 神经网络初始化、one-hot 编码、矩阵计算 等任务中非常有用,是创建单位矩阵的最佳选择。