在有标签的数据中,输入没有标签的数据后,通过计算数据特征与样本数据进行比较,算法提取样本集中特征最相似的分类标签。一般取前K个最相似的数,这就是k-近邻算法。从K近邻算法、距离度量谈到KD树、SIFT+BBF算法
实验基础
python/numpy中会用到的函数:
shape()
shape是numpy函数库中的方法,用于查看矩阵或者数组的维素
>>>shape(array) 若矩阵有m行n列,则返回(m,n)
>>>array.shape[0] 返回矩阵的行数m,参数为1的话返回列数n
tile()
tile是numpy函数库中的方法,用法如下:
>>>tile(A,(m,n)) 将数组A作为元素构造出m行n列的数组
sum()
sum()是numpy函数库中的方法
>>>array.sum(axis=1)按行累加,axis=0为按列累加
argsort()
argsort()是numpy中的方法,得到矩阵中每个元素的排序序号
>>>A=array.argsort() A[0]表示排序后 排在第一个的那个数在原来数组中的下标
dict.get(key,x)
python中字典的方法,get(key,x)从字典中获取key对应的value,字典中没有key的话返回0
sorted()
python中的方法
min()、max()
numpy中有min()、max()方法,用法如下
>>>array.min(0) 返回一个数组,数组中每个数都是它所在列的所有数的最小值
>>>array.min(1) 返回一个数组,数组中每个数都是它所在行的所有数的最小值
listdir('str')
python的operator中的方法
>>>strlist=listdir('str') 读取目录str下的所有文件名,返回一个字符串列表
split()
python中的方法,切片函数
>>>string.split('str')以字符str为分隔符切片,返回list
k-近邻算法的一般流程
1. 收集数据
2. 准备数据
3. 分析数据
4. 训练算法
5. 测试算法
6. 使用算法:k-近邻 手写识别数字应用
根据流程,首先我们要导入数据,然后对数据分析转换成需要的格式,完成k-近邻分类器的设计,再跑一下数据。
导入数据
# coding=utf-8
from numpy import *
import operator
#导入数据
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
k-近邻分类器
#k-近邻算法
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances**0.5
sortedDistIndices = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
使用k-近邻改进约会网站的配对效果
# 使用k-近邻改进约会网站的配对效果
#将文本记录转换到numpy
def file2matrix(filename):
fr = open(filename)
arrayQlines = fr.readlines()
numberOfLines = len(arrayQlines)
returnMat = zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayQlines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
# 分析数据: 使用matplotlib 创建散点图 数据可视化
# import matplotlib
# import matplotlib.pyplot as plt
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
# plt.show()
#归一化数值
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1))
return normDataSet, ranges, minVals
#分类器针对约会网站测试代码
def datingClassTest():
hoRation = 0.10
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRation)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:], datingLabels[numTestVecs:m],3)
print "the classifier came back with: %d, the real answer is : %d"%(classifierResult,datingLabels[i])
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is : %f" %(errorCount/float(numTestVecs))
#约会完站预测函数
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(raw_input("percentage of time spent playing video games? "))
ffMiles = float(raw_input("frequent flier miles earned per year?"))
iceCream = float(raw_input("liters of ice cream consumed per years?"))
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
classifierResult = classify0((inArr - minVals)/ranges, normMat, datingLabels,3)
print "you will probably like this person: ", resultList[classifierResult - 1]
k-近邻 手写识别系统
这里除了使用之前编写的k-近邻分类器,还提供了sklearn版。详细如下。
# 手写识别系统
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i + j] = int(lineStr[j])
return returnVect
# 测试算法: 使用k-近邻算法识别手写数字
from os import listdir
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s'%fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print "the classifier came back with %d, the real answer is :%d "%(classifierResult, classNumStr)
if (classifierResult != classNumStr) : errorCount += 1.0
print "\n the total number of errrors is :%d"%errorCount
print "\n the total error rate is : %f"%(errorCount/float(mTest))
# 手写识别sklearn版本
from sklearn.neighbors import KNeighborsClassifier
def sklearnKnnTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s'%fileNameStr)
neigh = KNeighborsClassifier(n_neighbors = 3) #k=3
neigh.fit(trainingMat, hwLabels) #数据拟合
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s'%fileNameStr)
classifierResult = neigh.predict(vectorUnderTest) #数据预测
print "the classifier came back with %d, the real answer is :%d "%(classifierResult, classNumStr)
if (classifierResult != classNumStr) : errorCount += 1.0
print "\n the total number of errrors is :%d"%errorCount
print "\n the total error rate is : %f"%(errorCount/float(mTest))