目的
采用k-近邻算法实现手写识别系统。这里采用0和1组成数字0-9的形状,再用算法对这些形状进行识别,来分辨出形状属于0-9那个数字。并计算出k-近邻算法识别手写数字的错误率。
数据说明
数据来自《机器学习实战》,分为测试集和训练集。单个数据如下图所示,表示数据0。
算法过程
- 收集数据:提供文本文件。
- 准备数据:编写函数classify0() ,将图像格式转换为分类器使用的制格式。
- 分析数据:在Python命令提示符中检查数据,确保它符合要求。
- 训练算法:此步驟不适用于k-近邻算法。
- 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
- 使用算法
准备数据:将图像转换为测试向量
def img2vector(filename):
"""
该函数创建1*1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行。
并将每行的头32个字符值存储在NumPy数组中,然后返回数组
"""
return_vect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
line_str = fr.readline()
for j in range(32):
return_vect[0, 32*i+j] = int(line_str[j])
return return_vect
测试算法:使用k-近邻算法识别手写数字
def hand_writing_class_test():
"""
手写数字识别系统的测试代码
"""
hw_labels = []
traing_file_list = os.listdir('../Data/Ch02/trainingDigits') # 获取目录内容
m = len(traing_file_list)
matrix_of_training = np.zeros((m, 1024))
for i in range(m):
# 从文件名解析分类数字
file_name_str = traing_file_list[i]
file_str = file_name_str.split('.')[0]
class_num_str = int(file_str.split('_')[0])
hw_labels.append(class_num_str)
matrix_of_training[i, :] = img2vector('../Data/Ch02/trainingDigits/%s' % file_name_str)
test_file_list = os.listdir('../Data/Ch02/testDigits')
error_count = 0.0
m_test = len(test_file_list)
for i in range(m_test):
file_name_str = test_file_list[i]
file_str = file_name_str.split('.')[0]
class_num_str = int(file_str.split('_')[0])
vector_under_test = img2vector('../Data/Ch02/testDigits/%s' % file_name_str)
classifier_result = classify0(vector_under_test, matrix_of_training,
hw_labels, 3)
print("the clasifier came back with: %d, the real answer is %d" % (classifier_result, class_num_str))
if classifier_result != class_num_str:
error_count += 1.0
print("\nthe total number of errors is %d" % error_count)
print("\nthe total error rate is: %f" % (error_count/float(m_test)))
完整代码
# -*- coding: utf-8 -*-
# @Function : 使用k-近邻算法识别手写数字
import numpy as np
import os
import operator
def classify0(in_x, data_set, labels, k):
"""
k-近邻算法
:param in_x: 用于分类的输入向量X
:param data_set: 输入的训练样本集data_set
:param labels: 标签向量,其元素数目与矩阵data_set的行数相同
:param k: 选择最近邻居的数目
:return: 发生频率最高的元素标签
"""
dataset_size = data_set.shape[0]
# 原型:numpy.tile(A,reps)
# tile共有2个参数,A指待输入数组,reps则决定A重复的次数。整个函数用于重复数组A来构建新的数组。
# 计算距离,欧式距离公式:sqrt(pow(xA0-xB0, 2) + pow(xA1-xB1, 2))
diff_mat = np.tile(in_x, (dataset_size, 1)) - data_set
sq_diff_mat = diff_mat ** 2
sq_distances = sq_diff_mat.sum(axis=1)
distances = sq_distances ** 0.5
# numpy.argsort() 返回排好序的序列的索引
sorted_dist_indicies = distances.argsort()
class_count = {}
# 选择距离最小的k个节点
for i in range(k):
vote_I_label = labels[sorted_dist_indicies[i]]
class_count[vote_I_label] = class_count.get(vote_I_label, 0) + 1
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
def img2vector(filename):
"""
该函数创建1*1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行。
并将每行的头32个字符值存储在NumPy数组中,然后返回数组
"""
return_vect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
line_str = fr.readline()
for j in range(32):
return_vect[0, 32*i+j] = int(line_str[j])
return return_vect
def hand_writing_class_test():
"""
手写数字识别系统的测试代码
"""
hw_labels = []
traing_file_list = os.listdir('../Data/Ch02/trainingDigits') # 获取目录内容
m = len(traing_file_list)
matrix_of_training = np.zeros((m, 1024))
for i in range(m):
# 从文件名解析分类数字
file_name_str = traing_file_list[i]
file_str = file_name_str.split('.')[0]
class_num_str = int(file_str.split('_')[0])
hw_labels.append(class_num_str)
matrix_of_training[i, :] = img2vector('../Data/Ch02/trainingDigits/%s' % file_name_str)
test_file_list = os.listdir('../Data/Ch02/testDigits')
error_count = 0.0
m_test = len(test_file_list)
for i in range(m_test):
file_name_str = test_file_list[i]
file_str = file_name_str.split('.')[0]
class_num_str = int(file_str.split('_')[0])
vector_under_test = img2vector('../Data/Ch02/testDigits/%s' % file_name_str)
classifier_result = classify0(vector_under_test, matrix_of_training,
hw_labels, 3)
print("the clasifier came back with: %d, the real answer is %d" % (classifier_result, class_num_str))
if classifier_result != class_num_str:
error_count += 1.0
print("\nthe total number of errors is %d" % error_count)
print("\nthe total error rate is: %f" % (error_count/float(m_test)))
if __name__ == '__main__':
hand_writing_class_test()