pytorch中gather用法


gather其实是对input进行一种映射,index必须是 LongTensor格式。

2维度tensor进行映射:

b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1)
print(torch.gather(b, dim=0, index=index_2)

输出:

 1 2 3
 4 5 6
[torch.FloatTensor of size 2x3]
 
 1 2
 6 4
[torch.FloatTensor of size 2x2]
 
 1 5 6
 1 2 3
[torch.FloatTensor of size 2x3]

dim = 1时表示跨越列维度来对行维度进行映射。index_1中的[0, 1]代表选取b中的第一行列标为0, 1的元素。index_1中的[2,0]代表选取b中第二行列标为2,0的元素进行映射。
同理dim = 0时表示跨越行对列维度进行映射,index_2中第0列表示对b第0列的下标为0的行进行映射,index_2第1列表示对b第1列第1行和第0行进行映射。index_2第2列表示对b第2列的1行和0行进行映射。

3维度tensor进行映射:

import torch
import numpy as np
a = torch.LongTensor(np.arange(24)).view(2,3,4)
print(a)
'''
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
'''

index = torch.LongTensor([[[0 ,1 ,2 ,0],
                          [0, 0, 0 ,0],
                          [1, 1, 1, 1]],
                        [[2, 2, 2, 2],
                         [1, 1, 1, 1],
                         [0, 0, 0, 0]]])

b = torch.gather(a, 1, index)
print(b)
"""
tensor([[[ 0,  5, 10,  3],
         [ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[20, 21, 22, 23],
         [16, 17, 18, 19],
         [12, 13, 14, 15]]])
"""

我们设定dim=1,这时候映射的情况下,改变的只有第1个维度。对于例子中这个三维的张量a,对a[0]和a[1]两块,我们分别取出a[0]和a[1]的每一列,也就是[0,4,8]、[1,5,9]……[15,19,23]。每一列值之间第0个维度的索引和第2个维度的索引都是固定不变的,而改变的,只有dim=1的维度。这也是我们设定的。

同样的,取出index的这些列,也就是[0,0,1]、[1,0,1]……[2,1,0],这些值恰巧就是 映射后的张量 每一列值的对映射前每一列值的索引(表示不太清楚)。
具体的,我们拿出张量b的这些列(b是映射后的张量),也就是[0,0,4]、[5,1,5]……[23,19,15]
我们拿出a,index,b,的第一列。分别是[0,4,8],[0,0,1],[0,0,4],可以发现,index的列对a列的索引,刚好就是b的列。其余列也是一一对应。

3维部分引用CSDN博主「山坡上幼稚狗」
原文链接:https://blog.csdn.net/ccwlisha/article/details/101524389

  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值