机器学习实战--k近邻算法(续)

    接续上次的k近邻算法,上一篇博文地址,这里用一个新的实例进行算法的验证。

一个手写数字识别系统,为了处理方便,书中已经将样本训练好,并转化为txt格式方便后续处理。具体格式如下:

这是0的其中一种表示方式。

    我们的目标是输入一个类似的数字,系统能够识别出来即可。

————————————————————————————————————————

一、具体算法实现:

# coding: UTF-8
import numpy as np
import os
import operator


# 将txt格式的数字转化为1*1024的向量格式
def img2vector(filename):
    return_vector = np.arange(1024)
    with open(filename) as f:
        for i in range(32):
            line = f.readline()
            for j in range(32):
                return_vector[32 * i + j] = int(line[j])
    return return_vector


# 分类算法实现
def classify(input_vector, trained_mat, class_list, k=3):
    # 欧式距离计算
    rows = trained_mat.shape[0]
    input_mat = np.tile(input_vector, (rows, 1))
    diff_mat = input_mat - trained_mat
    squ_mat = diff_mat ** 2
    sum_mat = squ_mat.sum(axis=1)
    d = sum_mat ** 0.5
    # 根据距离排序,获得排序后的索引
    sorted_d = d.argsort()
    # 创建用来统计某一类标签的字典
    class_count = {}
    for i in xrange(k):
        class_label = class_list[sorted_d[i]]
        class_count[class_label] = class_count.get(class_label, 0) + 1
    # 根据统计得到的类别数量,进行排序,返回一个包含元组的列表,[(),(),...()]
    sorted_class = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class[0][0]


# 用N个行向量构成训练矩阵
# 通过txt的命名获得分类
def vector2mat():
    training_file_list = os.listdir('./trainingDigits')
    rows = len(training_file_list)
    trained_mat = np.zeros((rows, 1024))
    class_list = []
    for index, each_file in enumerate(training_file_list):
        digits, _ = each_file.split('.')
        class_list.append(digits.split('_')[0])
        trained_mat[index, :] = img2vector('./trainingDigits/%s' % each_file)
    print trained_mat
    return trained_mat, class_list


# 系统错误率测试
# 通过另一组test数据作为测试样本输入
def handwriting_test(trained_mat, class_list):
    test_file_list = os.listdir('./testDigits')
    test_num = len(test_file_list)
    err_count = 0.0
    for each_file in test_file_list:
        # 去掉.txt的后缀
        digits = each_file.split('.')[0]
        # 得到已知分类,known_label类型应该与class_list中元素类别一致
        known_label = digits.split('_')[0]
        input_vector = img2vector('./testDigits/%s' % each_file)
        classify_result = classify(input_vector, trained_mat, class_list)
        print "Predict:%s\tReal answer:%s\n" % (classify_result, known_label)
        if known_label != classify_result:
            err_count += 1.0
    print "total err num:%d" % err_count
    print "err rate:%.2f" % (err_count / (float(test_num)))


def main():
    trained_mat, class_list = vector2mat()
    handwriting_test(trained_mat, class_list)


if __name__ == '__main__':
    main()

————————————————————————————————————————

二、结果分析


错误率在1%左右,说明识别准确率还是挺高的;但是因为每次输入一个样本,都要计算与所有训练样本之间的欧式距离,运算量还是挺大的,速度上稍显不足。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值