在numpy和tensorflow中都有扩展维度操作:expand_dims操作
pytorch中也有这个操作,但是命名不一样,pytorch中的命名是:unsqueeze,直接放在tensor后面即可。
示例如下:
import torch
x1 = torch.zeros(10, 10)
x2 = x1.unsqueeze(0) # 括号里的参数是扩展的维度的位置
print(x2.size())
"""
返回:torch.Size([1, 10, 10])
"""
unsqueeze_与unsqueeze有同样的效果
import torch
x1 = torch.zeros(10, 10)
x2 = x1.unsqueeze_(0) # 括号里的参数是扩展的维度的位置
print(x2.size())
"""
返回:torch.Size([1, 10, 10])
"""
参考:https://jbencook.com/adding-a-dimension-to-a-tensor-in-pytorch/