torch.gather理解

torch.gather

先看看定义

torch.gather(input, dim, index, out=None) → Tensor

沿给定轴dim,将输入索引张量index指定位置的值进行聚合。

对一个3维张量,输出可以定义为:

out[i][j][k] = tensor[index[i][j][k]][j][k]  # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k]  # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]]  # dim=3

例子

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
 1  1
 4  3
[torch.FloatTensor of size 2x2]

首先可以确定的

  • 输入的维度数与index的维度数量相等[即input为n维[a,b…c],则index必须为n维[x,y…z]](注意没有规定a和x,b和y,…c和z一定要相等)
  • 输出的元素数量与index的元素数量相等[即index [a,b…c],out[a,b…c]]
    确定了以上两点,之后我们要搞清楚的就是index在每个维度的意义
    定义如下tensor
t = torch.randn(2,3,3)
print(t)
index_a=torch.LongTensor([[[1,2]],[[1,2]]])
index_b=torch.LongTensor([[[1,0]],[[1,0]]])
#第三维度即out[i][j][k]为input(t)相对行列的第index[i][j][k]个元素
print(index_a.shape,torch.gather(t, 2, index_a))
#第二维度即out[i][j][k]为input(t)相对i行index[i][j][k]列的第k个元素
print(torch.gather(t,1,index_a))
#第一维度即out[i][j][k]为input(t)相对index[i][j][k]行j列的第k个元素
print(torch.gather(t,0,index_b))
'''
tensor([[[-1.3513, -0.8054, -1.1973],
         [-0.6869,  0.6490,  0.6097],
         [-0.8863,  0.4745,  0.0845]],

        [[ 0.0430, -1.1662, -1.6117],
         [-0.5186,  0.4622, -0.0349],
         [-0.3438,  1.4358, -0.6612]]])
torch.Size([2, 1, 2]) tensor([[[-0.8054, -1.1973]],

        [[-1.1662, -1.6117]]])
tensor([[[-0.6869,  0.4745]],

        [[-0.5186,  1.4358]]])
tensor([[[ 0.0430, -0.8054]],

        [[ 0.0430, -0.8054]]])
'''
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值