transpose
和permute
都是转置函数,可以交换Tensor的维度。
1. transpose
torch.transpose(input, dim0, dim1, out=None)
→
\rightarrow
→Tensor
transpose
函数用于交换input
的维度dim0
和dim1
,只能交换两个维度,且dim0和dim1的参数位置没有顺序而言。
例子:
a = torch.arange(6).reshape((2, 3))
a, a.transpose(1, 0) # shape从(2, 3)变成(3, 2)
Out:tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
a.transpose(1, 0)
与a.transpose(0, 1)
相同,都是将第0个维度和第1个维度交换。
a.transpose(1, 0) == a.transpose(0, 1)
Out: tensor([[True, True],
[True, True],
[True, True]])
作用于高维:
b = torch.arange(24).reshape((2, 3, 4))
b.transpose(1, 2) # shape从(2, 3, 4)变成(2, 4, 3)
b.transpose(0, 1, 2) # 报错,只能输入两个维度进行交换
2. permute
torch.permute(input, dims)
→
\rightarrow
→ Tensor
permute
函数相当于把input
的各个维度进行了重排列,可以一次性交换多个维度,参数dims
的长度必须与input
的维度相同。
例子:
print(torch.permute(a, (1, 0))) # shape从(2, 3) 变成 (3, 2)
print(torch.permute(a, (0, 1)) == a) # 不变
Out:tensor([[0, 3],
[1, 4],
[2, 5]])
tensor([[True, True, True],
[True, True, True]])
作用于高维:
b.permute(1, 0, 2) # shape从(2, 3, 4)变成(3, 2, 4)
b.permute(2, 0, 1) # shape从(2, 3, 4)变成(4, 2, 3)
b.permute(1, 0) # 报错,必须要输入三个维度
由于permute
一次可以操作多个维度,因此在高维的功能性比transpose
更强,不过permute
能做到的transpose
也能做到,只不过transpose
可能要多调用几次,transpose
能做到的permute
也都能做到。