在 PyTorch 中,transpose
和 permute
都是用于调整张量维度的函数。它们在很多深度学习任务中非常有用,尤其是在处理张量维度和进行矩阵操作时。
1. transpose
函数
transpose
函数用来交换张量的两个维度。它接受两个参数,即需要交换的两个维度的索引。这个操作不会改变张量的数据本身,只是改变了张量的视图。
语法
torch.transpose(input, dim0, dim1)
- input:输入的张量。
- dim0:要交换的第一个维度的索引。
- dim1:要交换的第二个维度的索引。
返回值
返回一个新的张量,其中 dim0
和 dim1
的维度被交换了。
示例
import torch
# 创建一个 2x3 的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 交换第0维和第1维
y = torch.transpose(x, 0, 1)
print("Original Tensor:")
print(x)
print("\nTransposed Tensor:")
print(y)
输出:
Original Tensor:
tensor([[1, 2, 3],
[4, 5, 6]])
Transposed Tensor:
tensor([[1, 4],
[2, 5],
[3, 6]])
在这个例子中,我们交换了张量 x
的第 0 维(行)和第 1 维(列),所以得到的张量 y
是一个 3x2 的张量。
2. permute
函数
permute
函数可以重新排列张量的所有维度。与 transpose
仅能交换两个维度不同,permute
允许你指定任意的维度顺序。
语法
torch.permute(input, dims)
- input:输入的张量。
- dims:一个包含维度索引的元组,表示新的维度顺序。
返回值
返回一个新的张量,维度顺序根据 dims
进行调整。
示例
import torch
# 创建一个 2x3x4 的张量
x = torch.randn(2, 3, 4)
# 调整维度顺序
y = x.permute(2, 0, 1)
print("Original Tensor Shape:", x.shape)
print("Permuted Tensor Shape:", y.shape)
输出:
Original Tensor Shape: torch.Size([2, 3, 4])
Permuted Tensor Shape: torch.Size([4, 2, 3])
在这个例子中,原始张量的形状是 (2, 3, 4)
,我们通过 permute
调整维度顺序为 (4, 2, 3)
。
详细说明
transpose
只交换两个维度。permute
可以自由地重新排列所有维度。例如,x.permute(2, 0, 1)
将维度2
、0
和1
进行了交换。
总结
transpose(dim0, dim1)
用于交换张量的两个维度,适用于二维及以上的张量。permute(dims)
可以重新排列张量的所有维度,适用于任意维度的张量。