torch.gather()的理解

函数定义

torch.gather(input, dim, index, out=None) 

其中,

input(tensor):   待操作数,源张量
dim(int):   维度。
index(LongTensor):   索引下标。
out:  输出的张量

注意,输出的张量和index的size是一致的。

举例

dim=0

输入:

a = torch.arange(0,16).view(4,4)
print(a)
index = torch.LongTensor([[0,1,2,3]]) # index的元素必须是int型数据,如果使用torch.Tensor会报错
print(a.gather(0, index))

输出:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[ 0,  5, 10, 15]])

解释:dim=0代表针对0维,即输出为行形式。不难看出,此时index的[0,1,2,3]分别指向源张量a的第0,1,2,3行的第0,1,2,3列的数据。被指向的数据分别为0,5,10,15,又因为dim=0代表“行”,故输出tensor([[ 0,  5, 10, 15]])

 

dim=1

输入:

index3 = torch.LongTensor([[0,1,2,3]]).t() # 添加上t.()将index3变成一列的数据
print(a.gather(1,index3))

输出:

tensor([[ 0],
        [ 5],
        [10],
        [15]])

解释:dim=0代表针对0维,即输出为行形式。不难看出,此时index的[0,1,2,3]分别指向源张量a的第0,1,2,3列的第0,1,2,3行的数据。被指向的数据分别为0,5,10,15,又因为dim=1代表“列”,故有上述输出。

 

dim=2

输入:

a = torch.randint(0, 30, (2, 3, 5))
print("------a------")
print(a)
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print("----index----")
print(index)
c = torch.gather(a, 2,index) # 除了使用a.gather(1,index)外,也可以torch.gather(a,1,index)
print("------c------")
print(c)

输出:

------a------
tensor([[[ 2,  0, 28, 28, 12],
         [28, 28, 13, 26, 13],
         [16,  7, 26, 21,  3]],

        [[19, 16, 29, 20, 23],
         [10, 26,  4, 24, 26],
         [26, 14, 28,  3, 25]]])
----index----
tensor([[[0, 1, 2, 0, 2],
         [0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1]],

        [[1, 2, 2, 2, 2],
         [0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2]]])
------c------
tensor([[[ 2,  0, 28,  2, 28],
         [28, 28, 28, 28, 28],
         [ 7,  7,  7,  7,  7]],

        [[16, 29, 29, 29, 29],
         [10, 10, 10, 10, 10],
         [28, 28, 28, 28, 28]]])

解释:

我们可以看index[0][0],为[0,1,2,0,2],其中的“0 1 2 0 2”五个数字分别指向源张量第0页第0行的第“0 1 2 0 2”列数据,故输出张量c[0][0]=[2,0,28,2,28]。类推即可理解dim=2的输出情况。

 

参考链接:

https://www.cnblogs.com/HongjianChen/p/9451526.html

https://www.jianshu.com/p/5d1f8cd5fe31

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
【优质项目推荐】 1、项目代码均经过严格本地测试,运行OK,确保功能稳定后才上传平台。可放心下载并立即投入使用,若遇到任何使用问题,随时欢迎私信反馈与沟通,博主会第一时间回复。 2、项目适用于计算机相关专业(如计科、信息安全、数据科学、人工智能、通信、物联网、自动化、电子信息等)的在校学生、专业教师,或企业员工,小白入门等都适用。 3、该项目不仅具有很高的学习借鉴价值,对于初学者来说,也是入门进阶的绝佳选择;当然也可以直接用于 毕设、课设、期末大作业或项目初期立项演示等。 3、开放创新:如果您有一定基础,且热爱探索钻研,可以在此代码基础上二次开发,进行修改、扩展,创造出属于自己的独特应用。 欢迎下载使用优质资源!欢迎借鉴使用,并欢迎学习交流,共同探索编程的无穷魅力! 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip 基于业务逻辑生成特征变量python实现源码+数据集+超详细注释.zip

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值