Python的加速模块numba

关于numba的介绍有很多,就是一个可以把大量重复代码即时编译为机器码来加快程序运行速度的库。优点是快,方便,但缺陷也很明显,如很多类型不兼容,使用时不太灵活,必须把f方法内包含的所有方法加上装饰器,在数值计算量小时反而会减慢速度等等...

import librosa
from numba import njit
import time

path = "/Users/birenjianmo/Desktop/learn/librosa/input/1你好.wav"
y, sr = librosa.load(path)
yy = y.tolist()

def logtime(f):
    def setf(*args, **kw):
        number = 100
        start = time.time()
        for i in range(number):
            r = f(*args, **kw)
        end = time.time()
        print( "%s 耗时:%s s" %  (f.__name__, end - start/number) )
        return r
    return setf

 
def gettopdata(y):
    data = []
    for i in range(1,len(y)-1):
        if y[i]>0 and y[i-1]<y[i]<y[i+1]:
            data.append(i)
    return data
 

@logtime
@njit(fastmath=True)
def findall_numba(a):
    result = []
    data = []
    for i in range(1,len(a)):
        data.append( a[i]-a[i-1] )
    
    for pos in range(5, len(data)//3):
        for i in range(len(data)-pos):
            if abs(data[i]  - data[i+pos]) < abs(data[i]) * 0.05:
                getp = True
                p = [a[i]]
                r = 1
    
                while r < (len(data)-i)//pos:
                    
                    for j in range(pos):
                        if abs(data[i+j] - data[i+pos*r+j]) > abs(data[i+j]) * 0.05:
                            getp = False
                            break
    
                    if not getp:
                        break
                    else:
                        p.append(a[i+pos*r])
                        r += 1
    
                if len(p) > 4:
                    result.append(p)
 
    return result

@logtime
def findall(a):
    result = []
    data = []
    for i in range(1,len(a)):
        data.append( a[i]-a[i-1] )
    
    for pos in range(5, len(data)//3):
        for i in range(len(data)-pos):
            if abs(data[i]  - data[i+pos]) < abs(data[i]) * 0.05:
                getp = True
                p = [a[i]]
                r = 1
    
                while r < (len(data)-i)//pos:
                    
                    for j in range(pos):
                        if abs(data[i+j] - data[i+pos*r+j]) > abs(data[i+j]) * 0.05:
                            getp = False
                            break
    
                    if not getp:
                        break
                    else:
                        p.append(a[i+pos*r])
                        r += 1
    
                if len(p) > 4:
                    result.append(p)
 
    return result


if __name__ == '__main__':
    a = gettopdata(yy)
    findall(a)
    findall_numba(a)

运行100次的平均时间 ,使用numba可以提升差不多7倍的速度

findall 耗时:1.180213589668274 s
findall_numba 耗时:0.17544831991195678 s

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值