transpose 函数的功能是重新排列维度,但是在 numpy 和 pytorch 中用法不同。
numpy transpose
输入是重新排列的维度的下标
b = a.transpose(1, 2, 0)
import numpy as np
a = np.random.randn(3, 200, 200)
b = a.transpose(1, 2, 0)
print(a.shape)
print(b.shape)
或者
b = np.transpose(a, (1, 2, 0))
效果相同
import numpy as np
a = np.random.randn(3, 200, 200)
b = np.transpose(a, (1, 2, 0))
print(a.shape)
print(b.shape)
pytorch transpose
输入是要交换的两个维度的下标
b = a.transpose(0, 2)
import torch
a = torch.randn(3, 200, 200)
b = a.transpose(0, 2)
print(a.shape)
print(b.shape)