目录
KNN算法 KNeighborsClassifier(n_neighbors)
KNN算法 KNeighborsClassifier(n_neighbors)
1.定义:如果一个样本在特征空间中的k个最相似的样本中的大多数属于某一个类别,则该样本也属于这个样本
2.缺陷:k值过小,容易受到异常点的影响 ;k值过大,受到样品不均衡的影响
3.用法:
1.实例化一个estimator 2.estimator.fit(x_train,y_train)进行计算,生成模型 3.模型评估 a.直接比对真实值和预测值: y_predict=estimator.predict(x_test) y_test==y_predict() 2.计算准确率 : accuracy=estimator.score(x_test,y_test)
4.代码实现对鸢尾花数据集进行分类:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
iris=load_iris()
#划分数据集
x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=6)
#特征工程:标准化
transfer=StandardScaler()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
#KNN算法估计
estimator=KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train,y_train) #调用完毕,模型生成
#模型评估
y_predict=estimator.predict(x_test)
print("y_predict:",y_predict)
print("直接比对真实值和预测值:\n",y_test==y_predict)
accuracy=estimator.score(x_test,y_test)
print("准确率:",accuracy)
模型选择与调优
1.交叉验证 cross validation
定义:将拿到的训练数据,分为训练和验证集,例如:将数据分为4份,其中一份作为验证集,然后进过4组调试,每次更换不同的验证集,取四组模型的结果平均值。
2.超参数搜索-网格搜索 Grid search
定义:每组超参数都采用检查验证来进行评估,最后选出最优参数组合建立模型
3.用法:
4.鸢尾花进行k值调优
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
#分类并进行网格搜索与交叉验证
iris=load_iris()
#划分数据集
x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=6)
#特征工程:标准化
transfer=StandardScaler()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
#KNN算法估计
estimator=KNeighborsClassifier()
#加入网格搜索与交叉验证
param_dict={"n_neighbors":[1,3,5,7,9,11]}
estimator=GridSearchCV(estimator,param_grid=param_dict,cv=10)
estimator.fit(x_train,y_train) #调用完毕,模型生成
#模型评估
y_predict=estimator.predict(x_test)
print("y_predict:",y_predict)
print("直接比对真实值和预测值:\n",y_test==y_predict)
accuracy=estimator.score(x_test,y_test)
print("准确率:\n",accuracy)
#结果分析
print("最佳参数:\n",estimator.best_params_)
print("最佳结果: \n",estimator.best_score_)
print("最佳估计器: \n",estimator.best_estimator_)
print("最佳结果: \n",estimator.cv_results_)