numba加速python
使用numba库加速python的for循环,时间确实快很多
安装
pip install numba
需要注意numba的版本与numpy进行匹配
numba原理
网上有很多解释的博文,不在详细介绍 ,比如链接
应用和对比
主要任务是有两组bboxs,进行匹配,计算iou阈值,如果直接用for循环,需要写两个for循环,时间复杂度还是很大的
因此尝试使用numba对循环进行加速,具体代码如下:
from numba import jit
import numpy as np
from scipy.spatial import KDTree
import time
from tqdm import tqdm
# 计算iou
@jit(nopython=True)
def bbox_iou(box1, box2):
x1, y1, x2, y2 = box1
x1_, y1_, x2_, y2_ = box2
inter_x1 = max(x1, x1_)
inter_y1 = max(y1, y1_)
inter_x2 = min(x2, x2_)
inter_y2 = min(y2, y2_)
inter_w = max(0, inter_x2 - inter_x1)
inter_h = max(0, inter_y2 - inter_y1)
union = (x2 - x1) * (y2 - y1) + (x2_ - x1_) * (y2_ - y1_) - inter_w * inter_h
iou = inter_w * inter_h / union
return iou
# 定义两组bbox
bbox1 = np.random.rand(10000, 4)
bbox2 = np.random.rand(5000, 4)
# 利用for循环计算匹配
start_time = time.time()
matches = []
for i in tqdm(range(bbox1.shape[0])):
max_iou = 0
match_idx = -1
for j in range(bbox2.shape[0]):
iou = bbox_iou(bbox1[i], bbox2[j])
if iou > max_iou:
max_iou = iou
match_idx = j
matches.append(match_idx)
print("Matching with for loop takes: %.3f s" % (time.time() - start_time))
@jit(nopython=True) #
def match_loop(bbox1,bbox2):
matches = []
for i in range(bbox1.shape[0]):
max_iou = 0
match_idx = -1
for j in range(bbox2.shape[0]):
iou = bbox_iou(bbox1[i], bbox2[j])
if iou > max_iou:
max_iou = iou
match_idx = j
matches.append(match_idx)
start_time = time.time()
match_loop(bbox1,bbox2)
print("Matching with for loop takes: %.3f s" % (time.time() - start_time))
需要注意的是: numba加速的函数,即@jit()修饰的函数内,如果调用了其他函数,那么这个函数也应该被@jit()修饰,比如bbox_iou()这个函数,也应该被@jit()修饰。原来的tqdm()就不能出现在函数内部了,不然会报错
对比结果:
Matching with for loop takes: 312.427 s
Matching with for loop takes: 1.537 s
第一个是for循环计算的结果,第二个是numba加速后的结果,确实加速了很多,很不错