读代码时候遇到numpy.where(),费了半天劲,终于理解了,分享一下。
格式
numpy.where(condition[, x, y])
参数
condition : array_like, bool
if conditon == True:
取当前位置的x的值
else:
取当前位置的y的值
x, y : array_like, optional,x 和y 与condition尺寸相同
返回值
返回一个数组,或者由数组组成的元组。
根据定义条件返回元素,这些元素或者从x中获得,或者从y中获得。
如果只给出条件,没有给出[,x, y],返回条件中非零(True)元素的坐标。
实例理解
>>> np.where([[True, False], [True, True]],
... [[1, 2], [3, 4]],
... [[9, 8], [7, 6]])
array([[1, 8],
[3, 4]]) # True时从x取值,False时从y取值
>>> np.where([[0, 1], [1, 0]]) # 只有condition,返回非零值的坐标
(array([0, 1]), array([1, 0]))
>>> np.where([[0,1],[1,2]])
(array([0,1,1]), array([1,0,1]))
>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))
>>> x[np.where( x > 3.0 )] # Note: result is 1D.
array([ 4., 5., 6., 7., 8.])
>>> np.where(x < 5, x, -1) # 值替换
array([[ 0., 1., 2.],
[ 3., 4., -1.],
[-1., -1., -1.]])
参考
https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html
http://www.bubuko.com/infodetail-1858534.html
Bonne courage!