from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from matplotlib import pyplot as plt
iris = datasets.load_iris() # 导入数据
x = iris.data # 特征
y = iris.target # 标签
r_range = range(1, 30) # n_neighbors的取值范围
r_list = list()
for r in r_range:
knn = KNeighborsClassifier(n_neighbors=r)
scores = cross_val_score(knn, x, y, cv=5, scoring="accuracy")
r_list.append(scores.mean()) # 5次交叉验证的平均值
plt.plot(r_range, r_list)
plt.show()
结果,可以看到n_neighbors的值在6到12之间比较合适: