kNN算法是k近邻分类(k-nearest neighbor classification)算法的简称。基本流程是从训练集中找到和新数据最接近的k条记录,然后根据他们的主要分类来决定新数据的类别。该算法涉及3个主要因素:训练集、距离或相似的衡量、k的大小。
一、算法的基本步骤如下所示:
输入: 训练数据T;近邻数目k;待分类的元组t。
输出: 输出类别c。
(1)N=O;
(2)FOR each d∈T DO BEGIN
(3) IF |N|≤k THEN
(4) N=N∪{d};
(5) ELSE
(6) IF $u∈N such that sim(t,u)<sim(t,d)THEN BEGIN
(7) N=N-{u};
(8) N=N∪{d};
(9) END
(10)END
(11)c=class to which the most u ∈N.
二、举例说明
对于给定的如下训练集,若“高度”用于计算距离,k=5,则对新样本<Pat,女,1.6>该如何分类?
表1.训练集数据
姓名
性别
身高(米)
类别
Kristina
女
1.6
矮
Jim
男
2
高
Maggie
女
1.9
中等
Martha
女
1.83
中等
Stephanie
女
1.7
矮
Bob
男
1.85
中等
Kathy
女
1.6
矮
Dave
男
1.7
矮
Worth
男
2.2
高
Steven
男
2.1
高
Debbie
女
1.8
中等
Todd
男
1.95
中等
Kim
女
1.9
中等
Amy
女
1.8
中等
Wynette
女
1.75
中等
•对T前k=5个记录,N={<Kristina,女, 1.6>、<Jim,男,2>、<Maggie,女,1.9>、<Martha,女,1.83>和<Stephanie,女,1.7>}。•对第6个记录d=<Bob,男,1.85>,得到N={<Kristina,女,1.6>、<Bob,男,1.85>、<Maggie,女,1.9>、<Martha,女,1.83>和<Stephanie,女,1.7>}。•对第7个记录d=<Kathy,女,1.6>,得到N={<Kristina,女, 1.6>、<Bob,男,1.85>、<Kathy,女,1.6>、<Martha,女,1.83>和<Stephanie,女,1.7>}。•对第8个记录d=<Dave,男,1.7>,得到N={<Kristina,女,1.6>、<Dave,男,1.7>、<Kathy,女,1.6>、<Martha,女,1.83>和<Stephanie,女,1.7>}。•对第9和10个记录,没变化。•对第11个记录d=<Debbie,女,1.8>,得到N={<Kristina,女,1.6>、<Dave,男,1.7>、<Kathy,女,1.6>、<Debbie,女,1.8>和<Stephanie,女,1.7>}。•对第12到14个记录,没变化。•对第15个记录d=<Wynette,女,1.75>,得到N={<Kristina,女,1.6>、<Dave,男,1.7>、<Kathy,女,1.6>、<Wynette,女,1.75>和<Stephanie,女,1.7>}。
最后的输出元组是<Kristina,女,1.6>、<Kathy,女,1.6>、<Stephanie,女,1.7>、<Dave,男,1.7>和<Wynette,女,1.75>。在这五项中,四个属于矮个、一个属于中等。最终kNN方法认为Pat为矮个。
三、算法的优缺点
1、优点
用距离(常用欧式距离、曼哈顿距离等)来表征相似性,距离越近,相似性越大,距离越远,相似性越小。简单直观,易于理解,易于实现,无需估计参数,无需训练、特别适合于多分类问题(对象具有多个类别标签)。
2、缺点
懒惰算法,模型构造很简单,但对测试样本分类时的计算量大,内存开销大,评分慢。而且可解释性较差,无法给出决策树那样的规则。
四、常见问题
对超参数k很敏感、k值设定为多大才合适?k太小,分类结果易受噪声点影响;k太大,近邻中又可能包含太多的其它类别的点。而且不同的k值可能会得到不一样的结果。如何消除此影响:可以对距离加权,距离待分样本较近的训练样本具有较大的权重、距离待分样本较远的训练其权重较小。对于k的选择,通常是采用交叉检验来确定。经验规则:k一般低于训练样本数的平方根。
图1.有一个未知形状X(图中绿色的点),如何判断X是什么形状?k=6 or k=11获得的结果是否相同?
五、实验代码
scikit-learn(简称sklearn)是目前最受欢迎、功能最强大的一个机器学习Python库。它广泛地支持各种分类、聚类以及回归分析方法,比如kNN、kmeans、支持向量机、随机森林、DBSCAN等。由于其强大的功能、优异的拓展性以及易用性,已受到了很多数据科学从业者的喜爱,也是业界相当著名的一个开源项目之一。为简化起见,本系列博客中大部分算法均直接基于scikit-learn来验证。若你电脑还没有安装此库,可直接通过pip
来自动安装:
pip install scikit-learn。
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import neighbors
import sklearn
# 查看iris数据集
irisdata = load_iris()
# 调用kNN算法,将参数k设置为5
knn = neighbors.KNeighborsClassifier(n_neighbors=5)
# 训练数据集
knn.fit(irisdata.data, irisdata.target)
# 预测
predict = knn.predict([[0.1, 0.2, 0.3, 0.4]])
print predict,irisdata.target_names[predict]