用sklearn实现基于KNN算法的手写数字识别
首先介绍KNN(k-nearest-neighor)
一、KNN算法概述#
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。Cover和Hart在1968年提出了最初的邻近算法。KNN是一种分类(classification)算法,它输入基于实例的学习(instance-based learning),属于懒惰学习(lazy learning)即KNN没有显式的学习过程,也就是说没有训练阶段,数据集事先已有了分类和特征值,待收到新样本后直接进行处理。与急切学习(eager learning)相对应。
KNN是通过测量不同特征值之间的距离进行分类。
思路是:如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
提到KNN,网上最常见的就是下面这个图,可以帮助大家理解。
我们要确定绿点属于哪个颜色(红色或者蓝色),要做的就是选出距离目标点距离最近的k个点,看这k个点的大多数颜色是什么颜色。当k取3的时候,我们可以看出距离最近的三个,分别是红色、红色、蓝色,因此得到目标点为红色。
算法的描述:#
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类
二、关于K的取值#
K:临近数,即在预测目标点时取几个临近的点来预测。
K值得选取非常重要,因为:
如果当K的取值过小时,一旦有噪声得成分存在们将会对预测产生比较大影响,例如取K值为1时,一旦最近的一个点是噪声,那么就会出现偏差,K值的减小就意味着整体模型变得复杂,容易发生过拟合;
如果K的值取的过大时,就相当于用较大邻域中的训练实例进行预测,学习的近似误差会增大。这时与输入目标点较远实例也会对预测起作用,使预测发生错误。K值的增大就意味着整体的模型变得简单;
如果K==N的时候,那么就是取全部的实例,即为取实例中某分类下最多的点,就对预测没有什么实际的意义了;
K的取值尽量要取奇数,以保证在计算结果最后会产生一个较多的类别,如果取偶数可能会产生相等的情况,不利于预测。
K的取法:#
常用的方法是从k=1开始,使用检验集估计分类器的误差率。重复该过程,每次K增值1,允许增加一个近邻。选取产生最小误差率的K。
一般k的取值不超过20,上限是n的开方,随着数据集的增大,K的值也要增大。
三、关于距离的选取#
距离就是平面上两个点的直线距离
关于距离的度量方法,常用的有:欧几里得距离、余弦值(cos), 相关度 (correlation), 曼哈顿距离 (Manhattan distance)或其他。
Euclidean Distance 定义:#
两个点或元组P1=(x1,y1)和P2=(x2,y2)的欧几里得距离是
曼哈顿距离为
python源码
KNN.py
#coding:utf-8
from numpy import *
import operator
##给出训练数据以及对应的类别
def createDataSet():
group = array([[1.0,2.0],[1.2,0.1],[0.1,1.4],[0.3,3.5]])
labels = ['A','A','B','B']
return group,labels
###通过KNN进行分类
def classify(input,dataSe t,label,k):
dataSize = dataSet.shape[0]
####计算欧式距离
diff = tile(input,(dataSize,1)) - dataSet
sqdiff = diff ** 2
squareDist = sum(sqdiff,axis = 1)###行向量分别相加,从而得到新的一个行向量
dist = squareDist ** 0.5
##对距离进行排序
sortedDistIndex = argsort(dist)##argsort()根据元素的值从大到小对元素进行排序,返回下标
classCount={}
for i in range(k):
voteLabel = label[sortedDistIndex[i]]
###对选取的K个样本所属的类别个数进行统计
classCount[voteLabel] = classCount.get(voteLabel,0) + 1
###选取出现的类别次数最多的类别
maxCount = 0
for key,value in classCount.items():
if value > maxCount:
maxCount = value
classes = key
return classes
在命令行输入
#-*-coding:utf-8 -*-
import sys
sys.path.append("...文件路径...")
import KNN
from numpy import *
dataSet,labels = KNN.createDataSet()
input = array([1.1,0.3])
K = 3
output = KNN.classify(input,dataSet,labels,K)
print("测试数据为:",input,"分类结果为:",output)
直接调用sklearn里的KNN算法的代码
训练集和测试集以32×32矩阵的方式存放在txt文件里
import numpy as np
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as KNN
"""
函数说明:将32x32的二进制图像转换为1x1024向量
"""
def img2vector(filename):
#创建1x1024零向量
returnVect = np.zeros((1, 1024))
#打开文件
fr = open(filename)
#按行读取
for i in range(32):
#读一行数据
lineStr = fr.readline()
#每一行的前32个元素依次添加到returnVect中
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
#返回转换后的1x1024向量
return returnVect
"""
函数说明:手写数字分类测试
"""
def handwritingClassTest():
#训练集的Labels
hwLabels = []
#返回trainingDigits目录下的文件名
trainingFileList = listdir('trainingDigits')
#返回文件夹下文件的个数
m = len(trainingFileList)
#初始化训练的Mat矩阵,训练集
trainingMat = np.zeros((m, 1024))
#从文件名中解析出训练集的类别
for i in range(m):
#获得文件的名字
fileNameStr = trainingFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#将获得的类别添加到hwLabels中
hwLabels.append(classNumber)
#将每一个文件的1x1024数据存储到trainingMat矩阵中
trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
#构建kNN分类器
neigh =KNN(n_neighbors = 3, algorithm = 'auto')
#拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签
neigh.fit(trainingMat, hwLabels)
#返回testDigits目录下的文件列表
testFileList = listdir('testDigits')
#错误检测计数
errorCount = 0.0
#测试数据的数量
mTest = len(testFileList)
#从文件中解析出测试集的类别并进行 分类测试
for i in range(mTest):
#获得文件的名字
fileNameStr = testFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#获得测试集的1x1024向量,用于训练
vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
#获得预测结果
classifierResult = neigh.predict(vectorUnderTest)
print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
if(classifierResult != classNumber):
errorCount += 1.0
print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))
"""
函数说明:main函数
"""
if __name__=='__main__':
handwritingClassTest()
因为不同项目中训练集和测试集的存储形式各种各样,因此数据预处理占代码开头很大一部分,也比较繁琐。预处理结束后用KNN模型进行拟合和预测就只需要简单几行代码就能完成了。
通过更改K值得到K=3时错误率最小