布尔索引的工作原理是,NumPy将接收的布尔数组中的True当做有效的索引位置,并从原数组中选取这些位置对应的元素。
arr = np.arange(12).reshape(3, 4) # arr = [[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]] bool_idx = np.array([[True, False, True, False], [True, False, True, False] [True, False, True, False]]) print(arr[bool_idx]) #arr = [[0 2] # [4 6] # [8 10]]
也就是说,布尔索引就是当数组接收到一个布尔数组时,会将内部的布尔数组中的元素和自己的内部元素一一对应。
但在当数组大小不匹配时,NumPy有自动广播(broadcasting)的机制来处理这种情况。
具体来说,如果布尔数组的shape比原数组的shape小,NumPy会自动对布尔数组的shape进行扩展,使其与原数组的shape相匹配。扩展规则是:在布尔数组的shape前面补1,直到其shape与原数组的shape一致为止。
import numpy as np arr = np.arange(12).reshape(3, 4) # arr = [[0 1 2 3] # [4 5 6 7] # [8 9 10 11]] # 一维布尔索引 bool_idx = np.array([True, False, True]) #[3, 0] print(arr[bool_idx]) # 广播:[3,0] --> [3,4] # 广播后: #[[True, True, True, True], # [False, False, False, False], # [True, True, True, True]] # 输出: [[0 1 2 3] # [8 9 10 11]] # 二维布尔索引 bool_idx_2d = np.array([[True, False], # [3,2] [False, True], [True, True]]) print(arr[bool_idx_2d]) # 广播:[3,2] --> [3,4] 将第二维*2 # 广播后: #[[True, False, True, False], # [False, True, False, True], # [True, True, True, True]] # 输出: [[ 0 2] # [ 4 7] # [ 8 9 10 11]]