不要尝试对不容易矢量化的循环进行矢量化。相反,使用Numba之类的jit编译器或使用Cython。如果生成的代码更具可读性,那么矢量化解决方案是很好的,但是就性能而言,编译的解决方案通常更快,或者在最坏的情况下与矢量化解决方案(BLAS例程除外)一样快。在
单线程示例import numba as nb
import numpy as np
#Min and max library calls may be costly for only 3 values
@nb.njit()
def max_min_3(A,B,C):
max_of_min=-np.inf
for i in range(A.shape[0]):
loc_min=A[i]
if (B[i]
loc_min=B[i]
if (C[i]
loc_min=C[i]
if (max_of_min
max_of_min=loc_min
return max_of_min
@nb.njit()
def your_func(A):
n=A.shape[0]
save_rows=np.zeros(3,dtype=np.uint64)
global_best=np.inf
for i in range(n):
for j in range(i+1, n):
for k in range(j+1, n):
# find the maximum of the element-wise minimum of the three vectors
local_best = max_min_3(A[i,:], A[j,:], A[k,:])
# if local_best is lower than global_best, update global_best
if (local_best < global_best):
global_best = local_best
save_rows[0] = i
save_rows[1] = j
save_rows[2] = k
return global_best, save_rows
单线程版本的性能
^{pr2}$
第一个调用的持续开销约为0.3-1s。对于计算时间本身的性能度量,请调用它一次,然后测量性能。在
通过一些代码更改,这个任务也可以并行化。在
多线程示例@nb.njit(parallel=True)
def your_func(A):
n=A.shape[0]
all_global_best=np.inf
rows=np.empty((3),dtype=np.uint64)
save_rows=np.empty((n,3),dtype=np.uint64)
global_best_Temp=np.empty((n),dtype=A.dtype)
global_best_Temp[:]=np.inf
for i in range(n):
for j in nb.prange(i+1, n):
row_1=0
row_2=0
row_3=0
global_best=np.inf
for k in range(j+1, n):
# find the maximum of the element-wise minimum of the three vectors
local_best = max_min_3(A[i,:], A[j,:], A[k,:])
# if local_best is lower than global_best, update global_best
if (local_best < global_best):
global_best = local_best
row_1 = i
row_2 = j
row_3 = k
save_rows[j,0]=row_1
save_rows[j,1]=row_2
save_rows[j,2]=row_3
global_best_Temp[j]=global_best
ind=np.argmin(global_best_Temp)
if (global_best_Temp[ind]
rows[0] = save_rows[ind,0]
rows[1] = save_rows[ind,1]
rows[2] = save_rows[ind,2]
all_global_best=global_best_Temp[ind]
return all_global_best, rows
多线程版本的性能n=100
your_version: 1.56s
compiled_version: 0.0078s (200x speedup)
n=150
your_version: 5.41s
compiled_version: 0.0282s (191x speedup)
n=500
your_version: 283s
compiled_version: 2.95s (96x speedup)
编辑
在一个更新的Numba版本(通过Anaconda Python发行版安装)中,我必须手动安装^{}以获得有效的并行化。在