最近看代码遇到多个矩阵作为矩阵索引的情况,网上没找到能理解的资料,有点懵,记录一下
如有错误,敬请指正
预备知识:
数组和矩阵作为矩阵索引:https://blog.csdn.net/yzlh2009/article/details/114118470
1. 定义使用的矩阵
batch_indices=torch.arange(2, dtype=torch.long)[:, None].repeat(1, 4)
row_indices = torch.arange(4, dtype=torch.long)[None, :].repeat(2, 1)
itself_indices=torch.arange(4, dtype=torch.long)[None, :].repeat(2, 1)
batch_indices
Out[14]:
tensor([[0, 0, 0, 0],
[1, 1, 1, 1]])
row_indices
Out[15]:
tensor([[0, 1, 2, 3],
[0, 1, 2, 3]])
itself_indices
Out[16]:
tensor([[0, 1, 2, 3],
[0, 1, 2, 3]])
import numpy as np
group_idx=np.array([[[0, 1, 2, 3],
...: [4, 5, 6, 7],
...: [8, 9, 10, 11],
...: [12, 13, 14, 15]],
...: [[16, 17, 18, 19],
...: [20, 21, 22, 23],
...: [24, 25, 26, 27],
...: [28, 29, 30, 31]]])
2. 单个矩阵作为矩阵索引
简单解释就是,batch_indices中的每个元素都作为group_idx中的索引,比如batch_indices中的“0”索引group_idx的第0维,batch_indices中的“1”索引group_idx的第1维。在每个batch_indices元素为“0”的位置用group_idx的第0维代替,每个batch_indices元素为“1”的位置用group_idx的第1维代替,得到输出。
group_idx[batch_indices]
Out[33]:
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]],
[[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]]])
要注意,group_idx的size为2,只有2维,所以作为索引的矩阵里,元素只能为“0”、“1”,否则无效
group_idx[row_indices]
Traceback (most recent call last):
File "E:\……\IPython\core\interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-19-3082f36a049a>", line 1, in <module>
group_idx[row_indices]
IndexError: index 2 is out of bounds for dimension 0 with size 2
3. 多个矩阵作为矩阵索引
简单理解就是,group_idx[batch_indices,row_indices] 在 group_idx[batch_indices] 基础上、group_idx[batch_indices,row_indices,itself_indices] 在 group_idx[batch_indices,row_indices] 基础上按照对应维度选择。
即 row_indices[0] 中的元素作为索引,在 group_idx[batch_indices][0] 中找到对应数组得到输出,
eg: row_indices[0][0] 对应元素“0” ,在 group_idx[batch_indices][0] 中按照索引“0”找到数组 [ 0, 1, 2, 3] ;
row_indices[1][0] 对应元素“0” ,在 group_idx[batch_indices][1] 中按照索引“0”找到数组 [16, 17, 18, 19] ;
注意 row_indices 的维度要和 group_idx[batch_indices] 的size对应;itself_indices 的维度和 group_idx[batch_indices,row_indices] 的size对应
group_idx[batch_indices,row_indices]
Out[34]:
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]])
group_idx[batch_indices,row_indices,itself_indices]
Out[35]:
array([[ 0, 5, 10, 15],
[16, 21, 26, 31]])