unsqueeze函数:就是在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度。
例子:
<<<
a = torch.arange(0, 5)
b = torch.arange(0,5).unsqueeze(1)
print(a.size())
print(b.size())
<<<
输出:
torch.Size([5])
torch.Size([5, 1])
'''
正数的话很好理解,如果是负数呢?
如下:
import torch
a=torch.rand(2,3,1)
print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1])
通过上述例子可以说明,正数是向右数插入的维度,负数则是向左数插入的维度。
上述例子原维度是(2,3,1),对应下标为(0,1,2),-2表示从下表0开始向左数两位,即在原下标为1的地方插入一个维度,原维度向左移动,即得出(2,3,1,1)。