Python实现K近邻算法
一、题目介绍
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
二、算法设计
knn.py是KNN算法内部实现的函数,train.py首先生成模拟数据,再调用knn.py实现分类,最后完成精度评价。
1.生成鸢尾花数据集
本例中需要用到鸢尾花数据集,它包含在scikit-learn的datasets模块中,可以调用load_iris函数来加载函数。load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值。
2.划分数据集和测试集train_test_split()方法用来直接分割数据,利用伪随机数生成器将数据集打乱,random_state为种子数,其中,random_state=0,表示每次调用train_test_split返回的输出都是不变的,即随机数生成器的种子是相同的。
3.先报训练集转换为DataFrame形式,方便画图。要提前观察下数据集,观察最好的方法就是看图,pandas为我们提供了一个绘制散点图矩阵的函数,叫做scatter_matrix。参数解释: frame:数据的dataframe,本例为4150的矩阵; c是颜色,本例中按照y_train的不同来分配不同的颜色; figsize设置图片的尺寸; marker是散点的形状,‘o’是圆形,’'是星形 ; hist_kwds是直方图的相关参数,{‘bins’:20}是生成包含20个长条的直方图;s是大图的尺寸 ; alpha是图的透明度; cmap是colourmap,就是颜色板。
4.可视化展示数据,绘制散点图
5.KNN里面去进行推测
6.计算推测值的精度
三、源代码(有注释)
1.train.py
import matplotlib.pyplot as plt
from knn import *
#本例中需要用到鸢尾花数据集,它包含在scikit-learn的datasets模块中。可以调用load_iris函数来加载函数
#load_iris返回的iris对象是一个Bunch对象,与字典非常相似,里面包含键和值
from sklearn.datasets import load_iris
iris_dataset = load_iris()
print("key of iris_dataset:\n{}".format(iris_dataset.keys()))
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
#train_test_split()方法用来直接分割数据,利用伪随机数生成器将数据集打乱,random_state为种子数
X_train,X_test,y_train,y_test = train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)
#一般画图使用scatter plot 散点图,但是有一个缺点:只能观察2维的数据情况;如果想观察多个特征之间的数据情况,scatter plot并不可行;
#用pair plot 可以观察到任意两个特征之间的关系图(对角线为直方图);恰巧:pandas的 scatter_matrix函数能画pair plots
#所以,我们先把训练集转换成DataFrame形式,方便画图
iris_dataframe = pd.DataFrame(X_train,columns=iris_dataset.feature_names)
#数据的好坏直接影响你模型构建成功与否,现实中我们的数据可能存在许多问题(单位不统一,部分数据缺失等)
#所以我们要提前观察下数据集,观察最好的方法就是看图,pandas为我们提供了一个绘制散点图矩阵的函数,叫做scatter_matrix
#参数解释: frame:数据的dataframe,本例为4*150的矩阵; c是颜色,本例中按照y_train的不同来分配不同的颜色;figsize设置图片的尺寸; marker是散点的形状,'o'是圆形,'*'是星形 ;hist_kwds是直方图的相关参数,{'bins':20}是生成包含20个长条的直方图;
#s是大图的尺寸 ; alpha是图的透明度; cmap是colourmap,就是颜色板
grr = pd.plotting.scatter_matrix(iris_dataframe,c=y_train,marker='o',figsize=(10,10),hist_kwds={'bins':20},s=60,alpha=0.8,cmap='viridis')
#plt.show()
#visualize data
plt.scatter(X_train[:,0], X_train[:,1], c=y_train, marker='.') #绘制散点图
plt.show()
plt.scatter(X_test[:,0], X_test[:,1], c=y_test, marker='.')
plt.show()
#knn classifier
#调用KNN分类器
clf = KNN(k=3)
clf.fit(X_train, y_train)
print('train accuracy: {:.3}'.format(clf.score()))
y_test_pred = clf.predict(X_test)
print('test accuracy: {:.3}'.format(clf.score(y_test, y_test_pred)))
2.knn.py
import numpy as np
import operator
class KNN(object):
def __init__(self, k=3):
self.k = k
def fit(self, x, y): #fit()函数将x和y传进去
self.x = x
self.y = y
def _square_distance(self, v1, v2): #计算任意两点之间的距离平方
return np.sum(np.square(v1 - v2))
def _vote(self, ys): #投票
ys_unique = np.unique(ys) #ys取唯一值
vote_dict = {} #用字典进行操作
for y in ys:
if y not in vote_dict.keys(): #y不在当前字典的键里
vote_dict[y] = 1 #建立k=0的键,值为1
else:
vote_dict[y] += 1 #y在当前字典的键里,值加1
#第一个参数为可迭代的参数,reverse为从大到小排序
sorted_vote_dict = sorted(vote_dict.items(), key=operator.itemgetter(1), reverse=True)
return sorted_vote_dict[0][0]
def predict(self, x): #接收x参数,多行的点数据,每行是一个二维的向量
y_pred = []
for i in range(len(x)):
#得到当前的x[i]和所有的训练样点之间的平方距离,保存于数组当中
dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))] #循环内部训练数据方法计算
sorted_index = np.argsort(dist_arr) #从小到大排序距离,返回索引
top_k_index = sorted_index[:self.k]
y_pred.append(self._vote(ys=self.y[top_k_index])) #添加当前x和y的预测值
return np.array(y_pred)
def score(self, y_true=None, y_pred=None): #计算推测值的精度
if y_true is None or y_pred is None:
y_pred = self.predict(self.x)
y_true = self.y
score=0.0
for i in range(len(y_true)):
if y_true[i] == y_pred[i]:
score += 1
score /= len(y_true) #得到正确率
return score
四、运行结果
1.lris数据集的散点图矩阵,按类别标签着色。如图1;
图1
2.得到最终结果,显示推测值精度。如图2;
图2
五、总结
在利用Python第三方库时,需要了解到install操作。在利用鸢尾花数据集的时候训练数据与测试数据是衡量模型是否成功的重要因素。在构建机器学习模型之前,通常要检查一下数据,看看如果不用机器学习能不能轻松完成任务,或者需要的信息又没有包含在数据中。检查数据也是发现异常值和特殊值的好办法。检查数据的最佳方法之一就是将其可视化,可以绘制散点图矩阵,这比绘制散点图要好得多。
KNN的主要优点有:
1) 理论成熟,思想简单,既可以用来做分类也可以用来做回归,
2) 可用于非线性分类,
3) 训练时间复杂度比支持向量机之类的算法低,仅为O(n),
4) 和朴素贝叶斯之类的算法比,对数据没有假设,准确度高,对异常点不敏感,
5) 由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合,
6)该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。
KNN的主要缺点有:
1)计算量大,尤其是特征数非常多的时候,
2)样本不平衡的时候,对稀有类别的预测准确率低,
3)KD树,球树之类的模型建立需要大量的内存,
4)使用懒散学习方法,基本上不学习,导致预测时速度比起逻辑回归之类的算法慢,
5)相比决策树模型,KNN模型可解释性不强。