torch.gather(input, dim, index, out=None),注:pytorch要求,除了dim指定的维度外,input和index在其他维度上必须大小一致.
直观的通俗解释:
如果都是2维的,dim = 0, 表示从input中的每一列中,选择第index行的值,生成一个新tensor
dim = 1, 表示从input中的每一行中,选择第index列的值,生成一个新tensor
例如:
dim=1表示,从input中的每一行中,选择index=[[1],[0],[0]],即选择每一行的第1,0,0个元素,构成新tensor
input
tensor([[0.2973, 0.6688],
[0.6045, 0.5933],
[0.6964, 0.2374]])
index
tensor([[1],
[0],
[0]])
torch.gather(input,dim=1,index=index)
tensor([[0.6688],
[0.6045],
[0.6964]])
dim=1表示,从input中的每一行中,选择index=[[1,0],[0,0],[0,1]],即选择第一行的第1,0个,第二行选择第0,0个,第三行选择第0,1个元素,构成新tensor
index
tensor([[1, 0],
[0, 0],
[0, 1]])
torch.gather(input,dim=1,index=index)
tensor([[0.6688, 0.2973],
[0.6045, 0.6045],
[0.6964, 0.2374]])
dim=0表示,从input中的每一列中,选择index=[[1,0]],即选择每一列的第1,0个元素,构成新tensor
input
tensor([[0.2973, 0.6688],
[0.6045, 0.5933],
[0.6964, 0.2374]])
index
tensor([[1, 0]])
torch.gather(input,dim=0,index=index)
tensor([[0.6045, 0.6688]])
dim=0表示,从input中的每一列中,选择index=[[1,0],[1,1]],即选择第一列的第1,1个元素,构成新tensor第一列,选择第二列的第0,1个元素构成新tensor的第二列,
换句话说:因为dim=0,所以index表示的数值都是第0维的,即:
index[0,0]=1,表示选input的第1行,列选择index[0,0]的列号,即0,所以index[0,0]在dim=0时,表示选择input的[1,0]
index[0,1]=0,表示选input的第0行,列选择index[0,1]的列号,即1,所以index[0,1]在dim=0时,表示选择input的[0,1]
index[1,0]=1,表示选input的第1行,列选择index[1,0]的列号,即0,所以index[1,0]在dim=0时,表示选择input的[1,0]
index[1,1]=1,表示选input的第1行,列选择index[1,1]的列号,即1,所以index[1,1]在dim=0时,表示选择input的[1,1]
input
tensor([[0.2973, 0.6688],
[0.6045, 0.5933],
[0.6964, 0.2374]])
index
tensor([[1, 0],
[1, 1]])
torch.gather(input,dim=0,index=index)
tensor([[0.6045, 0.6688],
[0.6045, 0.5933]])
详细说明:
目的是从input中,选择index中所指示位置的数值,然后生成一个新的tensor.
再直观的说:
目的是生成一个tensor,
这个新的tensor里的数值都是来自input的,
到底tensor中的哪个位置用input的哪个元素,是由index和dim共同确定的.
这里面有两个问题:
第一, 从input中选择的数据,放到新的tensor中的哪个位置
第二,从input中选择哪个位置的元素.
1 先说第一个, 从input中选择的数据,放到新的tensor中的哪个位置
其实非常直观,新的tensor和index完全一一对应.也就是,根据index的(i,j)元素选出来的数字,就是新的tensor的(i,j)位置的元素.
所以index是4*5的,那么新的tensor也就是4*5的.两个张量完全一致.
例如:,是个1行3列的张量,那么最终输出的tensor也是个1行3列的张量,
而且该张量的[0,0]是根据index的[0,0]的值(0),在input中查找得到的.
第[0,1]是根据index的[0,1]的值(2),在input中查找得到的.
第[0,2]是根据index的[0,2]的值(3),在input中查找得到的.
2 再说第二个问题,那么index表示的是什么意思,如何根据index从input中选择数据呢?
这里面要涉及到dim参数,他表示了index里面的数据到底标注了input的哪个维度
如果dim=0,表示index的元素表示的是第一个维度,也就是行的索引
如果dim=1,表示index的元素表示的是第二个维度,也就是列的索引
例如:,
那么如果dim=0,index的值表示的是0维:行的标识,因为index的值是0,2,3,所以也就是取第0,2,3行的对应元素
对应哪个列的元素呢?对应于index那个元素的列号
比如index的(0,0)元素的列号是0,那么就选input的第0列, 结合index(0,0)=0,表示行的编号为0,所以选input[0,0]=9
index的(0,1)元素的列号是1,那么就选input的第1列, 结合index(0,1)=2,表示行的编号为2,所以选input[2,1]=5
index的(0,2)元素的列号是2,那么就选input的第2列, 结合index(0,2)=3,表示行的编号为3,所以选input[3,2]=6
再结合第一个问题的答案,新tensor与index一一对应.所以gather(input, dim=0, index)的输出为[[9,5,6]]
那么如果dim=1,index的值表示的是1维:列的标识,因为index的值是0,2,3,所以也就是取第0,2,3列的对应元素
对应哪个行的元素呢?对应于index那个元素的行号
比如index的(0,0)元素的行号是0,那么就选input的第0行, 结合index(0,0)=0,表示列的编号为0,所以选input[0,0]=9
index的(0,1)元素的行号是0,那么就选input的第1行, 结合index(0,1)=2,表示列的编号为2,所以选input[0,1]=8
index的(0,2)元素的行号是0,那么就选input的第2行, 结合index(0,2)=3,表示列的编号为3,所以选input[0,2]=7
再结合第一个问题的答案,新tensor与index一一对应.所以gather(input, dim=1, index)的输出为[[9,8,7]]
gather的使用场景
比如强化学习中计算了n步选择action的概率,假设action一共4个,那么input.size()为[n,4],现在我们记录了这n步真正执行的action为[0,2,3,3,......],大小为n,现在我们想拿到执行每个动作的概率值.(从input中的每一行,挑选对应那一个action的元素,即action的值表示第1维-列的标记)
那么我们可以用 gather(input,dim=1,action.view(-1,1)),
特别注意,必须把action设置为n行1列,这样才能和input的行对应起来,如果是gather(input,dim=1,action),因为action只有一维,与input不一致,所以会报错.
注:因为除了dim标注的那一维度是从index的值选择,其他维度都需要有对应关系,所以pytorch强制要求input,index和最后输出的tensor,除了dim标志的那一维,其他维度必须一样,否则报错.
In [1]: input
Out[1]:
tensor([[0.2542, 0.2493, 0.2469, 0.2495],
[0.2493, 0.2494, 0.2525, 0.2488],
[0.2506, 0.2518, 0.2500, 0.2476],
[0.2466, 0.2527, 0.2522, 0.2484],
[0.2466, 0.2487, 0.2492, 0.2555],
[0.2528, 0.2480, 0.2491, 0.2501]])
In [2]: action
Out[2]: tensor([0, 0, 3, 2, 1, 2])
In [3]: torch.gather(input, dim=1, index=action.view(-1,1))
Out[3]:
tensor([[0.2542],
[0.2493],
[0.2476],
[0.2522],
[0.2487],
[0.2491]])
最后列上官网解释,数学公式一行顶的上千言万语.....
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
总结一句话:
gather就是从input中,针对dim那一维的索引选用index的值,其他维度索引与index的其他维度一致,取出值构成新的tensor.