网格搜索是一种对多个参数组合遍历进行寻优的方法。通过多重for循环可以进行网格。但是当参数的数量增加的时候需要增加for的层数,不利于程序的扩展。
如有两个参数时
# L和M是存放参数的list
for i L:
for j in M:
当有三个参数时
# L和M是存放参数的list
for i L:
for j in M:
for k in P:
通过先对参数进行全排列,然后让生成的全排列转换成numpy数组再逐行遍历的方式进行网格搜索可以方便的自适应参数的数量进行网格搜索和使用numba加速。
1、使用itertools生产全排列迭代器
import itertools
itertools.product(x,y)
2、将迭代器转换为列表进而形成numpy数组
注意:numpy的formiter函数可以直接由迭代器生成数组,但是没有先转list再生成数组方便。
ls=list(itertools.product(x,y))
A_permutation=np.array(ls)
n=len(A_permutation)
3、按行遍历,并进行numba加速
使用numba加速
from numba import njit
@njit()
def grid_search(A_array,n):
for i in range(n):
a=A_array[i]
print(a)
对于大量数据可以使用并行版本,需要进行import prange
from numba import njit,prange
@njit(parallel=True)
def grid_search(A_array,n):
for i in prange(n):
a=A_array[i]
print(a)
4、整体代码
import itertools
import numpy as np
from numba import njit,prange
x=[1,2,3]
y=[4,5,6]
ls=list(itertools.product(x,y))
A_permutation=np.array(ls)
n=len(A_permutation)
@njit(parallel=True)
def grid_search(A_array,n):
for i in prange(n):
a=A_array[i]
print(a)
grid_search(A_permutation,n)