pytorch基础函数解析##
1.view函数
1.1.view(参数a,参数b,…)
import torch
temp=torch.tensor[1,2,3,4,5,6] #temp的类型为list,非tensor
temp=torch.tensor(temp)
print(temp.view(2,3))
print(temp.view(1,2,3))
print(temp.view(2,3,1,1))
#view用于对tensor维度的重构,返回一个有相同数据但是不同维度的tensor,操作对象应该是tensor类型,可以通过data=torch.tensor(data)来转换
1.2 view(参数a,参数b,-1),其中如果某个参数为-1,则表示该维度取决于其他维度,由pytorch自己补充。
import torch
temp=list[1,2,3,4,5,6]
temp=torch.tensor(temp)
print(temp)
print(temp.view(-1))
temp1=torch.tensor([1,2,3],[3,5,6])
print(temp1)
print(temp1.view(-1))
- permute函数
将tensor的维度换位
x=torch.randn(2,3,5)
x.size()
x.permute(2,0,1).size()
2.1 transpose与permute的异同
tensor.permute(a,b,c,b,…):permute函数可以对任意高维矩阵进行转置,但没有torch.permute()这个调用方式,只能用tensor.permute()
torch.transpose(tensor,a,b)
:
transpose操作只能操作2D矩阵的转置,有两种调用方式
>>>torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape
torch.size([5,4,2,3])
permute操作相当于可以同时操作于tensor的若干维度,transpose只能作用于tensor的两个维度