我会提名
a.argmax()
用@ fuglede的测试数组:
In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999
In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop
In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop
我没有安装numba,所以可以比较.但是我的加速相对于short比@ fuglede的6倍大.
我在Py3中测试,它接受< np.nan,而Py2引发运行时警告.但代码搜索表明这不是依赖于该比较. /numpy/core/src/multiarray/calculation.c PyArray_ArgMax与轴一起播放(将感兴趣的一个移动到最后),并将该操作委派给arg_func = PyArray_DESCR(ap) – > f-> argmax,一个函数取决于dtype.
在numpy / core / src / multiarray / arraytypes.c.src中,它看起来像BOOL_argmax短路,一遇到True就返回.
for (; i < n; i++) {
if (ip[i]) {
*max_ind = i;
return 0;
}
}
而@ fname @ _argmax也是最大的nan的短路. argn中np.nan是’maximal’.
#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered; it's maximal */
return 0;
}
#endif
欢迎有经验的c编者的评论,但是在我看来,至少对于np.nan来说,一个简单的argmax将会很快,我们可以得到.
使用9999生成一个表示a.argmax时间取决于该值,与短路一致.