pytorch中unsqueeze和queeze函数的使用
直接上代码:
import torch
a = torch.randn([1,2,3])
print(a.shape)
a = a.squeeze(0) #这里注意一定要重新赋值
#在第“0”维减少一个维度
#unsqueeze()函数是增加一个维度
#实际上数据并没有变化,不存在多了或者少了
print(a.shape)
torch.Size([1, 2, 3])
torch.Size([2, 3])
所以看得出来tensor.squeeze()的作用就是减少一个维度,但要注意一个问题,只能删减维度大小为1的一维,维度不为1,说明存在有效数据,也就无法删除。
同时tensor.unsqueeze()的作用就是增加一个维度。里面是-1表示在最后一个维度增加。
该函数的使用:
构建神经网络时的输入形状都有假设:输入的不是单个数据,而是一个batch。这就意味这输入的维度不再是单个数据的维度
(比如tensor.size([3,28,28])),
而是(比如tensor.size([batch_size,3,28,28]))
所以当输入只有一个数据,则必须调用tensor.unsqueeze(0) 或 tensor[None]将数据伪装成batch_size=1的batch,不然就会出错。