torch.gater()方法自己理解

import torch as t
a = t.arange(0, 16).view(4, 4)
print(a)
index = t.LongTensor([[0,2,2,3],[1,1,0,3]])
b=a.gather(0, index)
print('-----------------------')
print(b)
b=a.gather(1, index)
print('........................')
print(b)

刚看到gather这个方法,这里面比较懵的是 dim=0,dim=1,其实只要这么理解就好,记住:

dim=0代表行
dim=1代表列


接下来:

a =tensor([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]])

index = t.LongTensor([[0, 2, 2, 3],
                                    [1, 1, 0, 3]])

二维的数组取值:一般取值是先找第几行,再找第几列,好的,

b=a.gather(0, index) # 这里的dim=0,那么意思就是 index 里面的值代表行

比如 [0, 2, 2, 3] 就代表依次取第0行,第2行,第2行,第3行,那么列呢?它们取第几个,这里的列就根据其在第几列就取第几列,

[0, 2, 2, 3]中

0是第0列,则取第0列,

2是在第一列,则取第一列,

后面一个2是在第二列,则取第二列,

3是在第三列,则取第三列

所以[0,2,2,3]取值是:
                                a[0][0],
                                a[2][1],
                                a[2][2],
                                a[3][3]

dim=0,意思就是里面的值代表第几行,列就看对应的值在第几列,就是取第几列

接下来看dim=1

dim=1就是代表里面的值是取第几列,但是取第几行呢?这个行就默认其在第几行就取第几行,官网要求 input.size() == index.size(),这里要求的维度相同,所以
index = t.LongTensor([[0,2,2,3],[1,1,0,3]])
d=a.gather(1, index)  

这里[[0,2,2,3],[1,1,0,3]]中[0,2,2,3]在第0行,所以都是在第0行取,第几列就按照里面的值,0代表第0列,2代表第2列,3代表第3列;[1,1,0,3]代表都是在第1行取值,
1代表第1列,0代表第0列,3代表第3列

所以:
        dim=0,就代表index中的值是行,至于列,就看其在第几列就取第几列         

        dim=1,就代表index中的值是列,至于行,就看其在第几行就取第几行


上下两行对照看就比较容易理解了,要么有行,列就按照其所在的列,要么有列,行就按照整体在哪一行;以上解释是由结果推到的,看了很多人的博客和官网,都没有很好理解。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值