Pythorch gather函数

Pytorch torch.gather 函数


官方文档


torch.gather(input, dim, index, out=None) → Tensor

    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters:	

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:

    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]



由上面的公式可以看出,输出的维度和index的维度是一致的。
对于dim=0时,out[i][j][k]的取值过程为:
首先找到index[i][j][k]的值,然后作为input的第一个维度的索引,后面维度的索引j,k。最后按照索引找到input对应位置的值,作为out[i][j][k]的值。
对于dim=1时,首先找到index[i][j][k]的值,然后作为input的第二个维度的索引,其他维度的索引为i,k。最后按照索引找到input对应位置的值,作为out[i][j][k]的值。
对于dim=2时,同理。

示例

import  torch

b = torch.Tensor([[1,2,3],[4,5,6]])

print(b)
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])

a = torch.gather(b,dim=0,index=index_2)
print(a)

输出为

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 5., 6.],
        [1., 2., 3.]])

此时dim=0,是将index对应位置的值,作为input第一个维度的索引。比如,index第一行第二列[0][1]的值为1,那么out的第一行第二列的值为inpu第二行第二列的值5.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值