一、knn算法介绍
1. 介绍
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
2. 核心概括
主要的思想是计算待分类样本与训练样本之间的差异性,并将差异按照由小到大排序,选出前面K个差异最小的类别,并统计在K个中类别出现次数最多的类别为最相似的类,最终将待分类样本分到最相似的训练样本的类中。与投票(Vote)的机制类似。
3. 样本间差异性度量
比较常见的距离度量方式有欧氏距离、曼哈顿距离等。以欧氏距离为例:例如训练样本与待测样本之间的欧式距离为:
二、knn算法流程
1.求训练样本与待预测样本间的相似性,例如计算欧式距离;
2.依据相似性(如计算的欧式距离)排序
3.选择前K个最为相似的样本对应的类别
4.类别出现最多的即为最终的分类结果
三、python实现knn算法以及预测
数据来源以及介绍:
数据来源于uci机器学习库:http://archive.ics.uci.edu/ml/datasets/Iris
数据介绍
主要依据以下4个花的特征来预测花的等级
特征如下:
1.萼片长度(cm)
2.萼片宽度(cm)
3.花瓣长度(cm)
4.花瓣宽度(cm)
等级:
- Iris Setosa
- Iris Versicolour
- Iris Virginica
部分数据样本示例
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
python实现knn以及预测
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import numpy as py
import operator
#加载数据
def loadData(filePath):
dataSet=[];
lable=[];
with open(filePath) as f:
for line in f.readlines():
lines=line.split(",");
dataSet.append([float(lines[0]),float(lines[1]),float(lines[2]),float(lines[3])]);#获得每一行数据的前4列
lable.append(lines[4]);#当前数据的标签
return py.array(dataSet),lable #dataSet转为数组
#knn算法
def knn(trainSet,label,testSet,k):
distance=(trainSet-testSet)**2;#求差的平方和---注意:数组可以做加减,此处均为数组
distanceLine=distance.sum(axis=1);#对数组的每一行求和,axis=1为对行求和,axis=0为对每列求和
finalDistance=distanceLine**0.5;#对和开方
sortedIndex=finalDistance.argsort();#获得排序后原始下角标
index=sortedIndex[:k];#获得距离最小的前k个下角标
labelCount={};#字典 key为标签,value为标签出现的次数
for i in index:
tempLabel=label[i];
labelCount[tempLabel]=labelCount.get(tempLabel,0)+1;
sortedCount=sorted(labelCount.items(),key=operator.itemgetter(1),reverse=True);#operator.itemgetter(1)意思是按照value值排序,即按照欧氏距离排序
return sortedCount[0][0];#输出标签出现最多的那个
#预测正确率
def predict(trainSet,trainLabel,testSet,k):
total=len(testSet);#测试样本总数(本次测试数据为1/5)
trueCount=0;
for i in range(len(testSet)):
label=knn(trainSet,trainLabel,testSet[i],k);
if label in testLabel[i]:
trueCount=trueCount+1;
return float(trueCount)/float(total)
if __name__=='__main__':
trainSet,trainLabel=loadData("D:\\data\iris_train.txt");#训练数据以及标签
testSet, testLabel = loadData("D:\\data\iris_test.txt");#测试数据以及标签
print predict(trainSet,trainLabel,testSet,3)
预测正确率:0.955
本文所用数据下载:https://download.csdn.net/download/wickedvalley/10330973