在 PyTorch 中,.T 和 .t() 都是用于对张量进行转置操作的,但它们有一些关键的区别。以下是它们的原理和区别:
.T 属性
.T 是一个属性,是 .permute 函数的简化版本。适用于所有维度的张量。
对于二维张量(矩阵),.T 将矩阵进行转置操作,即交换行和列。
对于高维张量,.T 将所有维度进行反转。例如,对于一个形状为 (2, 3, 4) 的张量,.T 的结果形状将是 (4, 3, 2)。
示例:
import torch
# 二维张量
matrix = torch.tensor([[1, 2], [3, 4]])
print(matrix.T)
# 输出:
# tensor([[1, 3],
# [2, 4]])
# 三维张量
tensor = torch.randn(2, 3, 4)
print(tensor.T.shape)
# 输出: torch.Size([4, 3, 2])
.t() 方法
.t() 是一个方法,只适用于二维张量(矩阵)。
它对二维张量进行转置操作,即交换行和列。
如果尝试对高维张量使用 .t() 方法,会抛出错误。
示例:
import torch
# 二维张量
matrix = torch.tensor([[1, 2], [3, 4]])
print(matrix.t())
# 输出:
# tensor([[1, 3],
# [2, 4]])
# 三维张量
tensor = torch.randn(2, 3, 4)
try:
print(tensor.t())
except RuntimeError as e:
print(e)
# 输出: t() expects a tensor with <= 2 dimensions, but self is 3D
总结
.T 属性:适用于任意维度的张量,对二维张量进行标准转置,对高维张量将所有维度反转。
.t() 方法:仅适用于二维张量,对其进行标准转置。如果在高维张量上调用会抛出错误。