程序中常常会涉及到张量的维度变换,为区别常用变换函数reshape、permute、view的不同,使用以下代码直观说明。
先对二维张量进行变换:
import torch
c=torch.randn(2,3)
print("c:")
print(c)
c1=torch.reshape(c,(3,2))
print("reshape:")
print(c1)
c2=c.permute(1,0)
print("permute:")
print(c2)
c3=c.view(3,2)
print("view:")
print(c3)
c4=c.transpose(1,0)
print("transpose:")
print(c4)
代码结果:
c:
tensor([[-0.1440, 0.0783, 0.1932],
[-0.3421, -0.3892, 1.5916]])
reshape:
tensor([[-0.1440, 0.0783],
[ 0.1932, -0.3421],
[-0.3892, 1.5916]])
permute:
tensor([[-0.1440, -0.3421],
[ 0.0783, -0.3892],
[ 0.1932, 1.5916]])
view:
tensor([[-0.1440, 0.0783],
[ 0.1932, -0.3421],
[-0.3892, 1.5916]])
transpose:
tensor([[-0.1440, -0.3421],
[ 0.0783, -0.3892],
[ 0.1932, 1.5916]])
先对三维张量进行变换()
import torch
c=torch.randn(2,3,4)
print("c:")
print(c)
c1=torch.reshape(c,(4,3,2))
print("reshape:")
print(c1)
c2=c.permute(2,1,0)
print("permute:")
print(c2)
c3=c.view(4,3,2)
print("view:")
print(c3)
运行结果:
c:
tensor([[[-1.0615, 0.4882, 0.8409, 0.2768],
[ 0.9477, -0.3273, -1.0569, -0.8247],
[ 1.6808, 0.8768, -0.0501, 0.6021]],
[[-0.7654, -0.0558, 0.6285, -1.5840],
[ 0.0180, 0.0347, 0.8596, -1.1917],
[ 3.1598, -0.2194, 0.9064, 0.2113]]])
reshape:
tensor([[[-1.0615, 0.4882],
[ 0.8409, 0.2768],
[ 0.9477, -0.3273]],
[[-1.0569, -0.8247],
[ 1.6808, 0.8768],
[-0.0501, 0.6021]],
[[-0.7654, -0.0558],
[ 0.6285, -1.5840],
[ 0.0180, 0.0347]],
[[ 0.8596, -1.1917],
[ 3.1598, -0.2194],
[ 0.9064, 0.2113]]])
permute:
tensor([[[-1.0615, -0.7654],
[ 0.9477, 0.0180],
[ 1.6808, 3.1598]],
[[ 0.4882, -0.0558],
[-0.3273, 0.0347],
[ 0.8768, -0.2194]],
[[ 0.8409, 0.6285],
[-1.0569, 0.8596],
[-0.0501, 0.9064]],
[[ 0.2768, -1.5840],
[-0.8247, -1.1917],
[ 0.6021, 0.2113]]])
view:
tensor([[[-1.0615, 0.4882],
[ 0.8409, 0.2768],
[ 0.9477, -0.3273]],
[[-1.0569, -0.8247],
[ 1.6808, 0.8768],
[-0.0501, 0.6021]],
[[-0.7654, -0.0558],
[ 0.6285, -1.5840],
[ 0.0180, 0.0347]],
[[ 0.8596, -1.1917],
[ 3.1598, -0.2194],
[ 0.9064, 0.2113]]])
可以看出,reshape只是对形状做了一个变化,而permute,transpose则是对轴进行了一个互换。