一、背景
写了个算法 需要多次计算两个字符串的编辑距离。这部分耗时比较多,想要进行优化。编辑距离代码见最后的附录。编辑距离本质上是双层for loop,而且是有依赖关系的for loop,因此无法并行。然后发现 numba ,介绍如下:
Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. The most common way to use Numba is through its collection of decorators that can be applied to your functions to instruct Numba to compile them. When a call is made to a Numba decorated function it is compiled to machine code “just-in-time” for execution and all or part of your code can subsequently run at native machine code speed!
本文主要介绍 nopython
模式和 parallel
模式
二、用法及其说明
2.1 @numba.njit()
用法非常简单,在函数上面使用装饰器
import numba
@numba.jit(nopython=True) # 等价于 @numba.njit()
def editing_distance(word1: str, word2: str):
'''
'''
pass
nopython=True
意味着numba对装饰器 装饰的函数进行编译成机器码,完全不使用python 解释器。 注意
- 使用
nopython=True
意味着着最佳的性能 - 如果
nopython=True
失败,可以切换另外一个模式object
。这个模式会把可以把可编译为机器码的for loop 编译,无法成功编译的则使用python 解释器运行。这个加速效果就会比较差,建议如果发现nopython=True
的跑不通,就尽量去改变代码,让代码变得更贴近pure python,进而让代码跑通。 - 使用numba加速的函数,第一次被调用的时候会进行初次编译,这个时间会比较久。计算耗时的时候不应该被记入。第一次之后的执行都会感受到numba的加速
- numba 不会对所有的for loop都有明显的加速效果。具体的对于什么for loop有比较好的加速效果需要看代码。一般来说对于 pure python的数据结构和numpy 都会有比较好的加速效果
2.2 @numba.njit(parallel=True)
对并行的代码进行加速。
level 1 : 原始
def ident_parallel(x):
return np.cos(x) ** 2 + np.sin(x) ** 2
level 2 : @numba.njit()
把python编译成机器码,加速 for loop
@numba.njit()
def ident_parallel(x):
return np.cos(x) ** 2 + np.sin(x) ** 2
level 3 : @numba.njit(parallel=True)
把python编译成机器码加速for loop ,且同时利用并行进行优化
@numba.njit(parallel=True)
def ident_parallel(x):
return np.cos(x) ** 2 + np.sin(x) ** 2
测试函数如下
if __name__=='__main__':
a = np.zeros((20000, 20000))
a_time = time.time()
ident_parallel(a)
b_time = time.time()
print(f' consuming time is {b_time-a_time}')
耗时统计如下
level 1 约 4.3 s
level 2 约 0.9 s
level 3 约 0.29 s
可以依次看到 njit
的速度提升以及 njit+parallel
的速度提升
三、numba为什么快
https://numba.readthedocs.io/en/stable/user/5minguide.html#how-does-numba-work
Numba reads the Python bytecode for a decorated function and combines this with information about the types of the input arguments to the function. It analyzes and optimizes your code, and finally uses the LLVM compiler library to generate a machine code version of your function, tailored to your CPU capabilities. This compiled version is then used every time your function is called.
提前获取Python的输入参数类型,把Python的字节码转化成机器码,转化的过程有针对性的优化。每次使用这个函数的时候 直接使用的是机器码,而无需重新编译
附录
编辑距离
def editing_distance(word1: str, word2: str):
'''
两个字符串的编辑距离.
'''
if len(word1)==0:
return len(word2)
if len(word2)==0:
return len(word1)
size1 = len(word1)
size2 = len(word2)
last = 0
tmp = list(range(size2 + 1))
value = None
for i in range(size1):
tmp[0] = i + 1
last = i
for j in range(size2):
if word1[i] == word2[j]:
value = last
else:
value = 1 + min(last, tmp[j], tmp[j + 1])
last = tmp[j+1]
tmp[j+1] = value
return value