unsqueeze()
import torch
a = torch.arange(0,6)
print(a)
print(a.view(1,6))
print(a.view(6,1))
print(a.unsqueeze(0))
print(a.unsqueeze(1))
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
可以看到,原来就是一个一维的数组,我们可以shape=[6],然后unsqueeze(0),就是在第0位纬度加上1,shape=[1,6];unsqueeze(1),就是在第1位纬度加上1,shape=[6,1]。
squeeze()
print(a.unsqueeze(0).shape)
print(a.unsqueeze(1).shape)
print(a.unsqueeze(0).squeeze(0))
print(a.unsqueeze(1).squeeze(0))
print(a.unsqueeze(0).squeeze(0).shape)
print(a.unsqueeze(1).squeeze(0).shape)
squeeze()
是去掉一个维度,并且去掉的这个维度的数值必须是1,否则无效。