gather() 的函数功能

参考文章:Pytorch中的torch.gather函数的含义

demo

b = torch.Tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(b.gather(dim=1, index=index_1))#print(torch.gather(b, dim=1, index=index_1))
print(b.gather(dim=0, index=index_2))
    

gather函数的功能可以解释为根据 index 参数(即是索引)返回数组里面对应位置的值
这里的b.gather()写法和torch.gather(b)的写法都可以,重点是两个参数,dim和index

低维的理解方式

dim=0表示按行来索引,也就是说index的值表示的是第几行
dim=1表示按列来索引,也就是指index的值表示的是第几列

b.gather(dim=1, index=index_1)可以看到index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])是一个3行2列的矩阵,根据dim=1,index_1里面的值表示的就是第几列,第几行就由index_1决定(共3行),那么[0,1]表示的就是【第0行第0列,第0行第1列】;[2,0]表示【第1行第2列,第1行第0列】, [1,1]表示【第3行第1列,第3行第1列】

b.gather(dim=0, index=index_2)可以看到index_2 = torch.LongTensor([[0,1,1],[0,0,0]])是一个2行3列的矩阵,根据dim=0,index_2里面的值表示的就是第几行,第几列就由index_2决定(共3列),那么[0,1,1]表示的就是【第0行第0列,第1行第1列,第1行第2列】;[0,0,0]表示【第0行第0列,第0行第1列,第0行第2列】

运行结果

'''b'''
######### 0列  1列  2列  3列
tensor([[ 1.,  2.,  3.,  4.],#   0行
        [ 5.,  6.,  7.,  8.],#   1行
        [ 9., 10., 11., 12.]])#  2行
'''b.gather(dim=1, index=index_1)'''
tensor([[ 1.,  2.],
        [ 7.,  5.],
        [10., 10.]])
'''b.gather(dim=0, index=index_2)'''
tensor([[1., 6., 7.],
        [1., 2., 3.]])

高维的理解方式

b.gather(dim=1, index=index_1)可以看到index_1 = torch.LongTensor([[0,1],[2,0],[1,1]])是一个3行2列的矩阵,index_1的[0,1]中的0的索引是(0,0),1的索引是(0,1);[2,0]中的2索引是(1,0),0的索引是(1,1);[1,1]中左边1的索引是(2,0),右边1的索引是(2,1)。然后根据dim=1,需要把这些索引的dim=1维度的值全部替换成对应index_1中的值,操作如下:

[0,1]中的0的索引是(0,0)转变为(0,0),1的索引是(0,1)转变为(0,1)
[2,0]中的2索引是(1,0)转变为(1,2),0的索引是(1,1)转变为(1,0)
[1,1]中左边1的索引是(2,0)转变为(2,1),右边1的索引是(2,1)转变为(2,1)

转变之后的索引对应到b上,把对应索引的数值取出来即可

  • 15
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值