python: numba 加速python的 for loop

一、背景

写了个算法 需要多次计算两个字符串的编辑距离。这部分耗时比较多,想要进行优化。编辑距离代码见最后的附录。编辑距离本质上是双层for loop,而且是有依赖关系的for loop,因此无法并行。然后发现 numba ,介绍如下:
Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. The most common way to use Numba is through its collection of decorators that can be applied to your functions to instruct Numba to compile them. When a call is made to a Numba decorated function it is compiled to machine code “just-in-time” for execution and all or part of your code can subsequently run at native machine code speed!

本文主要介绍 nopython模式和 parallel模式

二、用法及其说明

2.1 @numba.njit()

用法非常简单,在函数上面使用装饰器

import numba

@numba.jit(nopython=True) # 等价于 @numba.njit()
def editing_distance(word1: str, word2: str):
	'''
	
	'''
	pass
	

nopython=True 意味着numba对装饰器 装饰的函数进行编译成机器码,完全不使用python 解释器。 注意

  1. 使用nopython=True 意味着着最佳的性能
  2. 如果nopython=True 失败,可以切换另外一个模式object 。这个模式会把可以把可编译为机器码的for loop 编译,无法成功编译的则使用python 解释器运行。这个加速效果就会比较差,建议如果发现nopython=True 的跑不通,就尽量去改变代码,让代码变得更贴近pure python,进而让代码跑通。
  3. 使用numba加速的函数,第一次被调用的时候会进行初次编译,这个时间会比较久。计算耗时的时候不应该被记入。第一次之后的执行都会感受到numba的加速
  4. numba 不会对所有的for loop都有明显的加速效果。具体的对于什么for loop有比较好的加速效果需要看代码。一般来说对于 pure python的数据结构和numpy 都会有比较好的加速效果

2.2 @numba.njit(parallel=True)

对并行的代码进行加速。
level 1 : 原始

def ident_parallel(x):
    return np.cos(x) ** 2 + np.sin(x) ** 2

level 2 : @numba.njit() 把python编译成机器码,加速 for loop

@numba.njit()
def ident_parallel(x):
    return np.cos(x) ** 2 + np.sin(x) ** 2

level 3 : @numba.njit(parallel=True) 把python编译成机器码加速for loop ,且同时利用并行进行优化

@numba.njit(parallel=True)
def ident_parallel(x):
    return np.cos(x) ** 2 + np.sin(x) ** 2

测试函数如下

if __name__=='__main__':
    a = np.zeros((20000, 20000))
    a_time = time.time()
    ident_parallel(a)
    b_time = time.time()
    print(f' consuming time is {b_time-a_time}')

耗时统计如下

level 14.3 s
level 20.9 s
level 30.29 s

可以依次看到 njit 的速度提升以及 njit+parallel 的速度提升

三、numba为什么快

https://numba.readthedocs.io/en/stable/user/5minguide.html#how-does-numba-work

Numba reads the Python bytecode for a decorated function and combines this with information about the types of the input arguments to the function. It analyzes and optimizes your code, and finally uses the LLVM compiler library to generate a machine code version of your function, tailored to your CPU capabilities. This compiled version is then used every time your function is called.

提前获取Python的输入参数类型,把Python的字节码转化成机器码,转化的过程有针对性的优化。每次使用这个函数的时候 直接使用的是机器码,而无需重新编译

附录

编辑距离

def editing_distance(word1: str, word2: str):
    '''
        两个字符串的编辑距离. 
    '''
    if len(word1)==0:
        return len(word2)

    if len(word2)==0:
        return len(word1)

    size1 = len(word1)
    size2 = len(word2)

    last = 0
    tmp = list(range(size2 + 1))
    value = None

    for i in range(size1):
        tmp[0] = i + 1

        last = i

        for j in range(size2):
            if word1[i] == word2[j]:
                value = last
            else:
                value = 1 + min(last, tmp[j], tmp[j + 1])

            last = tmp[j+1]
            tmp[j+1] = value

    return value
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值