算法伪代码:
1、计算输入向量intX距训练集中各点的距离
2、将距离按从小到大排序
3、选取最小的k个值
4、统计其中各类标签数目
5、选取标签出现最多的,作出预测
代码如下:
#intX为输入向量,k为选取最邻近的k个点
def classify0(intX,dataset,labels,k):
diff=dataset-intX #运用了广播机制,使dataset的每一行都减去输入向量
diff2=diff**2
sum1=sum(diff2,axis=1) #按行相加
distance=sum1**0.5
sortedDis=distance.argsort() #argsort函数返回排序索引
dic={}
#选取据输入向量最近的k个点,统计标签数目
for i in range(k):
voteLabel=labels[sortedDis[i]]
dic[voteLabel]=dic.get(voteLabel,0)+1 #字典的get函数:Get an element with a default
sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedDic[0][0]
sorted函数:排序后原序列顺序不变
sorted(iterable, cmp=None, key=None, reverse=False)
operator.itemgetter(1)通过比较第二个数据成员来排序
示例:使用k近邻算法改进约会网站的配对效果
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
#intX为输入向量,k为选取最邻近的k个点
def classify0(intX,dataset,labels,k):
diff=dataset-intX #运用了广播机制,使dataset的每一行都减去输入向量
diff2=diff**2 #取平方
sum1=sum(diff2,axis=1) #按行相加
distance=sum1**0.5 #开方
sortedDis=distance.argsort() #argsort函数返回排序索引
dic={} #字典
#选取据输入向量最近的k个点,统计标签数目
for i in range(k):
voteLabel=labels[sortedDis[i]]
dic[voteLabel]=dic.get(voteLabel,0)+1 #字典的get函数:Get an element with a default
sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedDic[0][0]
def file2matrix(filename):
fr=open(filename)
fileLines=fr.readlines() #将文件的每一行读取后,作为List中的一个元素。每个元素为一个字符串
numberOfLines=len(fileLines) #文件总共行数(List的长度)
returnMat=zeros((numberOfLines,3)) #用0初始化一个m*n的数组(numpy数组)
classLabelVector=[]
index=0
for line in fileLines:
line=line.strip() #Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)。
splitLine=line.split('\t') #split()通过指定分隔符对字符串进行切片,如果参数num 有指定值,则仅分隔 num 个子字符串
#此为用'\t'将字符串分隔成数组(List)形式 形如:['38343', '7.241614', '1.661627', '3\n']
returnMat[index,:]=splitLine[0:3]
classLabelVector.append(int(splitLine[-1]))
index+=1
return returnMat,classLabelVector
def autoNorm(dataset):
maxOfCol=dataset.max(0) #计算列最大值
minOfCol=dataset.min(0)
max_minOfCol=maxOfCol-minOfCol
normDataSet=(dataset-minOfCol)/max_minOfCol
return normDataSet,minOfCol,max_minOfCol
def draw():
datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
#调用figure创建一个绘图对象,并且使它成为当前的绘图对象
fig = plt.figure()
#绘制的图像在1*1的网格中占第一部分(即占满)
ax = fig.add_subplot(111)
#datingDataMat矩阵中第二列为x,第三列为y
ax.scatter(datingDataMat[:,0], datingDataMat[:,1],15.0*array(datingLabels), 15.0*array(datingLabels))
plt.xlabel("frequent flier miles")
plt.ylabel("percentage of time spent playing video games")
plt.show()
def datingClassTest():
testRatio=0.1
datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
normDataSet,Min,ranges=autoNorm(datingDataMat)
tNum=int(testRatio*(normDataSet.shape[0]))
trainingSet=normDataSet[tNum:,:]
trainingLabels=datingLabels[tNum:]
error=0.0
k=4
for i in range(tNum):
testResult=classify0(normDataSet[i,:],trainingSet,trainingLabels,k)
print "the test result is %d, the real is %d" %(testResult,datingLabels[i])
if testResult != datingLabels[i]:
error+=1.0
print "error number is %d, error rate is %f" %(error,error/float(tNum))
def classifyPerson():
datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
normDataSet,Min,ranges=autoNorm(datingDataMat)
testRatio=0.1
tNum=int(testRatio*(normDataSet.shape[0]))
trainingSet=normDataSet[tNum:,:]
trainingLabels=datingLabels[tNum:]
arg1=input("frequent flier miles earned per year?")
arg2=input("percentage of time spent playing video games?")
arg3=input("liters of ice cream consumed per year?")
intX=array([arg1,arg2,arg3])
testResult=classify0((intX-Min)/ranges,trainingSet,trainingLabels,3)
labelTable=['not at all','in small doses','in large doses']
print "You will probably like this person: ",labelTable[testResult-1]
#draw()
#datingClassTest()
classifyPerson()
手写数字识别系统的测试代码
from numpy import *
from os import listdir
def img2vector(filename):
fr=open(filename)
lines=fr.readlines()
mat=zeros((1,1024)) #numpy数组
k=0
for line in lines:
for j in range(32):
mat[0][k]=int(line[j]) #numpy数组应将一维显示地表示成m*n形式
k+=1
return mat
#intX为输入向量,k为选取最邻近的k个点
def classify0(intX,dataset,labels,k):
diff=dataset-intX #运用了广播机制,使dataset的每一行都减去输入向量
diff2=diff**2 #取平方
sum1=sum(diff2,axis=1) #按行相加
distance=sum1**0.5 #开方
sortedDis=distance.argsort() #argsort函数返回排序索引
dic={} #字典
#选取据输入向量最近的k个点,统计标签数目
for i in range(k):
voteLabel=labels[sortedDis[i]]
dic[voteLabel]=dic.get(voteLabel,0)+1 #字典的get函数:Get an element with a default
sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedDic[0][0]
def handwritingClassTest():
trainFileList=listdir('trainingDigits') #将目录下所有文件的文件名用List表示,每一个文件名为List中一个元素
testFileList=listdir('testDigits')
m=len(trainFileList)
n=len(testFileList)
error=0.0
trainMat=zeros((m,1024))
realResult=[]
for i in range(m):
mat=img2vector('trainingDigits/'+trainFileList[i])
trainMat[i]=mat
realResult.append(int(trainFileList[i][0]))
for i in range(n):
testMat=img2vector('testDigits/'+testFileList[i])
testResult=classify0(testMat,trainMat,realResult,3)
testRealResult=int(testFileList[i][0])
print 'test result is %d, real result is %d' %(testResult,testRealResult)
if testResult != testRealResult:
error+=1
print 'there is %d error, the error rate is %f' %(error,error/float(n))
handwritingClassTest()