工作原理:存在一个样本数据集合,也称为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般只选择样本训练集中前k个最相似的数据。这也是k-近邻算法中k的来源
工作步骤:通过计算待预测对象与样本中各个对象的距离,然后排序,选择出k个距离最小的对象,查看这些对象的标签;最后vote,以多者的标签作为待测对象的标签。学者们在kNN研究中的深入和扩展,还提出了许多其他算法,不过这里只按照书本上的学习最基础的算法,后面的遇到再学习。
可修改的参数:k,训练集的大小
kNN的不足:1)计算量太大,耗时耗内存;2)如果样本不平衡,比如一类标签过多,那么vote时将会很大影响待测标签的预测,错误率会提升。
这里的代码来自于Machine Learning In Action书中Chapter 2,我只是做了些注释
简单介绍一下几个函数的功能:
classify0是kNN算法的实现函数,inX是待预测对象,dataSet是样本矩阵,labels是样本标签,k代表选取距离最小的k个样本,距离是欧氏距离;file2matrix是处理文本的函数,将文本格式的样本转换成numpy.matrix格式;autoNorm是将样本矩阵进行标准化,标准化的作用是避免在计算距离时,某个参数因为其数值较大而占较大的权重,比如第一个参数远大于另外两个参数,但不代表它比另外两个参数重要,所以进行标准化,使三个参数的权重相同;datingClassTest函数是将样本分为两部分,一部分作为样本,另一部分作为待测对象,以检查算法的预测值与真实值是否一样,从而计算算法的错误率。
接着对其中遇到的函数进行学习:
1)shape:返回array或matrix对象的维度,比如样本的维度为(1000,4),shape[0]就为1000;
2)tile(A, reps):构造一个以A为基元素,重复reps次数的array,其中reps可以为一个数,也可以为一个array,如
具体构造方法,我认为是从右向左,维数依次递增构造,比如reps=(4,3,2),首先在第一维上重复2次,得到[A,A]=B,然后再在第二维上重复3次,[B,B,B]=C,最后在第三维上重复4次,即[C,C,C,C],结果依次嵌套即可;所以classify0中的意思是构造一个(1000,1)的inX对象,即1000行*1列的inX对象,维数与dataSet相同,所以可以执行减操作;
3)sum(self, axis=None, dtype=None, out=None):通常采用arrayName.sum()这种形式来调用该函数。axis表示计算的轴,默认是None,即计算对象中所有元素值,以a = array([1,2],[3,4])为例,a.sum()为10;a.sum(axis=0)为[4,6];a.sum(axis=1)为[3,7];这里axis=1在二维中表示按行计算,axis=0表示按列计算。dtype和out参数通常不用,分别表示返回的元素类型和输出结果存放的array。
4)argsort(a,axis=None,kind='quicksort',order=None,fill_value=None):numpy.argsort()可以对array进行排序,并返回一个索引array,索引的顺序是最终排序完的顺序,索引值是在原array中的下标位置。如a=mat.array([3,1,2,4]),b=a.argsort(),则b为array([1,2,0,3]),b中的1,2,0,3分别表示a中的下标。此外,其他可选参数axis表示排序以哪个维为关键字排序;kind表示排序方法,有quicksort,mergesort,heapsort三种;order,若有多个关键字,指明排序的关键字顺序?;fill_value表示如果待排序array为mask_array时,将mask的值填充为fill_value的值。
5)dict.get(key[,default]):在dict中查找key,如果存在,则返回其值,不存在则返回default值。
6)sorted(iterable[,key][,reverse]):排序函数,排序的结果会生成一个新的list(与list.sort()的原地排序不同)。iterable是在list中的遍历,如语句中的.iteritems();key是排序的关键字或方法,如语句中的itemgetter(1)表示取list的第二个值作为排序关键字,也可以同时用好几个关键字,将按关键字的顺序进行排序;reverse表示是增序还是降序,等于True表示降序。
7)open():默认以只读方式打开文件。
8)readlines():读取当前一行数据;len(fr.readlines())返回行数。
9)zeros(shape,dtype=None,order='C'):返回一个维度为shape,数据类型为dtype的matrix,内容全部为0;dtype默认为float;order表示的是以什么方式存储,有C和F两种,C- 和Fortran-contiguous。
10)strip():读取整行后,去除首尾的whitespace,包括'\t','\n','r'等等。
11)min(axis,):这里主要解释axis参数,当等于None时,返回整个matrix或array最小的值;当axis=0时,返回列最小值,也就是该句中的用法;当axis=1时,返回行最小值;当axis>1时,按维度选择。
12)/ :这里的除号是指按元素除,normDataSet中的每个元素除tile(ranges,(m,1))中对应的每个元素。如果要做的是矩阵的除法,在numpy中则需要用到linalg.sovle(matA,matB)来实现。
补充一个函数:listdir(path):需要通过from os import listdir导入,功能是获得当前path下的所有文件列表,以list形式返回,可以利用下标访问。
用python实现kNN算法之简单分类