KNN(k-NearestNeighbor)识别minist数据集

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Gavin__Zhou/article/details/49383891

KNN算法介绍

KNN(K Nearest Neighbors,K近邻 )算法是机器学习所有算法中理论最简单,最好理解的。KNN是一种基于实例的学习,通过计算新数据与训练数据特征值之间的距离,然后选取K(K>=1)个距离最近的邻居进行分类判断(投票法)或者回归。如果K=1,那么新数据被简单分配给其近邻的类。KNN算法算是监督学习还是无监督学习呢?首先来看一下监督学习和无监督学习的定义。对于监督学习,数据都有明确的label(分类针对离散分布,回归针对连续分布),根据机器学习产生的模型可以将新数据分到一个明确的类或得到一个预测值。对于非监督学习,数据没有label,机器学习出的模型是从数据中提取出来的pattern(提取决定性特征或者聚类等)。例如聚类是机器根据学习得到的模型来判断新数据“更像”哪些原数据集合。KNN算法用于分类时,每个训练数据都有明确的label,也可以明确的判断出新数据的label,KNN用于回归时也会根据邻居的值预测出一个明确的值,因此KNN属于监督学习。

KNN算法流程

  • 选择一种距离计算方式, 通过数据所有的特征计算新数据与已知类别数据集中的数据点的距离
  • 按照距离递增次序进行排序,选取与当前距离最小的k个点
  • 对于离散分类,返回k个点出现频率最多的类别作预测分类;对于回归则返回k个点的加权值作为预测值
  • 按照投票的方式在返回的k个类别中选择出现次数最多的类别作为最终的预测类别

KNN算法关键

  • 数据的所有特征都要做可比较的量化。
    若是数据特征中存在非数值的类型,必须采取手段将其量化为数值。举个例子,若样本特征中包含颜色(红黑蓝)一项,颜色之间是没有距离可言的,可通过将颜色转换为灰度值来实现距离计算。另外,样本有多个参数,每一个参数都有自己的定义域和取值范围,他们对distance计算的影响也就不一样,如取值较大的影响力会盖过取值较小的参数。为了公平,样本参数必须做一些scale处理,最简单的方式就是所有特征的数值都采取归一化处置。
  • 需要一个distance函数以计算两个样本之间的距离。
    距离的定义有很多,如欧氏距离、余弦距离、汉明距离、曼哈顿距离等等。 一般情况下,选欧氏距离作为距离度量,但是这是只适用于连续变量。通常情况下,如果运用一些特殊的算法来计算度量的话,K近邻分类精度可显著提高,如运用大边缘最近邻法或者近邻成分分析法。
  • 确定K的值
    K是一个自定义的常数,K的值也直接影响最后的估计,一种选择K值得方法是使用 cross-validate(交叉验证)误差统计选择法。交叉验证的概念之前提过,就是数据样本的一部分作为训练样本,一部分作为测试样本,比如选择95%作为训练样本,剩下的用作测试样本。通过训练数据训练一个机器学习模型,然后利用测试数据测试其误差率。 cross-validate(交叉验证)误差统计选择法就是比较不同K值时的交叉验证平均误差率,选择误差率最小的那个K值。例如选择K=1,2,3,… , 对每个K=i做100次交叉验证,计算出平均误差,然后比较、选出最小的那个。

KNN算法优缺点

一、 简单、有效。
二、 重新训练的代价较低(基本不需要训练)。
三、 计算时间和空间线性于训练集的规模(在一些场合不算太大),样本过大识别时间会很长。
四、 k值比较难以确定。

mnist手写数据识别

mnist是一个手写数字的库,包含数字从0-9,每个图像大小为32*32,详细介绍和数据下载见这里


用到了PIL,numpy这两个python库,没有安装的可以参照我的另外一篇博客去配置安装,这就不多说了
代码是我修改的大牛的原始代码生成的,参见下面的参考文献,我也已经上传CSDN,一份是大牛的原始代码,一份是新的


我们需要使用KNN算法去识别mnist手写数字,具体步骤如下:
首先需要将手写数字做成0 1串,将原图中黑色像素点变成1,白色为0,写成TXT文件;
python代码:

def img2vector(impath,savepath):
    '''
    convert the image to an numpy array
    Black pixel set to 1,white pixel set to 0
    '''
    im = Image.open(impath)
    im = im.transpose(Image.ROTATE_90)
    im = im.transpose(Image.FLIP_TOP_BOTTOM)

    rows = im.size[0]
    cols = im.size[1]
    imBinary = zeros((rows,cols))
    for row in range(0,rows):
        for col in range(0,cols):
            imPixel = im.getpixel((row,col))[0:3]
            if imPixel == (0,0,0):
                imBinary[row,col] = 0
    #save temp txt like 1_5.txt whiich represent the class is 1 and the index is 5
    fp = open(savepath,'w')
    for x in range(0,imBinary.shape[0]):
        for y in range(0,imBinary.shape[1]):
            fp.write(str(int(imBinary[x,y])))
        fp.write('\n')
    fp.close()

结果大概像这样:
这是一个mnist数据的数字3

将所有的TXT文件中的0 1串变成行向量
python代码:

def vectorOneLine(filename):
    rows = 32
    cols = 32
    imgVector = zeros((1, rows * cols)) 
    fileIn = open(filename)
    for row in xrange(rows):
        lineStr = fileIn.readline()
        for col in xrange(cols):
            imgVector[0, row * 32 + col] = int(lineStr[col])
    return imgVector

KNN识别
python代码:

def kNNClassify(testImput, TrainingDataSet, TrainingLabels, k):
    numSamples = dataSet.shape[0] # shape[0] stands for the num of row
    #calculate the Euclidean distance
    diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise
    squaredDiff = diff ** 2 # squared for the subtract
    squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row
    distance = squaredDist ** 0.5
    #sort the distance vector
    sortedDistIndices = argsort(distance)
    #choose k elements
    classCount = {} # define a dictionary (can be append element)
    for i in xrange(k):
        voteLabel = labels[sortedDistIndices[i]]
        #initial the dict
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
    #vote the label as final return
    maxCount = 0
    for key, value in classCount.items():
        if value > maxCount:
            maxCount = value
            maxIndex = key
    return maxIndex 

*识别结果*
识别结果
参考文献:
[1]大牛的博客:http://blog.csdn.net/zouxy09/article/details/16955347
[2]matlab 实现KNN: http://blog.csdn.net/rk2900/article/details/9080821
[3]分类算法的优缺点:http://bbs.pinggu.org/thread-2604496-1-1.html
[4]代码下载地址:http://download.csdn.net/detail/gavin__zhou/9208821
http://download.csdn.net/detail/gavin__zhou/9208827

没有更多推荐了,返回首页

私密
私密原因:
请选择设置私密原因
  • 广告
  • 抄袭
  • 版权
  • 政治
  • 色情
  • 无意义
  • 其他
其他原因:
120
出错啦
系统繁忙,请稍后再试

关闭