【PyTorch】transpose() 函数详解
函数原型
torch.transpose(input, dim0, dim1) → Tensor
函数详解
交换输入张量 input 的两个维度,两个维度分别通过参数 dim0 和 dim1 传入。
例子
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893, 0.5809],
[-0.1669, 0.7299, 0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],
[-0.9893, 0.7299],
[ 0.5809, 0.4942]])
——翻译自PyTorch官方文档