在使用Numba JIT编译器加速代码计算时遇到一个大坑!
具体demo如下:
不使用JIT加速时:
输入:
from numba import jit
def fun1(**worda):
print(worda)
fun2(**worda)
def fun2(**wordb):
print(wordb)
fun3(**wordb)
def fun3(**wordc):
print(wordc)
if __name__=="__main__":
word = {'a':1, 'b':2, 'c':3}
fun1(**word)
输出:
>>>
{'a': 1, 'b': 2, 'c': 3}
{'a': 1, 'b': 2, 'c': 3}
{'a': 1, 'b': 2, 'c': 3}
使用JIT加速时:
输入:
from numba import jit
def fun1(**worda):
print(worda)
fun2(**worda)
def fun2(**wordb):
print(wordb)
fun3(**wordb)
@jit
def fun3(**wordc):
print(wordc)
if __name__=="__main__":
word = {'a':1, 'b':2, 'c':3}
fun1(**word)
报错:
>>>
Traceback (most recent call last):
File "xx.py", line 18, in <module>
fun1(**word)
File "xx.py", line 5, in fun1
fun2(**worda)
File "xx.py", line 9, in fun2
fun3(**wordb)
TypeError: too many arguments: expected 1, got 3
错误原因:
找了一万年也没想到时JIT加速器出错,最后把具体问题发给ChatGPT,以下是他给出的原因和解决方法:
Numba不支持在JIT编译函数中使用可变关键字参数,因此需要通过其他方式传递关键字参数。可以将关键字参数打包为字典,然后将字典作为参数传递。
以下是修改后的代码:
from numba import jit def fun1(worda): print(worda) fun2(worda) def fun2(wordb): print(wordb) fun3(wordb) @jit def fun3(wordc): print(wordc) if __name__=="__main__": word = {'a':1, 'b':2, 'c':3} fun1(word)