torch.gather的用法

torch.gather()函数用于收集数据。有两种用法,假如有一个tensor p, 则有:

torch.gather(p, dim = 1, index = p_i)

或者

p.gather(dim=1, index=p_i)

此外,实际上这个函数是为了在强化学习中从Q表中根据行动方便的选择对应的值函数。

例如我有如下Q表和对应的动作:

Q = torch.tensor([[0.1, 0.2],
                  [0.2, 0.0],
                  [0.3, 0.1]])

A = torch.tensor([0, 1, 1])

我想快速的根据动作A选择Q值怎么做呢?

首先,在使用gather时,A要是tensor类型,其次,A的维度应该和Q相同,这里A是一维而Q是二维,因此先对A进行转换,可采用.view方法:

A = A.view(-1, 1)

此时A变成3行1列数组:

A = tensor([[0],     # 第一行

                   [1],     # 第二行

                   [1]])     # 第三行

其实选择的应该是Q表中对应的[0, 0], [1, 1]和[2, 1]三个索引对应的数,即A中的0, 1, 1为列索引,而行索引本身与A中的行对应,而A中的元素索引为[0, 0], [1, 0]和[2, 0], 现在要将A中的值[0, 1, 1]分别代替A中的元素索引[0, 0], [1, 0]和[2, 0]中的第二个元素,即列索引[0, 0, 0]

所以有dim=1

因此有如下代码:

Q_choose_byA = Q.gather(1, A)

完整代码如下

import torch

Q = Q = torch.tensor([[0.1, 0.2],
                      [0.2, 0.0],
                      [0.3, 0.1]])

A = torch.tensor([0, 1, 1])

A = A.view(-1, 1)

Q_choose_byA = Q.gather(1, A)

print(Q_choose_byA)

结果为:

 

可见元素被正确选出

这里介绍gather更通用的用法

假如已知数列

a=torch.tensor([

[1, 2],

[3, 4],

[5, 6]])

一、选出a中每一行的第1, 0, 1个元素(对应值为2, 3, 6)

则可以创建

b = torch.tensor([1, 0, 1]).view(-1, 1)

此时b = 

torch.tensor([

[1],

[0],

[1]])

用 c = a.gather(1, b) 即可,即b替换每一行的列

二、选出a中第2行的1,0,1个元素(对应值为4, 3, 4)

由于是第二行,则可将第二行元素提出

a1 = a[2]

此时a1为一维tensor(即torch.tensor([3, 4]))

则此时b也是一维

b = torch.tensor([1, 0, 1])

用c = a1.gather(0, b)

三、选出a中第1列中的1,0,1个元素(4, 2, 4)

与二类似,可先将第二列元素取出

a2 = a[:, 1]

此时a2为1维tensor(即torch.tensor([2, 4, 6])

则这时b应当是一维(b的维度要与a2相同)

b=torch.tensor([1, 0, 1])

用c = a2.gather(0, b)

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值