有一些相关的方法,分为两个阵营。您可以通过计算单个布尔数组和^{}来使用向量化方法。或者,您可以通过for循环或带有生成器表达式的0元素计算第一行的索引。在
为了提高性能,我建议您将^{}与手动for循环一起使用。这里有一个例子,但请参见下面的基准测试,以获得更有效的变体:from numba import jit
@jit(nopython=True)
def trim_enum_nb(A):
for idx in range(A.shape[0]):
if (A[idx]==0).all():
break
return A[:idx]
绩效基准
^{pr2}$
测试代码
设置import numpy as np
from numba import jit
np.random.seed(0)
n = 120000
k = 1500
A = np.random.randint(1, 10, (n, 3))
A[k:, :] = 0
功能def trim_enum_loop(A):
for idx, row in enumerate(A):
if (row==0).all():
break
return A[:idx]
@jit(nopython=True)
def trim_enum_nb(A):
for idx in range(A.shape[0]):
if (A[idx]==0).all():
break
return A[:idx]
@jit(nopython=True)
def trim_enum_nb2(A):
for idx in range(A.shape[0]):
res = False
for col in range(A.shape[1]):
res |= A[idx, col]
if res:
break
return A[:idx]
def trim_enum_gen(A):
idx = next(idx for idx, row in enumerate(A) if (row==0).all())
return A[:idx]
def trim_vect(A):
idx = np.where((A == 0).all(1))[0][0]
return A[:idx]
def trim_searchsorted(A):
B = np.frombuffer(A, 'S12')
idx = A.shape[0] - np.searchsorted(B[::-1], B[-1:], 'right')[0]
return A[:idx]
检查# check all results are the same
assert (trim_vect(A) == trim_enum_loop(A)).all()
assert (trim_vect(A) == trim_enum_nb(A)).all()
assert (trim_vect(A) == trim_enum_nb2(A)).all()
assert (trim_vect(A) == trim_enum_gen(A)).all()
assert (trim_vect(A) == trim_searchsorted(A)).all()