前言
说道现在最流行的语言,就不得不提python。可是python虽然容易上手,但速度却有点感人。如何用简单的方法让python加速到近乎可以媲美C的速度呢?今天来就来谈谈numba这个宝贝。对你没看错,不是numpy,就是numba。
目录
1、用函数编程
2、Numba的优势
3、如何使用numba
4、只用1行代码即可加速,对loop有奇效
5、兼容常用的科学计算包,可以创建ufunc
6、会自动调整精度,保证准确性
7、拓展
8、更多numba的加速选项
9、Numba的精度问题
10、附录
用函数编程
在面对一个计算project的时候,我们最容易想到的就是直接码代码,最后写出一个超长的程序。这样一来,一旦出错往往就需要花很多时间定位问题。
有一个简单的办法解决这个问题,就是定义各种各样的函数,把任务分解成很多小部分。因为每个函数都不是特别复杂,并且在写好的时候就可以随时检查,因此简洁的主程序一旦出问题就很容易定位并解决。面向对象编程的思想就是基于函数。
写好函数之后,还可以使用装饰器(decorator)让它变得强大。装饰器本身是一个函数,不过是函数的函数,目的是增加函数的功能。比如首先定义一个输出当前时间的函数,再定义一个规定时间格式的函数,把后一个函数作用在前一个函数上,就是一个装饰器,作用是用特定格式输出当前时间。
Numba的优势
1.简单,往往只要1行代码就有惊喜;
2.对循环(loop)有奇效,而往往在科学计算中限制python速度的就是loop;
3.兼容常用的科学计算包,如numpy、cmath等;
4.可以创建ufunc;
5.会自动调整精度,保证准确性。
如何使用numba
针对上面提到的numba的优势,我来进行逐一介绍。
首先导入numba
import numba as nb
只用1行代码即可加速,对loop有奇效
因为numba内置的函数本身是个装饰器,所以只要在自己定义好的函数前面加个@nb.jit()就行,简单上手。下面以一个求和函数为例
# 用numba加速的求和函数
@nb.jit()
def nb_sum(a):
Sum = 0
for i in range(len(a)):
Sum += a[i]
return Sum
# 没用numba加速的求和函数
def py_sum(a):
Sum = 0
for i in range(len(a)):
Sum += a[i]
return Sum
来测试一下速度
import numpy as np
a = np.linspace(0,100,100) # 创建一个长度为100的数组
%timeit np.sum(a) # numpy自带的求和函数
%timeit sum(a) # python自带的求和函数
%timeit nb_sum(a) # numba加速的求和函数
%timeit py_sum(a) # 没加速的求和函数
结果如下
# np.sum(a)
7.1 µs ± 537 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# sum(a)
27.7 µs ± 2.64 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# nb_sum(a)
1.05 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
# py_sum(a)
43.7 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
可以看出,numba甚至比号称最接近C语言速度运行的numpy还要快6倍以上。但大家都知道,numpy往往对大的数组更加友好,那我们来测试一个更长的数组
a = np.linspace(0,100,10**6) # 创建一个长度为100万的数组
测试结果如下
# np.sum(a)
2.51 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# sum(a)
249 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# nb_sum(a)
3.01 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# py_sum(a)
592 ms ± 42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
可见即便是用很长的loop来计算,numba的表现也丝毫不亚于numpy。在这里,我们可以看到numba相对于numpy一个非常明显的优势:numba可以把各种具有很大loop的函数加到很快的速度,但numpy的加速只适用于numpy自带的函数。
但要注意的是,numba对没有循环或者只有非常小循环的函数加速效果并不明显,用不用都一样。(偷偷告诉你,numba的loop甚至常常比numpy的矩阵运算还要快)
兼容常用的科学计算包,可以创建ufunc
上一部分我们比较了numba和numpy的表现,可以说numba非常亮眼了。但numpy还有一个非常强大的功能——ufunc (universal functions),它可以让一个函数同时处理很多数据。比如要求一个数组每一个元素的三角函数,只需要
np.sin(a) # 这里的a仍然是上面有100万个元素的数组
而不需要写个循环一个一个求。可如果不用numpy但又想要很快的速度,那应该怎么求呢?我们可以用从math库里导入sin,然后写个loop再用numba加速。除了这个方法,在这里我还想说numba另一个强大的功能,矢量化(vectorize)。像上面的jit一样,只要添加一行vectorize就可以让普通函数变成ufunc
from math import sin
@nb.vectorize()
def nb_vec_sin(a):
return sin(a)
来比较一下用各种方式写出的三角函数
# 用numba加速的loop
13.5 ms ± 405 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# nb_vec_sin(a)
14.2 ms ± 55.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# np.sin(a)
5.75 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
可以看出,用vectorize改写的三角函数具有了和np.sin()一样的同时处理数组每个元素的功能,而且表现也不必numpy差。当遇到numpy没有的数学函数时(比如sech),用numba矢量化不失为一个好的选择。除了math中的sin,它支持的其他函数列表可以在documentation中找到(链接见附录)。
其实numpy也有矢量化的功能,只不过比numba差远了。
会自动调整精度,保证准确性
上面我们用的测试数组数字范围只是从0到100的。可如果数字很大,那么就很容易出现overflow的问题,比如
a = np.arange(106) # a的最小值为0,最大值为106-1
你猜猜用python自带的sum,我们自己写的py_sum,np.sum和nb_sum给出的结果一不一样呢?你会发现
# np.sum(a)
1783293664
# sum(a)
1783293664
# nb_sum(a)
499999500000
# py_sum(a)
1783293664
numba的结果和其他三个都不一样,肯定错了呀,还用问么?
且慢!其实在运行的时候,我并没有告诉你sum和py_sum都报错了“RuntimeWarning: overflow encountered in long_scalars”。但奇怪的是np.sum并没有报错。
在上面的四个函数里,其实numba表现的最好,因为它自动调整了整数类型。如果你用nb.typeof()查看,你会发现numba给出的结果是int64,而其他三个都是int32。不得不说,numba不仅快还在精度方面表现很好!
拓展
在本文的最后一部分,我想谈两个问题。
更多numba的加速选项
除了上面提到的jit和vectorize,其实numba还支持很多加速类型。常见的比如
@nb.jit(nopython=True,fastmath=True) 牺牲一丢丢数学精度来提高速度
@nb.jit(nopython=True,parallel=True) 自动进行并行计算
切记一定要用nopython。默认都是True的,但有时候如果定义的函数中遇到numba支持不良好的部分,它就会自动关闭nopython模式。没有nopython的numba就好像没有武器的士兵,虽然好过没兵,但确实没什么战斗力。因此,在使用jit时候要明确写出nopython=True。如果遇到问题,就找到这些支持不良好的部分,然后改写。毕竟numba对loop非常友好,改写这些部分应当是非常容易的。
其实如何选择这些模式会对函数有最佳的加速效果,是一个玄学。我前段时间向一位精通numba的prof请教,他给我的建议是,多试试就知道有没有用了。。。另外,numba还支持多个用GPU加速的包,比如CUDA。
Numba的精度问题
精度方面,在上面我也谈到numba会自动转换数据类型以适应计算。但是在个别时候,这种自动转变类型可能会引起一些计算误差。通常这个误差是非常小的,几乎不会造成任何影响。但如果你所处理的问题会积累误差,比如求解非线性方程,那么在非常多的计算之后误差可能就是肉眼可见了。如果你发现有这样的问题,记得在jit中指定输入输出的数据类型。numba具有C所有的数据类型,比如对上面的求和函数,只需要把@nb.jit()改为@nb.jit(nb.int64(nb.int32[:]))即可。nb.int64是说输出的数字为int64类型,nb.int32是说输入的数据类型为int32,而[:]是说输入的是数组。