torch.gather()函数用于收集数据。有两种用法,假如有一个tensor p, 则有:
torch.gather(p, dim = 1, index = p_i)
或者
p.gather(dim=1, index=p_i)
此外,实际上这个函数是为了在强化学习中从Q表中根据行动方便的选择对应的值函数。
例如我有如下Q表和对应的动作:
Q = torch.tensor([[0.1, 0.2],
[0.2, 0.0],
[0.3, 0.1]])
A = torch.tensor([0, 1, 1])
我想快速的根据动作A选择Q值怎么做呢?
首先,在使用gather时,A要是tensor类型,其次,A的维度应该和Q相同,这里A是一维而Q是二维,因此先对A进行转换,可采用.view方法:
A = A.view(-1, 1)
此时A变成3行1列数组:
A = tensor([[0], # 第一行
[1], # 第二行
[1]]) # 第三行
其实选择的应该是Q表中对应的[0, 0], [1, 1]和[2, 1]三个索引对应的数,即A中的0, 1, 1为列索引,而行索引本身与A中的行对应,而A中的元素索引为[0, 0], [1, 0]和[2, 0], 现在要将A中的值[0, 1, 1]分别代替A中的元素索引[0, 0], [1, 0]和[2, 0]中的第二个元素,即列索引[0, 0, 0]
所以有dim=1
因此有如下代码:
Q_choose_byA = Q.gather(1, A)
完整代码如下
import torch
Q = Q = torch.tensor([[0.1, 0.2],
[0.2, 0.0],
[0.3, 0.1]])
A = torch.tensor([0, 1, 1])
A = A.view(-1, 1)
Q_choose_byA = Q.gather(1, A)
print(Q_choose_byA)
结果为:
可见元素被正确选出
这里介绍gather更通用的用法
假如已知数列
a=torch.tensor([
[1, 2],
[3, 4],
[5, 6]])
一、选出a中每一行的第1, 0, 1个元素(对应值为2, 3, 6)
则可以创建
b = torch.tensor([1, 0, 1]).view(-1, 1)
此时b =
torch.tensor([
[1],
[0],
[1]])
用 c = a.gather(1, b) 即可,即b替换每一行的列
二、选出a中第2行的1,0,1个元素(对应值为4, 3, 4)
由于是第二行,则可将第二行元素提出
a1 = a[2]
此时a1为一维tensor(即torch.tensor([3, 4]))
则此时b也是一维
b = torch.tensor([1, 0, 1])
用c = a1.gather(0, b)
三、选出a中第1列中的1,0,1个元素(4, 2, 4)
与二类似,可先将第二列元素取出
a2 = a[:, 1]
此时a2为1维tensor(即torch.tensor([2, 4, 6])
则这时b应当是一维(b的维度要与a2相同)
b=torch.tensor([1, 0, 1])
用c = a2.gather(0, b)