Pytorch中的dim问题

pytorch函数中,很多都有dim这个控制参数,但是其具体含义是有时很迷惑,这里记录一下我的理解。

以tensor的sum函数为例,我们构建一个234的三维张量,分别在dim=0,1,2上求和。

>>x = [[[1, 2, 3, 4], [5, 6, 7, 8], [4, 3, 2, 1]],
    [[2, 2, 2, 2], [1, 1, 1, 1], [3, 3, 3, 3]]]
>>x = torch.tensor(x)
>>x.shape
torch.Size([2, 3, 4])
>>x.sum(dim=(0), keepdim=True)
tensor([[[3, 4, 5, 6],
         [6, 7, 8, 9],
         [7, 6, 5, 4]]])
>>x.sum(dim=(0), keepdim=True).shape
torch.Size([1, 3, 4])
>>x.sum(dim=(1), keepdim=True)
tensor([[[10, 11, 12, 13]],

        [[ 6,  6,  6,  6]]])
>>x.sum(dim=(1), keepdim=True).shape
torch.Size([2, 1, 4])
>>x.sum(dim=(2), keepdim=True)
tensor([[[10],
         [26],
         [10]],

        [[ 8],
         [ 4],
         [12]]])
>>x.sum(dim=(2), keepdim=True).shape
torch.Size([2, 3, 1])

从维度上可以看到,dim等于几表明对哪个维度进行操作,sum的最终结果就是将该维上数值进行求和,其他维度不受干扰。我们再看一个例子:

>>x.argmax(dim=(1), keepdim=True)
tensor([[[1, 1, 1, 1]],

        [[2, 2, 2, 2]]])
>>x.argmax(dim=(1), keepdim=True).shape
torch.Size([2, 1, 4])
>>x.argmax(dim=(0), keepdim=True)
tensor([[[1, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 1, 1]]])
>>x.argmax(dim=(0), keepdim=True).shape
torch.Size([1, 3, 4])
>>x.argmax(dim=(2), keepdim=True)
tensor([[[3],
         [3],
         [0]],

        [[0],
         [0],
         [0]]])
>>x.argmax(dim=(2), keepdim=True).shape
torch.Size([2, 3, 1])

这个例子也可以印证以上想法,即所有的操作只对设置的dim维度进行,其他维度上不进行诸如比较、累加等操作,还是保留在这些维度上的尺寸信息。

>>x.sum(dim=-1, keepdim=True)
tensor([[[10],
         [26],
         [10]],

        [[ 8],
         [ 4],
         [12]]])
>>x.sum(dim=-1, keepdim=True).shape
torch.Size([2, 3, 1])

这里dim=-1即最后一个维度

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值