python torch 多个矩阵作为矩阵索引

最近看代码遇到多个矩阵作为矩阵索引的情况,网上没找到能理解的资料,有点懵,记录一下
如有错误,敬请指正
预备知识:
数组和矩阵作为矩阵索引:https://blog.csdn.net/yzlh2009/article/details/114118470

1. 定义使用的矩阵

batch_indices=torch.arange(2, dtype=torch.long)[:, None].repeat(1, 4)
row_indices = torch.arange(4, dtype=torch.long)[None, :].repeat(2, 1) 
itself_indices=torch.arange(4, dtype=torch.long)[None, :].repeat(2, 1) 

batch_indices
Out[14]: 
tensor([[0, 0, 0, 0],
        [1, 1, 1, 1]])
row_indices
Out[15]: 
tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])
itself_indices
Out[16]: 
tensor([[0, 1, 2, 3],
        [0, 1, 2, 3]])

import numpy as np
group_idx=np.array([[[0, 1, 2, 3],
   ...:          [4, 5, 6, 7],
   ...:          [8, 9, 10, 11],
   ...:          [12, 13, 14, 15]],
   ...:         [[16, 17, 18, 19],
   ...:          [20, 21, 22, 23],
   ...:          [24, 25, 26, 27],
   ...:          [28, 29, 30, 31]]])

2. 单个矩阵作为矩阵索引

简单解释就是,batch_indices中的每个元素都作为group_idx中的索引,比如batch_indices中的“0”索引group_idx的第0维,batch_indices中的“1”索引group_idx的第1维。在每个batch_indices元素为“0”的位置用group_idx的第0维代替,每个batch_indices元素为“1”的位置用group_idx的第1维代替,得到输出。

group_idx[batch_indices]
Out[33]: 
array([[[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],
        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],
        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],
        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]],
       [[[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]],
        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]],
        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]],
        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]]])

要注意,group_idx的size为2,只有2维,所以作为索引的矩阵里,元素只能为“0”、“1”,否则无效

group_idx[row_indices]
Traceback (most recent call last):
  File "E:\……\IPython\core\interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-3082f36a049a>", line 1, in <module>
    group_idx[row_indices]
IndexError: index 2 is out of bounds for dimension 0 with size 2

3. 多个矩阵作为矩阵索引

简单理解就是,group_idx[batch_indices,row_indices]group_idx[batch_indices] 基础上、group_idx[batch_indices,row_indices,itself_indices]group_idx[batch_indices,row_indices] 基础上按照对应维度选择。
row_indices[0] 中的元素作为索引,在 group_idx[batch_indices][0] 中找到对应数组得到输出,
eg: row_indices[0][0] 对应元素“0” ,在 group_idx[batch_indices][0] 中按照索引“0”找到数组 [ 0, 1, 2, 3]
row_indices[1][0] 对应元素“0” ,在 group_idx[batch_indices][1] 中按照索引“0”找到数组 [16, 17, 18, 19]
注意 row_indices 的维度要和 group_idx[batch_indices] 的size对应;itself_indices 的维度和 group_idx[batch_indices,row_indices] 的size对应

group_idx[batch_indices,row_indices]
Out[34]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]],
       [[16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]]])
group_idx[batch_indices,row_indices,itself_indices]
Out[35]: 
array([[ 0,  5, 10, 15],
       [16, 21, 26, 31]])
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值