目录
在Python编程中,NumPy库因其高效的数组操作(也可以理解为矩阵操作)而广受数据科学家和程序员的喜爱。NumPy数组不仅提供了丰富的数学运算功能,还允许我们以数组作为索引来访问和操作数据。这种索引方式在处理复杂数据结构时显得尤为强大和灵活,在深度学习编程中这种方式效率提升更明显。本文将由浅入深地探讨如何使用NumPy数组作为索引,并展示一些直观的示例。
以单个数组为索引
一维数组索引一维数组
比较常见是数组作为数组的索引,如下:
a=np.array([1,2,3])
b=np.array([1,1,0])
a[b]
结果为
array([2, 2, 1])
一维数组索引二维数组
这里只谈论二维数组
a=np.array([[1,2,3],
[4,5,6]])
b=np.array([1,1,0])
a[b]
结果为
array([[4, 5, 6],
[4, 5, 6],
[1, 2, 3]])
b中每个元素对应a中一行,如果需要对应a单个元素,需要b的元素本身二维列表、数组或者元组。
二维数组索引二维数组
这里只谈论二维数组
a=np.array([[1,2,3],
[4,5,6]])
b=np.array([[1,1,0],
[0,0,0],
[0,0,1]])
a[b]
结果为:
array([[[4, 5, 6],
[4, 5, 6],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3],
[4, 5, 6]]])
相当于在上一种情况中把b变成了三个数组,所以一共得到了三个数组。
多维数组索引多维数组
因为显示的内容过多,以下只显示数组的维度,先看一下代码:
import numpy as np
A, B, C, D = 6, 7, 8, 9 # 被索引数组的维度
R, S, T = 3, 4, 5 # 索引数组的维度
d = np.arange(D)
c = np.array([d]*C)
b = np.array([c]*B)
a = np.array([b]*A)
print('AXBXCXD ', a.shape) # (6, 7, 8, 9)
t = np.arange(T)
s = np.array([t]*S)
r = np.array([s]*R)
print('RXSXT ', r.shape) # (3, 4, 5)
print('RXSXTXBXCXD ', a[r].shape) # (3, 4, 5, 7, 8, 9) A被RXSXT替代
# print(r[a].shape) # 索引数组中的元素数值超过A,出错!
直观的理解就是把被索引数组的第一维拆解,按索引数组中数值在第一维中选取元素(维度为BXCXD),再按索引数组布局方式组合起来。比如被索引数组维度AXBXCXD,索引数组RXSXT,第一维A消失了,BXCXD作为元素的维度保留下来了,A被RXSXT替代重新组合,但是索引数组中元素数值不能超过A。
多个数组作为索引(花式索引)
索引数组不止一个,情况如何?
两个一维数组索引二维数组
import numpy as np
a=np.array([[1,2,3],
[4,5,6]])
b=([1,1,0],[0,0,0])
print(a[b])
结果为
array([4,4,1])
相当于
a([1,1,0],[0,0,0])
等同于使用两个数组(不是一个数组),相当于
array([a[1,0],a[1,0],a[1,0]])
从第一个数组和第二个数组分别取出对应的数作为索引的第一维和第二维,这个有点复杂,但很有用,特别是在深度学习中。
N个一维数组索引M维数组(M ≥ \ge ≥N)
以下代码中,N设为3,M设为2和3
import numpy as np
a = np.array([[[111, 112, 113],
[121, 122, 123]],
[[211, 212, 213],
[221, 222, 223]]
])
b = ([1, 1, 0], [0, 0, 1])
c = ([[1, 1, 0], [1, 0, 1]],
[[0, 0, 1], [1, 0, 1]],
[[0, 1, 2], [2, 1, 2]])
print('两个一维数组索引三维数组:')
print(a[b])
print('三维个二维数组索引三维数组:')
print(a[c])
# 结果为
# 两个一维数组索引三维数组:
# [[211 212 213]
# [211 212 213]
# [121 122 123]]
# 三维个二维数组索引三维数组:
# [[211 212 123]
# [223 112 223]]
使用的索引数组越多,得到数组维度越低,但至少等于索引数组的维度。
N个多维数组索引M维数组(M ≥ \ge ≥N)
根据多维数组索引多维数组的内容推测,多维数组索引只是丰富了输出结果的组合形式,而且此维度没有上限,看代码:
import numpy as np
a = np.array([[[111, 112, 113],
[121, 122, 123]],
[[211, 212, 213],
[221, 222, 223]]
])
b = ([1, 1, 0], [0, 0, 1])
c = ([[1, 1, 0], [1, 0, 1]],
[[0, 0, 1], [1, 0, 1]],
[[0, 1, 2], [2, 1, 2]])
print('两个一维数组索引三维数组:')
print(a[b])
print('三维个二维数组索引三维数组:')
print(a[c])
结论
如果被索引数组维度AXBXCXD,索引数组RXSXT,一个索引数组得到结果数组为度为RXSXTXBXCXD,两个索引数组得到结果数组为度为RXSXTXCXD,三个索引数组得到结果数组为度为RXSXTXD,,四个索引数组得到结果数组为度为RXSXT,被索引数组的维度AXBXCXD完全消失了。注意,多个索引数组shape要一致(维度tuple要一致),元素数值不能超过被索引数组的对应维度,具体来说就是,第一个索引数组中的元素数值不能超过A,第一个索引数组中的元素数值不能超过B,第三个索引数组中的元素数值不能超过C,以此类推。