import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score #交叉验证
X,y = datasets.load_iris(True)# 演示了交叉验证如何使用
knn = KNeighborsClassifier()
score = cross_val_score(knn,X,y,scoring='accuracy',cv=6)print(score.mean())
error =[]for k inrange(1,14):
knn = KNeighborsClassifier(n_neighbors=k)
score = cross_val_score(knn,X,y,scoring='accuracy',cv=6).mean()#误差越小。K选择越合适
error.append(1- score)import matplotlib.pyplot as plt
plt.plot(np.arange(1,14),error)
plt.show()
weights =['uniform','distance']for w in weights:
knn = KNeighborsClassifier(n_neighbors=11,weights=w)print(cross_val_score(knn,X,y,scoring='accuracy',cv=6).mean())
result ={}for k inrange(1,14):for w in weights:
knn = KNeighborsClassifier(n_neighbors=k, weights=w)
sm = cross_val_score(knn,X,y,scoring='accuracy',cv=6).mean()
result[w+str(k)]= sm
result
# 找到最合适的参数print(np.array(list(result.values())).argmax())
二.实例中学到的方法【片段】
import numpy as np
import pandas as pd
from pandas import Series,DataFrame
from sklearn.neighbors import KNeighborsClassifier
cancer = pd.read_csv('./cancer.csv',sep='\t')# 读取数据
cancer.drop('ID',axis=1,inplace=True)#删除一行数据# 将列表中str类型数据转换为int
cols =['relationship','race']for col in cols:
u= X[col].unique()defconvert(x)return np.argwhere(u==x)[0,0]
X[col]=X[col].map(convert)# 数据划分
knn = KNeighborsClassifier()
kFold =kFold(10)
knn = KNeighborsClassifier()
accuracy =0for train,test in kFold.split(X,y):
knn.fit(X.loc[train],y[train])
acc = knn.score(X.loc[test],y[test])
accuracy += acc/10print(accuracy)
三.将列表中str类型数据转换为int(preprocessing方法)
import numpy as np
import pandas as pd
from pandas import Series,DataFrame
from sklearn.preprocessing import OrdinalEncoder ,OneHotEncoder,LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
# LabelEncoder(Series) 和OrdinalEncode(DataFrame)类似
labelEncode = LabelEncoder()
salary_label =labelEncode.fit_transform(salary['salary'])for col in salary.columns:
salary[col]= labelEncode.fit_transform(salary[col])print(salary.head())