pytorch中squeeze()和unsqueeze()函数
就我的理解,这两个函数其实是对tensor做的一个降维和升维的操作。
之所以这样做,是为了矩阵运算的需要。需要将tensor的size变换得满足矩阵运算的要求。
一、Squeeze()
squeeze字面意思就是压缩,挤压。在pytorch中的理解就是,将一个高维的tenso降下来,降到低维。默认无参的话,表示将tensor中所有维度是1的都压缩掉。可以结合具体代码看一下:
In: import torch
a = torch.arange(2,7)
In: a
Out: tensor([2, 3, 4, 5, 6])
In: b=a.unsqueeze(1)
c=</