【CNN记录】pytorch gather函数

torch.gather(input, dim, index, *, sparse_grad=False, out=None)
参数:
input:输入张量
dim:index按照哪个轴取值
index:取值用的索引张量

gather其实就是根据index中索引查找input中元素重排,数据都是原来的,只是重新查找形成新张量矩阵。

公式就是下面这样

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

这么看下来可能有点懵,我们举个栗子

import torch
t = torch.arange(0,32).view(1,2,4,4)

整个4维数据1x2x4x4

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]],

         [[16, 17, 18, 19],
          [20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31]]]])


index为1x1x2x4
index = torch.LongTensor([[[[0,1,2,3],[2,2,2,2]]]])

执行gather,axis设在第3个维度上
a = torch.gather(t, 2, index)

结果为
tensor([[[[ 0,  5, 10, 15],
          [ 8,  9, 10, 11]]]])

index维度为1x1x2x4,所以gather输出也是这个dims

根据上面的公式,我们可以一个个来取值

a[0][0][0][0] = t[0][0][index[0][0][0][0]][0] =  t[0][0][0][0] = 0

a[0][0][0][1] = t[0][0][index[0][0][0][1]][1] =  t[0][0][1][1] = 5

a[0][0][0][2] = t[0][0][index[0][0][0][2]][2] =  t[0][0][1][1] = 10

...

a[0][0][1][2] = t[0][0][index[0][0][1][2]][2] =  t[0][0][2][2] = 10

a[0][0][1][3] = t[0][0][index[0][0][1][3]][3] =  t[0][0][2][3] = 11

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值