这是numpy数组比较的问题
情况一:返回的是数组
import numpy as np
a = np.array([1, 2, 3])
b = np.array([1, 5, 6])
if a == b:
pass
因为a==b
的结果是[True False False]
。
解决方案是:
.any()
:只要有一个位置的元素True
就True
。.all()
:每个位置的元素都True
才True
。
print((a == b).any()) # True
print((a == b).all()) # False
情况二:list(numpy数组)
import numpy as np
# 纯numpy没问题
a = np.array([1, 2])
b = [1, 2]
print(a == b) # [ True True ]
# 纯numpy没问题
c = np.array([[1, 2], [3, 4]])
d = [[1, 2], [3, 4]]
print(c == d) # [[ True True], [ True True]]
# list(numpy)就不行
n1 = np.array([1, 2])
n2 = np.array([3, 4])
n_list_1 = [n1]
n_list_2 = [n1, n2]
print(n_list_1 == b) # ValueError
print(n_list_2 == d) # ValueError
解法:将list(numpy数组)
转化为纯Numpy
print(np.array(n_list_1) == b) # [ True True ]
print(np.array(n_list_2) == d) # [[ True True], [ True True]]