python序列的元素可以相乘吗_python – 有效地将每行的元素相乘

给定ndarray大小(n,3),n大约1000,如何将每行的所有元素相乘,快?下面的(不优雅的)第二种解决方案在大约0.3毫秒内运行,可以改进吗?

# dummy data

n = 999

a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions

def prod1(array):

return [np.prod(row) for row in array]

def prod2(array):

return [row[0]*row[1]*row[2] for row in array]

# benchmark

start = time.time()

prod1(a)

print time.time() - start

# 0.0015

start = time.time()

prod2(a)

print time.time() - start

# 0.0003

解决方法:

进一步提高绩效

一般来说,一般的经验法则.您正在使用数值数组,因此使用数组而不是列表.列表可能看起来有点像一般数组,但在后端完全不同,绝对不能用于大多数数值计算.

如果你使用Numpy-Arrays编写一个简单的代码,你可以通过简单地进行匹配来获得性能,如图所示.如果使用列表,则可以或多或少地重写代码.

import numpy as np

import numba as nb

@nb.njit(fastmath=True)

def prod(array):

assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)

res=np.empty(array.shape[0],dtype=array.dtype)

for i in range(array.shape[0]):

res[i]=array[i,0]*array[i,1]*array[i,2]

return res

使用np.prod(a,axis = 1)并不是一个坏主意,但性能并不是很好.对于只有1000×3的数组,函数调用开销非常重要.当在另一个jitted函数中使用jitted prod函数时,可以完全避免这种情况.

基准

# The first call to the jitted function takes about 200ms compilation overhead.

#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.

n=999

prod1 = 795 µs

prod2 = 187 µs

np.prod = 7.42 µs

prod 0.85 µs

n=9990

prod1 = 7863 µs

prod2 = 1810 µs

np.prod = 50.5 µs

prod 2.96 µs

标签:python,arrays,performance,numpy,multidimensional-array

来源: https://codeday.me/bug/20191002/1843997.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值