【Pytorch】torch.gather函数理解

【参考:torch.gather — PyTorch 1.12 documentation

out = torch.gather(input, dim, index)

out[i][j][k] = input[ index[i][j][k] ] [j][k]   # if dim == 0
out[i][j][k] = input[i] [ index[i][j][k] ] [k]  # if dim == 1
out[i][j][k] = input[i][j] [ index[i][j][k] ]   # if dim == 2

规则:需要构造一个和index一样shape的索引数组,然后把第dim维换成index[i][j]所得到的值,得到新的索引数组,然后根据该索引数组去input取值。

【参考:Pytorch中torch.gather函数祥解 - 简书

用二维tensor举例

#按照dim = 0,  取一个2*2tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])

b = torch.gather(a, dim = 0, index=torch.LongTensor([[0, 1]]))

print('a = ', a)
print('b = ', b)
"""
a =  tensor([[1., 2.],
        [3., 4.]])
b =  tensor([[1., 4.]])
"""

index的shape是[1,2],所以i=0,j=0,1

i,j=0,0
b[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=0,1
b[0][0]=a[ index[0][1] ][1] =a[1][1]=4

案例二

#按照dim = 1,  取一个2*2 tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])

c = torch.gather(a, dim = 1, index=torch.LongTensor([[0], 
                                                     [1]]))
print('a = ', a)
print('c = ', c)
"""
a =  tensor([[1., 2.],
        [3., 4.]])
c =  tensor([[1.],
        [4.]])
"""
index的shape是[2,1],所以i=0,1,j=0

i,j=0,0
c[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=1,0
c[0][0]=a[ index[1][0] ][0] =a[1][1]=4

案例三


import torch
a = torch.Tensor([[1, 2],
                  [3, 4]])
d = torch.gather(a, dim= 0, index=torch.LongTensor([[0, 0],
                                                   [1, 0]]))
print('a = ', a)
print('d = ', d)
"""
a =  tensor([[1., 2.],
             [3., 4.]])

d =  tensor([[1., 2.],
             [3., 2.]])
"""
index的shape是[2,2],所以i=0,1,j=0,1

i,j=0,0
d[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=0,1
d[0][1]=a[ index[0][1] ][1] =a[0][1]=2
i,j=1,0
d[1][0]=a[ index[1][0] ][0] =a[1][0]=3
i,j=1,1
d[1][1]=a[ index[1][1] ][1] =a[0][1]=2

实际中的一个例子

引言:在多分类中,torch.gather常用来取出标签所对应的概率

有三个标签[0, 1, 2],即三个类别。现在知道两个样本(A 和 B)所得到的三个标签的概率分别为[0.1, 0.3, 0.6]和[0.3, 0.2, 0.5], 用my_pred表示, 这两个样本的真实标签分别为0和2, 那么我们很容易知道A所预测的真实标签的概率为0.1, B所预测的真实标签的概率为0.5,A分类错误,B正确分类。那么用程序这么获得标签对应的概率呢,这里就可以用gather函数。

import torch

my_pred = torch.tensor([[0.1, 0.3, 0.6], 
						[0.3, 0.2, 0.5]])
my = torch.LongTensor([[0],
                       [2]])
print(torch.gather(input=my_pred, dim=1, index=torch.LongTensor([[0],
                                                                [2]])))
"""
tensor([[0.1000],
        [0.5000]])
"""                                                                
index的shape是[2,1],所以i=0,1,j=0

i,j=0,0
d[0][0]=a[0][ index[0][0] ] =a[0][0]=0.1
i,j=1,0
d[1][0]=a[1][ index[1][0] ] =a[1][2]=0.5
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值