1.unsqueeze函数(给指定的维度增加一维)
import torch
a = torch.arange(0,6)
b = a.view(2,3)
可以得到b为:
tensor([[0,1,2],
[3,4,5]])
可以看到a的维度为(2,3)
c = b.unsqueeze(1) ###在第一维增加一维
可以得到c为:
tensor([[[0, 1, 2]],
[[3, 4, 5]]])
可以看到b的形状变成(2,1,3)
2.squeeze函数(删除指定维度)
import torch
d = c.squeeze(1)
可以得到d的结果为:
tensor([[0, 1, 2],
[3, 4, 5]])
可以看到d的维度变成了(2,3)