一开始看问题还是蛮多的,看python的基础语法看完就忘了,只有到真正需要用的时候才会了解到相应的用法。
下面是带注释的代码。
from numpy import *
import operator
from os import listdir
def classify0(inX, dataSet, labels, k):#inx需要进行分类的样本向量
dataSetSize = dataSet.shape[0]#获得dataset的行数,也就是训练样本的数量
diffMat = tile(inX, (dataSetSize,1)) - dataSet#将inx复制多列,列数和样本数一样,然后和样本相减
sqDiffMat = diffMat**2#将结果平方 ps:将结果数组中每个元素平方
sqDistances = sqDiffMat.sum(axis=1)#axis=1的意思是按行求和,就是将每一行所有元素相加
distances = sqDistances**0.5#和开方
sortedDistIndicies = distances.argsort() #对矩阵进行排序,将行号或者说索引按递增序列返回,注意变成了一维数组
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]#距离最近的第i个点的标签(类别)
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#对应的类别的数量加一(字典中没有此类别就返回0)
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)#根据value值进行降序排序
return sortedClassCount[0][0]#返回value值最大的类
- zeros的用法
- strip函数的用法
- readlines的用法
- 注意切片是前切后不切
- 循环之前记得再把文件读一遍
- 教你一个小技巧,在控制台导入模块的时候,用os包去更改工作目录或者工作文件夹。具体见此
def file2matrix(filename):
fr = open(filename)
numberoflines = len(fr.readlines())
returnMat = zeros([numberoflines, 3])
classlabelvec = []
fr = open(filename)#这行很重要,一开始以为是多余的就截掉了,后来发现截掉以后循环不执行
index = 0
for line in fr.readlines():
line = line.strip()#截取掉所有的回车符
listfromline = line.split('\t')
returnMat[index, :] = listfromline[0:3]
classlabelvec.append(int(listfromline[-1]))#注意类型转换,所谓的向量其实就是个列表
index += 1
return returnMat, classlabelvec
直接贴一下代码,拉长一下篇幅。
def autonorm(dataset):
minval = dataset.min(0)#使得函数可以从列中选取最小值,而不是当前行的最小值
maxval = dataset.max(0)#
ranges = maxval - minval
m = dataset.shape[0]
normdataset = zeros(shape(dataset))
normdataset = dataset - tile(minval, (m, 1))
normdataset = normdataset/tile(ranges, (m, 1))
return normdataset, ranges
def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat,datingLabels = file2matrix('datingTestSet.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])#这里要改成%s,%s源代码有点问题
if (classifierResult != datingLabels[i]): errorCount += 1.0
print "the total error rate is: %f" % (errorCount/float(numTestVecs))
print errorCount