在看蘑菇书的时候出现了这样的代码:
'''计算当前(s_t,a)对应的Q(s_t, a)'''
'''torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])'''
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 等价于self.forward
注意这里代码第一行的备注的代码似乎不太对,如果报错:RuntimeError: gather(): Expected dtype int64 for index
可以尝试把torch.Tensor()
改为torch.LongTensor()
,可能是高版本的问题。还有就是这个注释的结果是错的,正常输出应该是tensor([[1.], [4.]])
,具体原因看下文。
函数torch.gather(input, dim, index, out=None) → Tensor沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合。对一个 3 维张量,输出可以定义为:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters:
input (Tensor) – 源张量
dim (int) – 索引的轴
index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
out (Tensor, optional) – 目标张量
使用说明举例:
dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 28., 22., 27., 0.]],
[[ 26., 10., 20., 29., 18.],
[ 5., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18., 26., 22., 1., 0.],
[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.]],
[[ 5., 29., 10., 0., 22.],
[ 26., 10., 20., 29., 18.],
[ 10., 29., 10., 0., 22.]]])
'''
可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,第一行分别是 按照index指定的,
input tensor的第一页 第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,第五列下标为2元素index-->0,1,2,0,2 output--> 18., 26., 22., 1., 0.
dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18., 5., 7., 18., 7.],
[ 3., 3., 3., 3., 3.],
[ 28., 28., 28., 28., 28.]],
[[ 10., 20., 20., 20., 20.],
[ 5., 5., 5., 5., 5.],
[ 10., 10., 10., 10., 10.]]])
'''
dim = 2的时候就安装 行 聚合了。参照上面的举一反三。
dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18., 10., 20., 1., 18.],
[ 3., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]],
[[ 26., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 29., 22., 27., 0.]]])
'''
这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。这里举的例子只有两页。所有index在0,1两个之间选择。输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。index [0,1,1,0,1]的意思就是。在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。
开头的问题
torch.gather:对于a=torch.Tensor([[1,2],[3,4]]),那么a.gather(1,torch.Tensor([[0],[1]]))=torch.Tensor([[1],[3]])
为什么有问题
当dim=1时,应该是按列来的。
a=torch.Tensor([[1,2],[3,4]])
print(a)
"""
tensor([[1., 2.],
[3., 4.]])
"""
index1=torch.LongTensor([[0],[1]])
print(index1)
"""
tensor([[0],
[1]])
这是二维的,由于dim=1,所以是在对应向量内查找,只查找当前向量,因此0查找[[1.,2.]]中索引为0的,即1;而1查找[[3.,4.]]中索引为1的,即4
"""
c=torch.gather(a, 1,index1)
print(c)
"""
tensor([[1.],
[4.]])
"""
d=torch.gather(a, 0,index1)
prind(d)
"""
tensor([[1.],
[3.]])
由于dim=0,所以是在页内查找,只查找对应页,在这里是对应列,因此0查找[[1.],[3.]]中索引为0的,即1;而1查找[[1.],[3.]]中索引为1的,即3
"""
反正不管是多少维的Tensor,记住都是先列后行,二维就是先列后行,三维就是先页后列在行,高维的也是这样。