之前我们从Pytorch基础(机器学习)了解了使用.t()
函数够对一维和二维tensor
进行操作。对于高阶tensor
的转置操作,我们应该使用.transpose()
,.T
或者.permute()
函数。
import torch
# 可以将高阶张量的任意两个方向的进行转置, 但是一次只能实现两方向之间的转置
t1 = torch.tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
print("将第二个维度与第一个维度上的元素位置进行交换后得: ")
print(torch.transpose(input=t1, dim0=2, dim1=1))
print("-" * 40)
print(t1.T) # 等同于tensor.permute(n-1, n-2 .... 0)
print(t1.permute(2, 1, 0))
print("上述两个结果一致")
print("-" * 40)
print("使用permute函数实现多维度重新排列后: ")
print(t1.permute(1, 2, 0)) # permute()函数可以实现多维度同时转置
"""
输出结果:
将第二个维度与第一个维度上的元素位置进行交换后得:
tensor([[[1, 3],
[2, 4]],
[[5, 7],
[6, 8]]])
----------------------------------------
tensor([[[1, 5],
[3, 7]],
[[2, 6],
[4, 8]]])
tensor([[[1, 5],
[3, 7]],
[[2, 6],
[4, 8]]])
上述两个结果一致
----------------------------------------
使用permute函数实现多维度重新排列后:
tensor([[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]])
"""
结果显而易见,这里不做过多说明。
码字不易,如果大家觉得有用,请高抬贵手给一个赞让我上推荐让更多的人看到吧~