torch.gather()的理解

在看蘑菇书的时候出现了这样的代码:

 '''计算当前(s_t,a)对应的Q(s_t, a)'''
        '''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
        q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch)  # 等价于self.forward

注意这里代码第一行的备注的代码似乎不太对,如果报错:RuntimeError: gather(): Expected dtype int64 for index可以尝试把torch.Tensor()改为torch.LongTensor(),可能是高版本的问题。还有就是这个注释的结果是错的,正常输出应该是tensor([[1.], [4.]]),具体原因看下文。

函数torch.gather(input, dim, index, out=None) → Tensor沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合。对一个 3 维张量,输出可以定义为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
Parameters:
input (Tensor) – 源张量
dim (int) – 索引的轴
index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
out (Tensor, optional) – 目标张量

使用说明举例:

dim = 1

import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''

可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,第一行分别是 按照index指定的,
input tensor的第一页 第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,第五列下标为2元素index-->0,1,2,0,2 output--> 18., 26., 22., 1., 0.

dim =2

c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18.,   5.,   7.,  18.,   7.],
         [  3.,   3.,   3.,   3.,   3.],
         [ 28.,  28.,  28.,  28.,  28.]],

        [[ 10.,  20.,  20.,  20.,  20.],
         [  5.,   5.,   5.,   5.,   5.],
         [ 10.,  10.,  10.,  10.,  10.]]])
'''

dim = 2的时候就安装 行 聚合了。参照上面的举一反三。

dim = 0

index2 = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18.,  10.,  20.,   1.,  18.],
         [  3.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]],

        [[ 26.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  29.,  22.,  27.,   0.]]])
'''

这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。这里举的例子只有两页。所有index在0,1两个之间选择。输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。index [0,1,1,0,1]的意思就是。在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。

开头的问题

torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])为什么有问题
当dim=1时,应该是按列来的。

a=torch.Tensor([[1,2],[3,4]])
print(a)
"""
tensor([[1., 2.],
        [3., 4.]])
"""
index1=torch.LongTensor([[0],[1]])
print(index1)
"""
tensor([[0],
        [1]])
这是二维的,由于dim=1,所以是在对应向量内查找,只查找当前向量,因此0查找[[1.,2.]]中索引为0的,即1;而1查找[[3.,4.]]中索引为1的,即4
"""
c=torch.gather(a, 1,index1)
print(c)
"""
tensor([[1.],
        [4.]])
"""
d=torch.gather(a, 0,index1)
prind(d)
"""
tensor([[1.],
        [3.]])
由于dim=0,所以是在页内查找,只查找对应页,在这里是对应列,因此0查找[[1.],[3.]]中索引为0的,即1;而1查找[[1.],[3.]]中索引为1的,即3
"""

反正不管是多少维的Tensor,记住都是先列后行,二维就是先列后行,三维就是先页后列在行,高维的也是这样。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WHUT米肖雄

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值