一、KNN算法介绍
在现实生活中我们会遇到要把大量的数据进行分类的问题,这个时候一些常见的分类算法就可以对数据进行分类了。一般有1.KNN算法 2.贝克斯算法 3.决策树 4.人工神经网络 5.支持向量机(SVM)等很多分类算法。我们这里讲解一下KNN算法的原理及其编程方法。
我们通过一个简单的物品销售数据分类来讲解KNN算法的原理,首先我们拿出三类数据:三种零食(a,b,c),三种金饰(d,e,f),三种跑车(g,h,i)。建立一个坐标轴,横轴是商品价格,纵轴是商品销量,我们可以发现零食一般价格低、销量高。金饰一般价格中等、销量低。跑车一般价格高、销量高。所以在上图可以将他们分为3类,这时候想知道一个数据m是什么,就可以通过KNN算法来实现:
- 假设m(xm,ym),a(xa,ya)…..
- 计算m点与a,b,c,d,e,f,g,h,i各点的距离,并按照从小到大顺序排列。
- KNN算法中的k便是取上面排列好的距离中的前k个距离,假设我们取k=3,最后就取出来3个离m点最近的点,假设为d,e,b。
- 我们发现d,e,b中金饰占了大部分,所以可以近似认为m是金饰,这样就可以对数据m进行分类了。.
使用KNN算法,先利用总数据的70%进行训练,再用30%进行测试,看看最后是否达到预期。
二、KNN算法编程
from numpy import * #导入numpy模块下的所有用法,用来计算矩阵
import operator #这个模块主要用来对距离进行排序
def knn(k,testdata,traindata,labels): #指定k值、测试数据、训练数据、类别名(零食,金饰...)
''