简介: k-近邻算法是一种分类算法,无需训练
算法思想:对未知类别的数据集合中的每一个点执行以下操作:
- 计算已知类别数据集(训练集)中的点与当前点的距离;
- 将距离按照递增顺序排列;
- 选取距离最小的k个点;
- 确定这k个点的所属类别,计算各类别出现的概率;
- 将概率最大的类别作为当前点的类别;
原始kNN的python实现:
def KNN(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#计算距离
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
#选择距离最小的k个点
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
kNN用于手写数字识别:
手写数字数据以如下图格式存储,每一个数字由32*32的矩阵构成。
存储数据的文件命名方式如下图(这些文件中存储的数字都为8,格式为n_xx.txt,n为数字本身,xx为数据的编号),数据分为训练集和测试集:
我们首先把图像转换成向量,即把32*32的图像转换成1*1024的向量,使用img2vector函数。完整的python代码如下,如果需要,数据和代码可访问我的github获取:
#!/bin/python
#coding=utf-8
#手写数字识别
#当k=3时分类结果最好,错误率为12%
from numpy import *
from os import listdir
import operator
#将图像转换成向量
def img2vector(filename):
returnVector = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVector[0, 32 * i + j] = int(lineStr[j])
return returnVector
#kNN算法用于确定类别
def classify(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
#计算距离
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
#选择距离最小的k个点
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
#手写数字识别
def handwritingClassTest():
hwLabels = []
#训练集数据转换成向量
trainingFileList = listdir('digits/trainingDigits')
m = len(trainingFileList)
trianingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trianingMat[i,:] = img2vector('digits/trainingDigits/%s' % fileNameStr)
#测试集数据
testFileList = listdir('digits/testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
#调用knn进行分类
classifierResults = classify(vectorUnderTest, trianingMat, hwLabels, 3)
print "分类结果为:%d, 实际类别为:%d" % (classifierResults, classNumStr)
if(classifierResults != classNumStr):
errorCount += 1.0
print "总误差为:%d" % errorCount
print "误差率为:%f" % (errorCount/float(mTest))
if __name__ == '__main__':
handwritingClassTest()