在Pytorch中涉及张量的操作都会涉及“dim”的设置,虽然也理解个大差不差,但是偶尔还是有点犯迷糊,究其原因还是没有形象化的理解。
首先,张量的维度排序是有固定顺序的,0,1,2,......,是遵循一个从外到内的索引顺序;张量本身的维度越高,往内延伸的维度数越高。
“dim define what operation elements is”——这是我自己的形象化理解。
看一组代码:
>>> ones = torch.ones(3,4)
>>> ones
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
>>> zeros = torch.zeros(3,4)
>>> zeros
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
>>> ra = torch.arange(12).view(3,4)
>>> ra
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> torch.stack((ra,zeros),dim=0)
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]]])
>>> torch.stack((ones,zeros),dim=0)
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
>>> torch.stack((ones,zeros),dim=-1)
tensor([[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],
[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],
[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]]])
>>> torch.stack((ra,zeros),dim=-1)
tensor([[[ 0., 0.],
[ 1., 0.],
[ 2., 0.],
[ 3., 0.]],
[[ 4., 0.],
[ 5., 0.],
[ 6., 0.],
[ 7., 0.]],
[[ 8., 0.],
[ 9., 0.],
[10., 0.],
[11., 0.]]])
>>> torch.stack((ra,zeros),dim=1)
tensor([[[ 0., 1., 2., 3.],
[ 0., 0., 0., 0.]],
[[ 4., 5., 6., 7.],
[ 0., 0., 0., 0.]],
[[ 8., 9., 10., 11.],
[ 0., 0., 0., 0.]]])
>>> print("dim define what operation elements is")
dim define what operation elements is
>>>
>>>
看完代码你应该会比较形象化的理解最后一句话:dim其实定义了参与操作的元素是什么样的。对于一个batch的数据来说,dim=0上定义的是一个个样本,dim=1定义了第二个维度即每个样本的特征维度,......, dim=-1代表了从最底层的逐个数值操作。