给定ndarray大小(n,3),n大约1000,如何将每行的所有元素相乘,快?下面的(不优雅的)第二种解决方案在大约0.3毫秒内运行,可以改进吗?
# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)
# two solutions
def prod1(array):
return [np.prod(row) for row in array]
def prod2(array):
return [row[0]*row[1]*row[2] for row in array]
# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015
start = time.time()
prod2(a)
print time.time() - start
# 0.0003
解决方法:
进一步提高绩效
一般来说,一般的经验法则.您正在使用数值数组,因此使用数组而不是列表.列表可能看起来有点像一般数组,但在后端完全不同,绝对不能用于大多数数值计算.
如果你使用Numpy-Arrays编写一个简单的代码,你可以通过简单地进行匹配来获得性能,如图所示.如果使用列表,则可以或多或少地重写代码.
import numpy as np
import numba as nb
@nb.njit(fastmath=True)
def prod(array):
assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
res=np.empty(array.shape[0],dtype=array.dtype)
for i in range(array.shape[0]):
res[i]=array[i,0]*array[i,1]*array[i,2]
return res
使用np.prod(a,axis = 1)并不是一个坏主意,但性能并不是很好.对于只有1000×3的数组,函数调用开销非常重要.当在另一个jitted函数中使用jitted prod函数时,可以完全避免这种情况.
基准
# The first call to the jitted function takes about 200ms compilation overhead.
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1 = 795 µs
prod2 = 187 µs
np.prod = 7.42 µs
prod 0.85 µs
n=9990
prod1 = 7863 µs
prod2 = 1810 µs
np.prod = 50.5 µs
prod 2.96 µs
标签:python,arrays,performance,numpy,multidimensional-array
来源: https://codeday.me/bug/20191002/1843997.html