import numpy import os from numpy import array from numpy import tile import operator import matplotlib.pyplot as plt #数据例子 def createDataSet(): group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels=['A','A','B','B'] #标签与点一一对应 return group,labels '''******************************主要分类函数************************************************************''' #取距离最近的k各点, 返回 k个点中频率最多的类别作为分类 def classify0(point,dataArray,labels,k):#(测试[,,...] 比较集array 标签集 OneDimension=dataArray.shape[0] tmpArray=(tile(point,(OneDimension,1))-dataArray)**2 #point平铺成二维矩与其计算各点距离 sqrtArray=tmpArray.sum(1) sortedArrayIndex=sqrtArray.argsort()#按索引点排序 -列表 #print(sortedArrayIndex) classCount={} #空字典 for i in range(k): lab=labels[sortedArrayIndex[i]] #取相应索引点的标签 classCount[lab]=classCount.get(lab,0)+1 #字典中有该key则取其映射值(这里为int),否则返回0 sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #指明关键字 return sortedClassCount[0][0] #测试 ''' group0,labels0=createDataSet() print(group0,labels0) print(classify0([0,0],group0,labels0,3)) ''' '''*********************************************************************''' '''***************************约会配对分类*******************************''' '''*********************************************************************''' #获取文件数据 返回数据Array 标签list def file2matrix(filename): file=open(filename) fileList=file.readlines()#返回全部行 ,行后有\n---列表 returnMat=numpy.zeros((len(fileList),3)) index=0 labels=[] for st in fileList: st=st.strip()#移除字符串头尾指定的字符(默认为空字符) 这里移除\n strList=st.split('\t')#str.split(sep=None, maxsplit=-1 无限制) returnMat[index,:]=strList[0:3] labels.append(int(strList[-1])) index+=1 return returnMat,labels #数据归一化 def Normalize(dataMat): #Array min_value=dataMat.min(0) max_value=dataMat.max(0) range_value=max_value-min_value normMat=dataMat-tile(min_value,(dataMat.shape[0],1)) normMat=normMat/tile(range_value,(dataMat.shape[0],1)) return normMat,range_value,min_value #测试KNN错误率 def datingClassTest(): datingData,datingLabels=file2matrix('datingTestSet2.txt') datingData,datingRange,datingMinValue=Normalize(datingData) testnum=int(datingData.shape[0]/10) #100 error_count=0; for i in range(testnum): label=classify0(datingData[i],datingData[testnum:datingData.shape[0]],datingLabels[testnum:datingData.shape[0]],3) if label!=datingLabels[i]: error_count+=1; print('错误率:%f'%(error_count/float(testnum))) #测试 ''' datingDataMat,datingLabels=file2matrix('datingTestSet2.txt') datingDataMat,datingDataRange,datingDataMin=Normalize(datingDataMat) fg=plt.figure() subfg1=fg.add_subplot(111) subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels)) #subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),1*array(tile([1],(array(datingLabels).shape[0],1)))) plt.xlabel('玩视频耗时百分比') plt.ylabel('周消耗冰激凌公升数') plt.show() datingClassTest() ''' '''*********************************************************************''' def classifyPerson(): datingData,datingLabels=file2matrix('datingTestSet2.txt') datingData,datingRange,datingMinValue=Normalize(datingData) resultClass=['不喜欢','一般','有魅力'] miles=float(input('每年飞行里程数:')) game=float(input('玩游戏小号百分比:')) ice=float(input('每周冰淇淋公升:')) data=array(([miles,game,ice]-datingMinValue)/datingRange) label=classify0(data,datingData,datingLabels,3) print('类型是:',resultClass[label-1]) #数据的分类标签1,2,3 #测试 ''' classifyPerson() ''' '''*********************************************************************''' '''*****************************手写识别*********************************''' '''*********************************************************************''' def img2vector(filename): file=open(filename) returnVec=numpy.zeros((1,1024)) for i in range(32): fileString=file.readline() for j in range(32): returnVec[0,i*32+j]=fileString[j] file.close() return returnVec def handWriteClassTest(): trainList=os.listdir('trainingDigits') DT=len(trainList) trainArray=numpy.zeros((DT,1024)) labels=[] for i in range(DT): filename=trainList[i] labels.append(int(filename[0])) trainArray[i,:]=img2vector('trainingDigits/%s'%filename) testList=os.listdir('trainingDigits') DS=len(testList) error_count=0; for j in range(DS): filename=testList[j] label=int(filename[0]) testArray=img2vector('trainingDigits/%s'%filename) testLabel=classify0(testArray,trainArray,labels,3) if label!=testLabel: error_count+=1 error_rate=error_count/DS print('错误率:%f'%error_rate) #测试 ''' handWriteClassTest() '''
机器学习实战-KNN算法
最新推荐文章于 2024-03-23 16:40:14 发布