Pytorch 的 gather的直观理解(做个记录)

torch.gather(input, dim, index, out=None),注:pytorch要求,除了dim指定的维度外,input和index在其他维度上必须大小一致.

直观的通俗解释:

如果都是2维的,dim = 0, 表示从input中的每一列中,选择第index行的值,生成一个新tensor

                               dim = 1, 表示从input中的每一行中,选择第index列的值,生成一个新tensor

例如:

dim=1表示,从input中的每一行中,选择index=[[1],[0],[0]],即选择每一行的第1,0,0个元素,构成新tensor

input                                                                  
tensor([[0.2973, 0.6688],
        [0.6045, 0.5933],
        [0.6964, 0.2374]])

index                                                                  
tensor([[1],
        [0],
        [0]])

torch.gather(input,dim=1,index=index)                                   
tensor([[0.6688],
        [0.6045],
        [0.6964]])

dim=1表示,从input中的每一行中,选择index=[[1,0],[0,0],[0,1]],即选择第一行的第1,0个,第二行选择第0,0个,第三行选择第0,1个元素,构成新tensor

index                                                                  
tensor([[1, 0],
        [0, 0],
        [0, 1]])
                            
torch.gather(input,dim=1,index=index)                                   
tensor([[0.6688, 0.2973],
        [0.6045, 0.6045],
        [0.6964, 0.2374]])

dim=0表示,从input中的每一列中,选择index=[[1,0]],即选择每一列的第1,0个元素,构成新tensor

input                                                                  
tensor([[0.2973, 0.6688],
        [0.6045, 0.5933],
        [0.6964, 0.2374]])

index                                                                  
tensor([[1, 0]])

torch.gather(input,dim=0,index=index)                                  
tensor([[0.6045, 0.6688]])

dim=0表示,从input中的每一列中,选择index=[[1,0],[1,1]],即选择第一列的第1,1个元素,构成新tensor第一列,选择第二列的第0,1个元素构成新tensor的第二列,

换句话说:因为dim=0,所以index表示的数值都是第0维的,即:

index[0,0]=1,表示选input的第1行,列选择index[0,0]的列号,即0,所以index[0,0]在dim=0时,表示选择input的[1,0]

index[0,1]=0,表示选input的第0行,列选择index[0,1]的列号,即1,所以index[0,1]在dim=0时,表示选择input的[0,1]

index[1,0]=1,表示选input的第1行,列选择index[1,0]的列号,即0,所以index[1,0]在dim=0时,表示选择input的[1,0]

index[1,1]=1,表示选input的第1行,列选择index[1,1]的列号,即1,所以index[1,1]在dim=0时,表示选择input的[1,1]

input                                                                  
tensor([[0.2973, 0.6688],
        [0.6045, 0.5933],
        [0.6964, 0.2374]])

index                                                                   
tensor([[1, 0],
        [1, 1]])                              

torch.gather(input,dim=0,index=index)                                  
tensor([[0.6045, 0.6688],
        [0.6045, 0.5933]])

 





详细说明:

目的是从input中,选择index中所指示位置的数值,然后生成一个新的tensor.

再直观的说:

目的是生成一个tensor,

这个新的tensor里的数值都是来自input的,

到底tensor中的哪个位置用input的哪个元素,是由index和dim共同确定的.

这里面有两个问题:

第一, 从input中选择的数据,放到新的tensor中的哪个位置

第二,从input中选择哪个位置的元素.

1 先说第一个, 从input中选择的数据,放到新的tensor中的哪个位置

其实非常直观,新的tensor和index完全一一对应.也就是,根据index的(i,j)元素选出来的数字,就是新的tensor的(i,j)位置的元素.

所以index是4*5的,那么新的tensor也就是4*5的.两个张量完全一致.

例如:index= \begin{bmatrix} [0&2&3] \end{bmatrix},是个1行3列的张量,那么最终输出的tensor也是个1行3列的张量,

而且该张量的[0,0]是根据index的[0,0]的值(0),在input中查找得到的.

                       第[0,1]是根据index的[0,1]的值(2),在input中查找得到的.

                       第[0,2]是根据index的[0,2]的值(3),在input中查找得到的.

2 再说第二个问题,那么index表示的是什么意思,如何根据index从input中选择数据呢?

这里面要涉及到dim参数,他表示了index里面的数据到底标注了input的哪个维度

如果dim=0,表示index的元素表示的是第一个维度,也就是行的索引

如果dim=1,表示index的元素表示的是第二个维度,也就是列的索引

例如:index= \begin{bmatrix} [0&2&3] \end{bmatrix},

input = \begin{bmatrix} 9 & 8& 7 & 6&5 \\ 5& 4 &3 & 2 &1 \\ 4& 5&6 & 7 &8 \\ 1& 4& 6& 8 & 9 \end{bmatrix}

那么如果dim=0,index的值表示的是0维:行的标识,因为index的值是0,2,3,所以也就是取第0,2,3行的对应元素

对应哪个列的元素呢?对应于index那个元素的列号

比如index的(0,0)元素的列号是0,那么就选input的第0列, 结合index(0,0)=0,表示行的编号为0,所以选input[0,0]=9

        index的(0,1)元素的列号是1,那么就选input的第1列, 结合index(0,1)=2,表示行的编号为2,所以选input[2,1]=5

        index的(0,2)元素的列号是2,那么就选input的第2列, 结合index(0,2)=3,表示行的编号为3,所以选input[3,2]=6

再结合第一个问题的答案,新tensor与index一一对应.所以gather(input, dim=0, index)的输出为[[9,5,6]]

 

那么如果dim=1,index的值表示的是1维:列的标识,因为index的值是0,2,3,所以也就是取第0,2,3列的对应元素

对应哪个行的元素呢?对应于index那个元素的行号

比如index的(0,0)元素的行号是0,那么就选input的第0行, 结合index(0,0)=0,表示列的编号为0,所以选input[0,0]=9

        index的(0,1)元素的行号是0,那么就选input的第1行, 结合index(0,1)=2,表示列的编号为2,所以选input[0,1]=8

        index的(0,2)元素的行号是0,那么就选input的第2行, 结合index(0,2)=3,表示列的编号为3,所以选input[0,2]=7

再结合第一个问题的答案,新tensor与index一一对应.所以gather(input, dim=1, index)的输出为[[9,8,7]]

gather的使用场景

比如强化学习中计算了n步选择action的概率,假设action一共4个,那么input.size()为[n,4],现在我们记录了这n步真正执行的action为[0,2,3,3,......],大小为n,现在我们想拿到执行每个动作的概率值.(从input中的每一行,挑选对应那一个action的元素,即action的值表示第1维-列的标记)

那么我们可以用 gather(input,dim=1,action.view(-1,1)),

特别注意,必须把action设置为n行1列,这样才能和input的行对应起来,如果是gather(input,dim=1,action),因为action只有一维,与input不一致,所以会报错.

注:因为除了dim标注的那一维度是从index的值选择,其他维度都需要有对应关系,所以pytorch强制要求input,index和最后输出的tensor,除了dim标志的那一维,其他维度必须一样,否则报错.

In [1]: input                                                                  
Out[1]: 
tensor([[0.2542, 0.2493, 0.2469, 0.2495],
        [0.2493, 0.2494, 0.2525, 0.2488],
        [0.2506, 0.2518, 0.2500, 0.2476],
        [0.2466, 0.2527, 0.2522, 0.2484],
        [0.2466, 0.2487, 0.2492, 0.2555],
        [0.2528, 0.2480, 0.2491, 0.2501]])
In [2]: action                                                                 
Out[2]: tensor([0, 0, 3, 2, 1, 2])

In [3]: torch.gather(input, dim=1, index=action.view(-1,1))                    
Out[3]: 
tensor([[0.2542],
        [0.2493],
        [0.2476],
        [0.2522],
        [0.2487],
        [0.2491]])

最后列上官网解释,数学公式一行顶的上千言万语.....

torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2
 

总结一句话:

gather就是从input中,针对dim那一维的索引选用index的值,其他维度索引与index的其他维度一致,取出值构成新的tensor.

已标记关键词 清除标记
相关推荐
<p> <b><span style="background-color:#FFE500;">【超实用课程内容】</span></b> </p> <p> <br /> </p> <p> <br /> </p> <p> 本课程内容包含讲解<span>解读Nginx的基础知识,</span><span>解读Nginx的核心知识、带领学员进行</span>高并发环境下的Nginx性能优化实战,让学生能够快速将所学融合到企业应用中。 </p> <p> <br /> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <b><br /> </b> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <b><span style="background-color:#FFE500;">【课程如何观看?】</span></b> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> PC端:<a href="https://edu.csdn.net/course/detail/26277"><span id="__kindeditor_bookmark_start_21__"></span></a><a href="https://edu.csdn.net/course/detail/27216">https://edu.csdn.net/course/detail/27216</a> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 移动端:CSDN 学院APP(注意不是CSDN APP哦) </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 本课程为录播课,课程永久有效观看时长,大家可以抓紧时间学习后一起讨论哦~ </p> <p style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <br /> </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <strong><span style="background-color:#FFE500;">【学员专享增值服务】</span></strong> </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <b>源码开放</b> </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 课件、课程案例代码完全开放给你,你可以根据所学知识,自行修改、优化 </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 下载方式:电脑登录<a href="https://edu.csdn.net/course/detail/26277"></a><a href="https://edu.csdn.net/course/detail/27216">https://edu.csdn.net/course/detail/27216</a>,播放页面右侧点击课件进行资料打包下载 </p> <p> <br /> </p> <p> <br /> </p> <p> <br /> </p>
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页