经常能看到numpy.take与numpy.argwhere在一起使用,索性将二者的功能一起记录如下。
np.take
np.take(ori_array,indices_array,axis = None,out = None,mode ='raise')
沿轴取数组ori_array(它可以是python数组或者np数组)中的元素,返回结果和indices_array的形状相同。indices_array的每个元素值是指ori_array的所有元素转换为1维元素对应的下标。
示例:
>>a = np.arange(3,12).reshape(3,3)
>>a
array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
>>np.take(a, [1,2,3])
array([4, 5, 6])
>>np.take(a, [[1,2],[3,4]])
array([[4, 5],
[6, 7]])
np.argwhere
np.argwhere(a)
返回a中所有非零元素的下标。
参数:
- a : 一个数组, 可以是python数组或numpy数组。通常就是一个bool值(0和1)的矩阵
返回:
- index_array: 一个numpy的二维数组,形状是(N, a.ndim)。N是a中非零元素的数目, a.ndim就是a的总维度。比如a的形状是(2,3,3),那么index_array的形状就是(N, 3)
示例:
>>x = np.arange(6).reshape(2,3)
>>x
array([[0, 1, 2],
[3, 4, 5]])
>>np.argwhere(x>1)
array([[0, 2],
[1, 0],
[1, 1],
[1, 2]])