通过观看机器学习实战这本书,有了些许读后感,下面是我理解这本书里面的KNN算法,希望阔以帮助你们稍微加强一下理解
数据集代码下载
KNN算法其实就是邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
也就是说离你最近的k个点中,大多数点属于的类别也就是这个样本属于的类别,也就是用俗语说,物以类聚,人以群分。属于较为简单的算法
#coding=UTF8
from numpy import *
import operator
from os import listdir
def classify0(inX, dataset, labels, k):
dataSetSize = dataset.shape[0]#训练集的行数
diffMat = tile(inX, (dataSetSize, 1)) - dataset#tile函数复制datasize行
sqDiffMat = diffMat ** 2#平方
sqDistance = sqDiffMat.sum(axis=1)#横向求和
distance = sqDistance ** 0.5#开方
sortedDistIndicies = distance.argsort()#排序
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1,1024))#返回一行,1024列的数组
fr = open(filename)
for i in range(32):
lineStr = fr.readline()#一行一行读取
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])#把文件中的数据存入到里面,键值对应
return returnVect#返回一个装着所有数据的列表
def handwritingClassTest():
hwLabels = []
# 加载训练数据
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)#文件的个数
trainingMat = zeros((m,1024))#建立一个m行,1024列的数组,每行存入一个文件
for i in range(m):
fileNameStr = trainingFileList[i]#第 i个文件
fileStr = fileNameStr.split('.')[0] #分割文件名字
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)#添加标签,也就是数字是几,添加到一个列表里面
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#把每一个文件一行一行的存入到训练集列表里面
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)#只有一行的列表
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)#计算测试集和每个训练集的距离
print "the classifier came back with: %d, the real answer is: %d, The predict result is: %s" % (classifierResult, classNumStr, classifierResult==classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d / %d" %(errorCount, mTest)
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
if __name__== "__main__":
handwritingClassTest()
#coding=UTF8
from numpy import *
import operator
from os import listdir
def classify0(inX, dataset, labels, k):
dataSetSize = dataset.shape[0]#训练集的行数
diffMat = tile(inX, (dataSetSize, 1)) - dataset#tile函数复制datasize行
sqDiffMat = diffMat ** 2#平方
sqDistance = sqDiffMat.sum(axis=1)#横向求和
distance = sqDistance ** 0.5#开方
sortedDistIndicies = distance.argsort()#排序
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1,1024))#返回一行,1024列的数组
fr = open(filename)
for i in range(32):
lineStr = fr.readline()#一行一行读取
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])#把文件中的数据存入到里面,键值对应
return returnVect#返回一个装着所有数据的列表
def handwritingClassTest():
hwLabels = []
# 加载训练数据
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)#文件的个数
trainingMat = zeros((m,1024))#建立一个m行,1024列的数组,每行存入一个文件
for i in range(m):
fileNameStr = trainingFileList[i]#第 i个文件
fileStr = fileNameStr.split('.')[0] #分割文件名字
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)#添加标签,也就是数字是几,添加到一个列表里面
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#把每一个文件一行一行的存入到训练集列表里面
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)#只有一行的列表
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)#计算测试集和每个训练集的距离
print "the classifier came back with: %d, the real answer is: %d, The predict result is: %s" % (classifierResult, classNumStr, classifierResult==classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is: %d / %d" %(errorCount, mTest)
print "\nthe total error rate is: %f" % (errorCount/float(mTest))
if __name__== "__main__":
handwritingClassTest()