pytorch函数中,很多都有dim这个控制参数,但是其具体含义是有时很迷惑,这里记录一下我的理解。
以tensor的sum函数为例,我们构建一个234的三维张量,分别在dim=0,1,2上求和。
>>x = [[[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]],
[[2, 2, 2, 2], [1, 1, 1, 1], [3, 3, 3, 3]]]
>>x = torch.tensor(x)
>>x.shape
torch.Size([2, 3, 4])
>>x.sum(dim=(0), keepdim=True)
tensor([[[3, 4, 5, 6],
[6, 7, 8, 9],
[7, 6, 5, 4]]])
>>x.sum(dim=(0), keepdim=True).shape
torch.Size([1, 3, 4])
>>x.sum(dim=(1), keepdim=True)
tensor([[[10, 11, 12, 13]],
[[ 6, 6, 6, 6]]])
>>x.sum(dim=(1), keepdim=True).shape
torch.Size([2, 1, 4])
>>x.sum(dim=(2), keepdim=True)
tensor([[[10],
[26],
[10]],
[[ 8],
[ 4],
[12]]])
>>x.sum(dim=(2), keepdim=True).shape
torch.Size([2, 3, 1])
从维度上可以看到,dim等于几表明对哪个维度进行操作,sum的最终结果就是将该维上数值进行求和,其他维度不受干扰。我们再看一个例子:
>>x.argmax(dim=(1), keepdim=True)
tensor([[[1, 1, 1, 1]],
[[2, 2, 2, 2]]])
>>x.argmax(dim=(1), keepdim=True).shape
torch.Size([2, 1, 4])
>>x.argmax(dim=(0), keepdim=True)
tensor([[[1, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 1]]])
>>x.argmax(dim=(0), keepdim=True).shape
torch.Size([1, 3, 4])
>>x.argmax(dim=(2), keepdim=True)
tensor([[[3],
[3],
[0]],
[[0],
[0],
[0]]])
>>x.argmax(dim=(2), keepdim=True).shape
torch.Size([2, 3, 1])
这个例子也可以印证以上想法,即所有的操作只对设置的dim维度进行,其他维度上不进行诸如比较、累加等操作,还是保留在这些维度上的尺寸信息。
>>x.sum(dim=-1, keepdim=True)
tensor([[[10],
[26],
[10]],
[[ 8],
[ 4],
[12]]])
>>x.sum(dim=-1, keepdim=True).shape
torch.Size([2, 3, 1])
这里dim=-1即最后一个维度