torch.squeeze()和unsqueeze()
unsqueeze()
函数功能:与squeeze()函数功能相反,用于添加维度。
queeze()
函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
其中squeeze(0)代表若第一维度值为1则去除第一维度,squeeze(1)代表若第二维度值为1则去除第二维度。
eg1:
a = torch.Tensor(1,3)
print a
tensor([[-1.37,4.56,-3.57]])
print a.squeeze(0)
tensor([-1.37,4.56,-3.57])
print a.squeeze(1)
tensor([[-1.37