给定训练数据样本和标签,对于某测试的一个样本数据,选择距离其最近的k个训练样本,这k个训练样本中所属类别最多的类即为该测试样本的预测标签。简称kNN。通常k是不大于20的整数,这里的距离一般是欧式距离。
K最近邻(k-Nearest Neighbour,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
用官方的话来说,所谓K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例(也就是上面所说的K个邻居), 这K个实例的多数属于某个类,就把该输入实例分类到这个类中。
下面是机器学习实战书中一些代码的实现:
其中包含使用k-邻近算法改进约会网站配对效果代码和手写识别系统的代码:
#coding=UTF8
from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import operator
from os import listdir
def createDataSet():
group = array([[3,104],[2,100],[1,81],[101,10],[99,5]]) #训练集
labels = ['affectional film','affectional film','affectional film','action movie',"action movie"]
return group,labels
def classify0(inX,dataSet,labels,k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat ** 2; #diffMat ^ 2
sqDistances = sqDiffMat.sum(axis=1) #将矩阵的每一行相加比如[[2,1,3],[1,1,1]]结果为[6,3]
distances = sqDistances ** 0.5 #sqDistances ^ (1/2)
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) #获取对象的第1个域的值
return sortedClassCount[0][0]
def file2matrix(filename):
fr = open(filename)
numberOfLines = len(fr.readlines()) #get the number of lines in the file
returnMat = zeros((numberOfLines,3)) #prepare matrix to return
classLabelVector = [] #prepare labels return
fr = open(filename)
index = 0
for line in fr.readlines():
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
def autoNorm(dataSet):
minVals = dataSet.min(0) #参数0从当前列选取最小值
maxVals = dataSet.max(0) #同上
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet)) #创建规模为dataSet的零矩阵
m = dataSet.shape[0] #行
normDataSet = dataSet - tile(minVals, (m,1)) #m行1列的minVals
normDataSet = normDataSet / tile(ranges, (m,1))
return normDataSet,ranges,minVals
def datingClassTest():
hoRatio = 0.10
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat,ranges,minVals = autoNorm(datingDataMat)
m = normMat.shape[0] #行
numTestVecs = int(m * hoRatio) #算出测试数据
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m], 3) #@param1:读入每行的数据,@param2:样本数据,因为测试数据是0-numTestVecs
print "the classifier came back with: %d,the real answer is : %d" % (classifierResult,datingLabels[i])
if(classifierResult != datingLabels[i]):errorCount += 1.0
print "the total error rate is: %f" % (errorCount / float(numTestVecs))
def classifyPerson():
percentTats = float(raw_input("输入玩视频游戏所消耗时间的百分比:"))
ffMiles = float(raw_input("输入每年获得的飞行常客里程数:"))
iceCream = float(raw_input("输入每周消费的冰淇淋公升数:"))
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat,ranges,minVals = autoNorm(datingDataMat)
inArr = array([ffMiles,percentTats,iceCream])
classifierResult = classify0((inArr-minVals)/ranges, normMat ,datingLabels , 3)
print(classifierResult)
temp = classifierResult - 1;
if temp == 0:
print("一点都不喜欢这个人")
elif temp == 1:
print("一般般")
else:
print("非常喜欢")
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect #returnVect为1*1024的数组
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir("trainingDigits")
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m): #这是训练集
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest): #这是测试集
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
print "the classifier came back with: %d,the real answer is : %d" % (classifierResult,classNumStr)
if(classifierResult != classNumStr) : errorCount += 1.0
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is:%f" % (errorCount / float(mTest))
'''
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
print(datingDataMat)
print(datingLabels[0:20])
normMat,ranges,minVals = autoNorm(datingDataMat)
print('\n')
print(normMat)
print(ranges)
print(minVals)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0 * array(datingLabels))
plt.show()
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
ax1.scatter(normMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0 * array(datingLabels))
plt.show()
datingClassTest()
'''
#classifyPerson()
#testVector = img2vector('0_13.txt')
#print testVector[0,0:31]
#print testVector[0,32:63]
handwritingClassTest()