numpy.where(condition[, x, y])
函数返回输入数组中满足给定条件的元素的索引。
这个函数有两种形式:
np.where(condition)
:返回输入数组中满足条件的元素的索引,作为一个元组,其中第一个元素是满足条件的元素的行索引,第二个元素是满足条件的元素的列索引。np.where(condition, x, y)
:返回满足条件的元素 x 和不满足条件的元素 y。
示例:
import numpy as np
# 一维数组
a = np.array([1, 2, 3, 4])
# 返回输入数组中大于 2 的元素的索引
indexes = np.where(a > 2)
print(indexes) # 输出 (array([2, 3]),)
# 返回满足条件的元素 3 和不满足条件的元素 0
b = np.where(a > 2, a, 0)
print(b) # 输出 [0 0 3 4]
# 二维数组
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 返回输入数组中大于 5 的元素的索引
indexes = np.where(a > 5)
print(indexes) # 输出 (array([1, 1, 2, 2, 2]), array([1, 2, 0, 1, 2]))
# 返回满足条件的元素 6、8、9 和不满足条件的元素 0
b = np.where(a > 5, a, 0)
print(b)
# 输出
# [[0 0 0]
# [0 6 0]
# [0 8 9]]