说到扩展维度,可能第一想法是调用sequeeze()
函数,但是实际上会有更简单的方式:
x = torch.Tensor(3, 2)
print(x.size())
print(x[:, :, None].size())
print(x[:, None, :, None].size())
打印结果:
torch.Size([3, 2])
torch.Size([3, 2, 1])
torch.Size([3, 1, 2, 1])
- 很简单,看代码就能理解,不再赘述。
另外,在PyTorch里,repeat
和expend
函数的区别在于,对于dim为1的某个维度,后者不会申请更多的内存。这个机制有利于节省空间。
x = torch.Tensor(3, 1)
print(x.repeat(1, 4).size())
print(x.expand(3, 4).size())
print(x.repeat(1, 4))
print(x.expand(3, 4)) # 等价于 x.expand(-1,4)
print:
torch.Size([3, 4])
torch.Size([3, 4])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
[ 3.0844e-41, 3.0844e-41, 3.0844e-41, 3.0844e-41],
[-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])
tensor([[-7.6648e+06, -7.6648e+06, -7.6648e+06, -7.6648e+06],
[ 3.0844e-41, 3.0844e-41, 3.0844e-41, 3.0844e-41],
[-7.7269e+05, -7.7269e+05, -7.7269e+05, -7.7269e+05]])
实际上,这是numpy.array的一种特性,PyTorch中的Tensor继承了Numpy中array的很多特性,因此你在PyTorch中也可以这么用。