网格搜索
Grid Search
网格搜索是一种调参手段;穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果。其原理就像是在数组里找最大值。(为什么叫网格搜索?以有两个参数的模型为例,参数a有3种可能,参数b有4种可能,把所有可能性列出来,可以表示成一个3*4的表格,其中每个cell就是一个网格,循环过程就像是在每个网格里遍历、搜索,所以叫grid search)
本章我使用的还是digits数据集,调用方法接上一章。
#网格搜索,具体的格式可以上http://scikit-learn.org自行寻找。
# generates candidates
grid_param = [
{
'weights':['uniform'],
'n_neighbors':[i for i in range(1,11)]
},
{
'weights':['distance'],
'p':[i for i in range(1,6)],
'n_neighbors':[i for i in range(1,11)]
}
]
# 先new一个默认的Classifier对象
knn_clf = KNeighborsClassifier()# 调用GridSearchCV创建网格搜索对象,传入参数为Classifier对象以及参数列表
from sklearn.model_selection import GridSearchCV
grid_clf = GridSearchCV(knn_clf,grid_param)# 调用fit方法执行网格搜索
grid_clf.fit(X_train,y_train)
# 不是用户传入的参数,而是根据用户传入的参数计算出来的结果,以_结尾
# 最好的评估结果,返回的是KNeighborsClassifier对象
gri