gather torch_Pytorch的gather用法理解

先放一张表,可以看成是二维数组

行(列)索引

索引0

索引1

索引2

索引3

索引0

0

1

2

3

索引1

4

5

6

7

索引2

8

9

10

11

索引3

12

13

14

15

看一下下面例子代码:

针对0维(输出为行形式)

>>> import torch as t

>>> a = t.arange(0,16).view(4,4)

>>> a

tensor([[ 0, 1, 2, 3],

[ 4, 5, 6, 7],

[ 8, 9, 10, 11],

[12, 13, 14, 15]])

#选取对角线的元素

>>> index = t.LongTensor([[0,1,2,3]])

>>> a.gather(0,index)

tensor([[ 0, 5, 10, 15]])

如何理解结果呢?其实很简单,就是a.gather(0,index)中第一个0已经表明输出结果是行形式(0维),如果第一个是1说明输出结果是列形式(1维),然后按照index = tensor([[0, 1, 2, 3]])顺序作用在行上索引依次为0,1,2,3:

a[0][0] = 0

a[1][1] = 5

a[2][2] = 10

a[3][3] = 15

针对0维

# 选取反对角线上的元素,注意与上面的不同

>>> index2 = t.LongTensor([[3,2,1,0]])

>>> a.gather(0,index2)

tensor([[12, 9, 6, 3]])

如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在行上索引依次为3,2,1,0:

a[3][0] = 12

a[2][1] = 9

a[1][2] = 6

a[0][3] = 3

针对1维(输出为列形式)

#选取对角线的元素

>>> index3 = t.LongTensor([[0,1,2,3]]).t()

>>> a.gather(1,index3)

tensor([[ 0],

[ 5],

[10],

[15]])

如何理解结果呢?同理,按照index = tensor([[0, 1, 2, 3]])顺序作用在列上索引依次为0,1,2,3:

a[0][0] = 0

a[1][1] = 5

a[2][2] = 10

a[3][3] = 15

针对1维

选取反对角线上的元素

>>> index4 = t.LongTensor([[3,2,1,0]]).t()

>>> a.gather(1,index4)

tensor([[ 3],

[ 6],

[ 9],

[12]])

如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在列上索引依次为3,2,1,0:

a[0][3] = 3

a[1][2] = 6

a[2][1] = 9

a[3][0] = 12

原文:https://www.cnblogs.com/HongjianChen/p/9451526.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值