关闭

Machine Learning In Action -- kNN (k Nearest Neighbors)

153人阅读 评论(0) 收藏 举报
分类:

k最近邻分类算法:k Nearest Neighbors

k最近邻分类算法是最简单的机器学习算法之一,主要应用在对未知事物的识别。

主要思想:

如果一个样本在特性空间的k个最相似样本的大多数都以属于同一个类别,那么这个样本也属于该类别。

算法优点

  • 算法准确度较高
  • 对数据不作假设
  • 适用于交叉或重叠较多的待分样本集

算法缺点

  • 计算量大
  • 内存消耗大
  • 样本数量不平衡时易受影响

示例图

这里写图片描述
绿色圆点表示未分类的样本,令其为A。如果我们把k设成3, 那么离A最近的3个样本就是黑色圆中所包含的样本。由于红色三角形有2个,而蓝色正方形只有一个。所以最终的分类结果为红色三角形。

代码 Python

这里主要参考了Machine Learning In Action这本书中的代码。其kNN的具体python实现代码如下。
Note: 运行代码之前,请安装好matplotlib。

import numpy as np
import operator
from os import listdir

def knn_classify(vec_in, data_set, labels, k):
    rows = data_set.shape[0]
    diffs = np.tile(vec_in, (rows, 1)) - data_set
    sq_diffs = diffs ** 2
    sq_distances = sq_diffs.sum(axis = 1)
    distances = sq_distances ** 0.5
    sorted_dist_indices = distances.argsort()

    class_cnt = {}
    for i in range(k):
        vote_label = labels[sorted_dist_indices[i]]
        class_cnt[vote_label] = class_cnt.get(vote_label, 0) + 1
    sorted_class_cnt = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_cnt[0][0]

实例:手写数字0-9的识别

在Machine Learning In Action中,有一个将kNN算法用于手写数字识别的例子。该例子中的training set共包含2000个样本,也就是说每个数字大约有200个样本。每个数字通过处理,以32x32大小的0/1数字组成。
应用kNN算法,先将每个手写数字变成1x1024的向量。保存在training set数组中。当未判别的数字出现时,用该数字的向量于training set中的每一个向量计算距离,选其中的top k个样本进行投票,最后哪个类别的数量最多,就将该数字判定成那个类别。
这里写图片描述
具体代码:

def img2vector(filename):
    ret_vec = np.zeros((1, 1024))
    fp = open(filename)
    for i in range(32):
        line_str = fp.readline()
        for j in range(32):
            ret_vec[0, 32*i + j] = int(line_str[j])
    fp.close()
    return ret_vec

def hand_writing_test():
    training_path = './digits/trainingDigits'
    test_path = './digits/testDigits'
    training_files = listdir(training_path)
    m = len(training_files)
    training_mat = np.zeros((m, 1024))
    labels = []
    for i in range(m):
        filename = training_files[i]
        class_num_str = filename.split('_')[0]
        labels.append(class_num_str)
        training_mat[i, :] = img2vector(training_path + '/%s' % filename)
    # print training_mat
    test_files = listdir(test_path)
    m = len(test_files)
    err_cnt = 0.0
    for i in range(m):
        filename = test_files[i]
        class_num_str = filename.split('_')[0]
        vec_in = img2vector(test_path + '/%s' % filename)
        ret = knn_classify(vec_in, training_mat, labels, 3)
        if str(ret) != class_num_str:
            print "file %s, classifier result: %s, real ans: %s." % (filename, ret, class_num_str)
            err_cnt += 1.0
    print "The total number of error is %d." % err_cnt
    print "The failure rate is %f." % (err_cnt / float(m))



if __name__ == '__main__':
    hand_writing_test()

这里写图片描述

0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:399178次
    • 积分:5272
    • 等级:
    • 排名:第5157名
    • 原创:136篇
    • 转载:60篇
    • 译文:2篇
    • 评论:60条
    文章分类
    最新评论