一、概念:
最近邻 (k-Nearest Neighbors, KNN) 算法是一种分类算法, 1968年由 Cover和 Hart 提出, 应用场景有字符识别、 文本分类、 图像识别等领域。
核心思想: 一个样本与数据集中的k个样本最相似, 如果这k个样本中的大多数属于某一个类别, 则该样本也属于这个类别。
——距离度量:
在选择两个实例相似性时,一般使用的欧式距离,又称之为欧几里得度量,它定义于欧几里得空间中。n维空间中两个点x1(x11,x12,…,x1n)与 x2(x21,x22,…,x2n)间的欧氏距离:
——k值选择
如果选择较小的K值,就相当于用较小的邻域中的训练实例进行预测,学习的近似误差会减小,只有与输入实例较近的训练实例才会对预测结果起作用,但学习的估计误差会增大,预测结果会对近邻的实例点分成敏感。如果邻近的实例点恰巧是噪声,预测就会出错。K值减小就意味着整体模型变复杂,分的不清楚,就容易发生过拟合。
如果选择较大K值,就相当于用较大邻域中的训练实例进行预测,其优点是可以减少学习的估计误差,但近似误差会增大,也就是对输入实例预测不准确。
简而言之,K值减小就意味着整体模型变复杂,分的不清楚,就容易发生过拟合。K值得增大就意味着整体模型变的简单。
在实际应用中,K值一般取一个比较小的数值,通常采用交叉验证法来选取最优的K值。
——流程:
1) 计算已知类别数据集中的点与当前点之间的距离
2) 按距离递增次序排序
3) 选取与当前点距离最小的k个点
4) 统计前k个点所在的类别出现的频率
5) 返回前k个点出现频率最高的类别作为当前点的预测分类
——优点:
1、简单有效
2、重新训练代价低
3、算法复杂度低
4、适合类域交叉样本
5、适用大样本自动分类
——缺点:
1、惰性学习
2、类别分类不标准化
3、输出可解释性不强
4、不均衡性
5、计算量较大
概念参考:https://blog.csdn.net/sinat_30353259/article/details/80901746
二、代码实现(丁香花数据集):
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline
lilac_data = pd.read_csv("Desktop/ML/syringa.csv")
lilac_data.head()
#绘制丁香花子图特征
fig,axes = plt.subplots(2,3,figsize=(20,10))
fig.subplots_adjust(hspace=0.3,wspace=0.2)
axes[0,0].set_xlabel("sepal_length")
axes[0,0].set_ylabel("sepal_width")
axes[0,0].scatter(lilac_data.sepal_length[:50],lilac_data.sepal_width[:50],c="b")
axes[0,0].scatter(lilac_data.sepal_length[50:100],lilac_data.sepal_width[50:100],c="g")
axes[0,0].scatter(lilac_data.sepal_length[100:],lilac_data.sepal_width[100:],c="r")
axes[0,0].legend(["daphne","syinga","willow"],loc=2)
axes[0,1].set_xlabel("sepal_length")
axes[0,1].set_ylabel("petal_length")
axes[0,1].scatter(lilac_data.sepal_length[:50],lilac_data.petal_length[:50],c="b")
axes[0,1].scatter(lilac_data.sepal_length[50:100],lilac_data.petal_length[50:100],c="g")
axes[0,1].scatter(lilac_data.sepal_length[100:],lilac_data.petal_length[100:],c="r")
axes[0,2].set_xlabel("sepal_length")
axes[0,2].set_ylabel("petal_width")
axes[0,2].scatter(lilac_data.sepal_length[:50],lilac_data.petal_width[:50],c="b")
axes[0,2].scatter(lilac_data.sepal_length[50:100],lilac_data.petal_width[50:100],c="g")
axes[0,2].scatter(lilac_data.sepal_length[100:],lilac_data.petal_width[100:],c="r")
axes[1,0].set_xlabel("sepal_width")
axes[1,0].set_ylabel("petal_width")
axes[1,0].scatter(lilac_data.sepal_width[:50],lilac_data.petal_width[:50],c="b")
axes[1,0].scatter(lilac_data.sepal_width[50:100],lilac_data.petal_width[50:100],c="g")
axes[1,0].scatter(lilac_data.sepal_width[100:],lilac_data.petal_width[100:],c="r")
axes[1,1].set_xlabel("sepal_width")
axes[1,1].set_ylabel("petal_length")
axes[1,1].scatter(lilac_data.sepal_width[:50],lilac_data.petal_length[:50],c="b")
axes[1,1].scatter(lilac_data.sepal_width[50:100],lilac_data.petal_length[50:100],c="g")
axes[1,1].scatter(lilac_data.sepal_width[100:],lilac_data.petal_length[100:],c="r")
axes[1,2].set_xlabel("petal_length")
axes[1,2].set_ylabel("petal_width")
axes[1,2].scatter(lilac_data.petal_length[:50],lilac_data.petal_width[:50],c="b")
axes[1,2].scatter(lilac_data.petal_length[50:100],lilac_data.petal_width[50:100],c="g")
axes[1,2].scatter(lilac_data.petal_length[100:],lilac_data.petal_width[100:],c="r")
#切分数据集
from sklearn.model_selection import train_test_split
feature_data = lilac_data.iloc[:,:-1]
label_data = lilac_data["labels"]
X_train,X_test,y_train,y_test = train_test_split(feature_data,label_data,test_size=0.3,random_state=2)
X_test.head()
#构建KNN模型
from sklearn.neighbors import KNeighborsClassifier
def sklearn_classify(train_data,label_data,test_data,k_num):
knn = KNeighborsClassifier(n_neighbors=k_num)
knn.fit(train_data,label_data)
predict_label = knn.predict(test_data)
return predict_label
y_predict = sklearn_classify(X_train,y_train,X_test,3)
print(y_predict)
[out]:
array(['daphne', 'daphne', 'willow ', 'daphne', 'daphne', 'willow ',
'daphne', 'syringa', 'willow ', 'daphne', 'daphne', 'daphne',
'daphne', 'daphne', 'syringa', 'syringa', 'syringa', 'willow ',
'syringa', 'willow ', 'syringa', 'willow ', 'willow ', 'syringa',
'syringa', 'daphne', 'daphne', 'willow ', 'daphne', 'willow ',
'willow ', 'daphne', 'syringa', 'willow ', 'willow ', 'daphne',
'willow ', 'willow ', 'syringa', 'willow ', 'willow ', 'willow ',
'willow ', 'syringa', 'daphne'], dtype=object)
#计算准确率
def get_accuracy(test_labels,pred_labels):
correct = np.sum(test_labels == pred_labels)
accur = correct/len(test_labels)
return accur
print(get_accuracy(y_test,y_predict))
[out]:
0.7777777777777778
#测试k在2~20内的准确率
normal_accuracy=[]
k_value=range(1,10)
for k in k_value:
y_predict = sklearn_classify(X_train,y_train,X_test,k)
accuracy = get_accuracy(y_test,y_predict)
normal_accuracy.append(accuracy)
plt.xlabel("k")
plt.ylabel("arruracy")
plt.yticks(np.linspace(0.6,1,10))
plt.plot(k_value,normal_accuracy,"r")
plt.grid(True) #增加网格画布
K等于4、6时,accuracy最高为0.8667,为了节约计算资源,K取局部最优为4。