[pytorch] torch.gather()函数

记录一下torch.gather函数

用法torch.gather(input: Tensor, dim: int, index: LongTensor, *, sparse_grad=False, out=None) -> Tensor

功能:指定张量index,根据其元素的值来获取输入矩阵input上的值。

注意

  • index需要与input有相同的维度,并且对d!=dim时要求index.size(d)<=input.size(d)
    意思就是说如果input的size为(2, 3, 4),如果dim指定为1,那么需要index.size(0)<=2以及index.size(2)<=4.
  • 函数输出的Tensor与index的shape相同

举个例子:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

怎么得到这个结果的呢,可以这样记忆:index现在是[[0, 0], [1, 0]],它的每个元素在index中都有其索引,比如元素1索引是[1, 0](index[1, 0]=1),由于现在指定的dim=1,那么就用1代替[1, 0]dim=1处的0,变成[1, 1],即获取到input[1, 1] ,如下图所示。
示意图

对于多维矩阵也是一样的流程,用index的每个元素的值代替该元素在index上的索引在dim维度上的值,便能得到在input上的索引。
也就是官方举的例子:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

再有一个例子

>>> input_ = [[2, 3, 4, 5, 0, 0],
		 	  [1, 4, 3, 0, 0, 0],
    	 	  [4, 2, 2, 5, 7, 0],
    	 	  [1, 0, 0, 0, 0, 0]]
>>> input_ = torch.tensor(input_)
>>> index = torch.LongTensor([[3],[2],[4],[0]])
>>> torch.gather(input_, 1, index)
tensor([[5],
        [3],
        [7],
        [1]])
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值