参考:torch.squeeze() 和torch.unsqueeze()用法的通俗解释
import torch
x = torch.tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])
print(x)
print(torch.unsqueeze(x,0))
输出:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
tensor([[[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]])