- squeeze(): 去除size为1的维度,包括行和列。至于维度大于等于2时,squeeze()不起作用
- unsqueeze():在指定维度增加一个维度
import torch
a = torch.rand(4, 1, 3)
print('a.shape:',a.shape)
print('a:',a)
b = a.squeeze() #除去维度为1的维度
print('b.shape:',b.shape)
print('b:',b)
c = a.unsqueeze(0) #在指定维度上增加一个维度
print('c:',c)
print('c.shape:',c.shape)
a.shape: torch.Size([4, 1, 3])
a: tensor([[[0.2007, 0.2601, 0.2374]],
[[0.9118, 0.9556, 0.6185]],
[[0.7858, 0.8942, 0.0059]],
[[0.2043, 0.9204, 0.6308]]])
b.shape: torch.Size([4, 3])
b: tensor([[0.2007, 0.2601, 0.2374],
[0.9118, 0.9556, 0.6185],
[0.7858, 0.8942, 0.0059],
[0.2043, 0.9204, 0.6308]])
c: tensor([[[[0.2007, 0.2601, 0.2374]],
[[0.9118, 0.9556, 0.6185]],
[[0.7858, 0.8942, 0.0059]],
[[0.2043, 0.9204, 0.6308]]]])
c.shape: torch.Size([1, 4, 1, 3])