1. squeeze()
压缩张量,去除size为1的维度
2.unsqueeze()
增加维度:在指定维度上增加维度,
如b = torch.tensor([1, 2, 3]), torch.Size([3])
在dim = 0的维度插入,B1 = torch.unsqueeze(b, 0),输出为[[1,2,3]]。torch.Size([1, 3])
加到另一个维度,B2 = B.unsqueeze(1),输出为[[1],[2],[3]]。torch.Size([3, 1])
在B1基础上加新的维度,B3 = B1.unsqueeze(2), 输出为[[[1],[2],[3]]]。torch.Size([1, 3, 1])
3.expand()
在指定维度上进行扩张
4.gather()
从tensor中按索引取值,注意!索引范围是[0, length - 1],如果范围超出则会报错
具体torch.gather()的使用方法这篇博客介绍十分详细。
图解PyTorch中的torch.gather函数 - 知乎1 背景 去年我理解了 torch.gather()用法,今年看到又给忘了,索性把自己的理解梳理出来,方便今后遗忘后