k近邻算法(kNN)是监督学习的一种。其原理非常简单:存在一个样本数据集,也称作训练样本集。样本集中的每个数据都存在标签,即知道数据与对应分类的关系。输入新的没有标签的数据,将新的数据的每个特征与样本中的数据特征进行对比,然后利用算法提取出样本集中特征最相似的数据(最邻近)分类标签。一般来说我们只选取样本集中前k个最相似的数据。
k近邻算法一般流程:
1.选择一种距离计算方式,通过数据所有的特征计算新的数据与已知数据集中的距离
2.按照距离递增的顺序进行排序,选取与当前距离最近的k个点
3.返回k个点出现频率最多的类作为预测类。如果是回归问题,则需要加上权值。
需要注意的是kNN是不需要训练的。
kNN算法的关键
1.k的选择
举一个简单的例子,绿色圆表示要被赋值的点,是红色三角形还是蓝色四边形?如果k为3,则红色三角形出现的比例是2/3,所以预测结果为红色三角形。若k=5,则蓝色四边形出现的比例为3/5,所以预测结果为蓝色四边形。
因此k的选择对于结果的影响非常大。如果k值太小,受到噪声干扰比较明显,容易受到出现过拟合现象。而k值过大则会导致其分界不明显。一种选择K值得方法是使用 cross-validate(交叉验证)误差统计选择法。也就是将数据样本的一部分作为训练样本,一部分作为测试样本。一般来说选择90%的作为训练数据集,剩下的作为测试集。选择不同的k值计算误差,最后选出误差最小的k值。
2.需要对数据所有特征做可比较量化
如果数据中存在非数值类型,必须对其进行数值化。例如样本中包含颜色(红绿蓝)。颜色本身时没有距离的,但是我们可以将其转换成灰度值在进行计算。另外,样本有多个参数(特征),每个参数都有自己的定义域和取值范围。它们对于距离计算的影响也就不一样。为了公平起见,我们必须将特征进行归一化。
下面是python的实现代码
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
创建数据和对应的标签
def classify0(inX, dataSet, labels, k):
dataSetSize=dataSet.shape[0]#返回dataset的第一维的长度
diffMat = tile(inX, (dataSetSize,1)) - dataSet#计算各个点到原点的x轴,y轴的距离。
#计算出各点离原点的距离
#表示diffMat的平方
sqDiffMat = diffMat**2#平方只针对数组有效
sqDistances=sqDiffMat.sum(axis = 1)
distances=sqDistances**0.5
sortedDistIndices = distances.argsort()#返回从小到大的引索
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndices[i]]#找到对应的从小到大的标签
classCount[voteLabel] = classCount.get(voteLabel,0)+1
print(classCount.get(voteLabel,0)+1)
print(classCount)
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
调用函数
group,labels = createDataSet()
classer=classify0([0,0],group,labels,3)
group表示的是一个4*2的数组,也就是说dataSet是一个4*2的矩阵。dataSet.shape[0]返回的值为4。tile表示将数组进行复制。然后计算出各个点距离原点的欧氏距离。最后开根号排序,选择距离最近的k各点。
第二步:将数据从文本中导入,并将文本记录转换成Numpy的解析程序。
def file2matrix(filename):
fr=open(filename)#打开文件
arrayOLines=fr.readlines()#读取所有行的数据,直到遇到结束符
numberOfLines=len(arrayOLines)
returnMat=zeros((numberOfLines,3))
classLabelVector=[]
index = 0
for lines in arrayOLines:
lines = lines.strip()#截取掉后面的换行符
listFromLine = lines.split('\t')#
returnMat[index,:]=listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
第三部:分析数据,将数据用散点图表示出来
def show(datingDataMat,datingLabels):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show()
最后利用kNN实现手写识别程序的完整代码
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
from os import listdir
def classify0(inX, dataSet, labels, k):
dataSetSize=dataSet.shape[0]#返回dataset的第一维的长度
print(dataSetSize)
diffMat = tile(inX, (dataSetSize,1)) - dataSet
#计算出各点离原点的距离
#表示diffMat的平方
sqDiffMat = diffMat**2#平方只针对数组有效
sqDistances=sqDiffMat.sum(axis = 1)
distances=sqDistances**0.5
sortedDistIndices = distances.argsort()#返回从小到大的引索
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndices[i]]#找到对应的从小到大的标签
classCount[voteLabel] = classCount.get(voteLabel,0)+1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createDataSet():
group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])#numpy里面的数组,注意和list的区别
labels=['A','A','B','B']
return group,labels
def file2matrix(filename):
fr=open(filename)
arrayOLines=fr.readlines()
numberOfLines=len(arrayOLines)
print(numberOfLines)
returnMat=zeros((numberOfLines,3))
classLabelVector=[]
index = 0
for lines in arrayOLines:
lines = lines.strip()
listFromLine = lines.split('\t')
returnMat[index,:]=listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
def show(datingDataMat,datingLabels):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show()
def autoNorm(dataSet):#将特征值归一化
minVals=dataSet.min(0)#选择数据集中最小的
maxVals=dataSet.max(0)
ranges = maxVals - minVals
normDataSet=zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet-tile(minVals,(m,1))
normDataSet = normDataSet/tile(ranges,(m,1))
return normDataSet,ranges,minVals
def datingClassTest():
hoRatio = 0.50 # hold out 10%
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') # load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]):
errorCount += 1.0
print( "the total error rate is: %f" % (errorCount / float(numTestVecs)))
# print(errorCount)
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') # load the training set
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits') # iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] # take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
if __name__ == "__main__":
group,labels = createDataSet()
classer=classify0([0,0],group,labels,3)
# handwritingClassTest()
datingDataMat, datingLabels=file2matrix('datingTestSet2.txt')
show(datingDataMat,datingLabels)