torch.unsqueeze()
torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度
import torch
a = torch.tensor([[1,2,3],[4,5,6]]) # size([2,3])
print(a.unsqueeze(0)) #扩增0维
>>> tensor([[[1,2,3],[4,5,6]]]) # size([1,2,3])
print(a.unsqueeze(1)) #增加1维
>>> tensor([[[1,2,3]],[[4,5,6]]]) # size(2,1,3)
print(a.unsqueeze(2)) # 增加2维
>>>tensor([[[1],[2],[3]],[[4],[5],[6]]]) # size(2,3,1)