整理一下前阶段复习的关于网格搜索的知识:
程序及数据 请到github 上 下载 GridSearch练习
网格搜索是将训练集训练的一堆模型中,选取超参数的所有值(或者代表性的几个值),将这些选取的参数及值全部列出一个表格,并分别将其进行模拟,选出最优模型。
上面是数据集的可视化分布图,具体代码如下:
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
data = pd.read_csv('data/grid.csv',header=None)
X = data[[0,1]]
y = data[2]
#print(y)
X_blue = data[data[2]== 0]
X_red = data[data[2]== 1]
plt.scatter(X_blue[0],X_blue[1],c='blue',edgecolor='k',s=50)
plt.scatter(X_red[0],X_red[1],c='red',edgecolor='k',s=50)
plt.xlim(-2.05,2.05)
plt.ylim(-2.05,2.05)
采用决策树来训练数据
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train,y_train)
train_predictions = clf.predict(X_train)
test_predictions = clf.predict(X_test)
print(clf.get_params())
数据分类可视化自定义函数的代码如下:
def plot_model(X, y, clf):
plt.scatter(X_blue[0],X_blue[1],c='blue',edgecolor='k',s=50)
plt.scatter(X_red[0],X_red[1],c='red',edgecolor='k',s=50)