python numba_Numba代码比纯python慢

我一直在加速粒子过滤器的重采样计算。由于python有很多加速的方法,我想我会尝试所有的方法。不幸的是,numba版本非常慢。由于Numba会导致加速,我认为这是我的一个错误。

我尝试了4种不同的版本:麻木

Python

纽比

赛松

其代码如下:import numpy as np

import scipy as sp

import numba as nb

from cython_resample import cython_resample

@nb.autojit

def numba_resample(qs, xs, rands):

n = qs.shape[0]

lookup = np.cumsum(qs)

results = np.empty(n)

for j in range(n):

for i in range(n):

if rands[j] < lookup[i]:

results[j] = xs[i]

break

return results

def python_resample(qs, xs, rands):

n = qs.shape[0]

lookup = np.cumsum(qs)

results = np.empty(n)

for j in range(n):

for i in range(n):

if rands[j] < lookup[i]:

results[j] = xs[i]

break

return results

def numpy_resample(qs, xs, rands):

results = np.empty_like(qs)

lookup = sp.cumsum(qs)

for j, key in enumerate(rands):

i = sp.argmax(lookup>key)

results[j] = xs[i]

return results

#The following is the code for the cython module. It was compiled in a

#separate file, but is included here to aid in the question.

"""

import numpy as np

cimport numpy as np

cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)

def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs,

np.ndarray[DTYPE_t, ndim=1] xs,

np.ndarray[DTYPE_t, ndim=1] rands):

if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:

raise ValueError("Arrays must have same shape")

assert qs.dtype == xs.dtype == rands.dtype == DTYPE

cdef unsigned int n = qs.shape[0]

cdef unsigned int i, j

cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)

cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

for j in range(n):

for i in range(n):

if rands[j] < lookup[i]:

results[j] = xs[i]

break

return results

"""

if __name__ == '__main__':

n = 100

xs = np.arange(n, dtype=np.float64)

qs = np.array([1.0/n,]*n)

rands = np.random.rand(n)

print "Timing Numba Function:"

%timeit numba_resample(qs, xs, rands)

print "Timing Python Function:"

%timeit python_resample(qs, xs, rands)

print "Timing Numpy Function:"

%timeit numpy_resample(qs, xs, rands)

print "Timing Cython Function:"

%timeit cython_resample(qs, xs, rands)

这将产生以下输出:Timing Numba Function:

1 loops, best of 3: 8.23 ms per loop

Timing Python Function:

100 loops, best of 3: 2.48 ms per loop

Timing Numpy Function:

1000 loops, best of 3: 793 µs per loop

Timing Cython Function:

10000 loops, best of 3: 25 µs per loop

知道为什么numba代码这么慢吗?我认为它至少可以与纽比相提并论。

注意:如果有人对如何加快Numpy或Cython代码示例有任何想法,那也不错:)不过,我的主要问题是Numba。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值