前言
由于我们使用的数据集是 MNIST 数据集,而该数据集的手写数字为灰度图像,所以对色差非常敏感,如果手写数字图像的质量不高或者存在噪声,极有可能会导致识别错误。不过,对于初学者来说,这类数据集已经足够用于理解和学习 KNN 算法的应用。
训练模型
import pandas as pd # 处理数据集
import joblib # 加载和保存训练好的模型
from sklearn.datasets import fetch_openml # 手写数字图像
from sklearn.model_selection import train_test_split # 分割训练数据和测试数据并调整比例
from sklearn.neighbors import KNeighborsClassifier # K-最近邻分类器 也就是KNN算法 根据最近的K个邻居进行预测
from sklearn.metrics import accuracy_score # 计算模型的准确率
mnist = fetch_openml('mnist_784',version=1)
# 取出名为mnist_784的数据集 版本为1 该数据集里包含了7万张0~9的手写数字图像
# 一个有趣的背景 该数据集中的6万张训练图像是由美国人口普查局的工作人员所写 而一万张测试图像是来自美国高中生
X = pd.DataFrame(mnist['data'])
# 提取该数据集中的特征 这些特征为二维向量且在数据中的索引为data 所以需要用DataFrame来处理他们
y = pd.Series(mnist.target).astype('int')
# 对于此数据集中的标签 也就是一个28*28的二维向量所对应的目标值 是一维向量 所以用Series处理即可
# 因为目标值被设计为字符串类型 所以要用astype('int')将其转化为int类型
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)
# 使用train_test_split函数将特征和标签(也就是上面的X和y)以80%和20%的比例进行分割
# 使X_train和y_train用于训练 X_test和y_test用于测试
estimator = KNeighborsClassifier(n_neighbors=3)
# 实例化一个K-最近邻分类器 并将n_neighbors设置为3 这也就意味着在以后的数据测试或模型使用中
# 该算法会自动查找其在特征空间中距离最近的三个邻居(一般使用欧氏距离)
# 并通过这三个邻居的类别来决定新数据点的类别
estimator.fit(X_train,y_train)
# 开始正式训练模型 使用这个分类器的fit方法 传入训练数据和其对应的标签
# 让模型学习并记住数据特征和目标值之间的关系 以便以后可以用来做预测
y_pred = estimator.predict(X_test)
# 模型训练完成之后 就需要使用上面拆分出的测试数据来进行预测 计算模型的准确度。
# 通过调用分类器的predict方法 传入测试数据X_test 就会根据之前学到的规律 返回对每个测试样本的预测结果。
print(accuracy_score(y_test,y_pred))
# 通过测试结果的对比来计算出模型的精准度
joblib.dump(estimator, '../mnist_784.pth')
# 最后 调用joblib库中的dump方法 传入训练好的模型并设置文件名为‘mnist_784.pth'
# 这样模型就被保存在了磁盘上 以后就可以在不需要重新训练模型的情况下 直接加载使用这个已经训练好的模型
数字预测
import warnings # 忽略警告信息
import joblib # 加载和保存训练好的模型
import numpy as np # 操作数组
from PIL import Image # 处理图像文件
def digit_test():
warnings.filterwarnings("ignore")
# 因为训练时的特征名称和现在的特征名称不一样 会报warning但是不影响运行和结果
# 所以直接忽略这个warning就可以了 如果想解决的话在测试模型前的代码中加上
# X.columns = [f'pixel{i}' for i in range(X.shape[1])] 即可
model = joblib.load('../mnist_784.pth')
# 使用joblib方法加载刚才训练好的模型
filename = '../handwritten_digits/digit_1.png'
# 将图片文件的路径储存在filename里 方便下面直接调用
img = Image.open(filename).convert('L')
# 用Image.open方法将该路径上的图片打开 并且使用convert方法将其转换为灰度图
img = img.resize((28, 28))
# 将图片大小压缩为28*28的格式 以符合模型的输入要求
img = np.array(img).reshape(1,-1)
# 用np.array方法把img转成数组类型 reshape方法中第一个‘1’ 是用于将这个二位数组拉伸成一个只有一行的数组
# 第二个‘-1’是让Numpy自动计算出剩余维度的大小 使得数组展平成一个包含784个元素的数组
predict = model.predict(img)
# 将处理好的img投入到模型中进行预测 然后该模型会根据设定中最近的三个‘邻居'
# 进行投票后返回最终结果
print(f'测试结果为: {predict[0]}')
# 因为返回的是一个数组的形式 所以其中的第一个数输出 就是结果
if __name__ == '__main__': # 主函数
digit_test()
加权优化
import warnings # 忽略警告信息
import joblib # 加载和保存训练好的模型
import numpy as np # 操作数组
from PIL import Image # 处理图像文件
class DigitRecognizer:
# 创建一个手写数字识别器的类
def __init__(self, model_path):
self.model = joblib.load(model_path)
# 使用joblib方法加载训练好的模型
self.X_model = self.model._fit_X
self.y_model = self.model._y
# 取出模型的训练数据和标签
def compute_distance(self, x1, x2):
return np.sqrt(np.sum((x1 - x2) ** 2))
# 通过公式计算两个数之间的欧式距离
def compute_weight(self, distance):
a, b = 1, 1
# 定义用来调整权重的参数 a可以理解为一个平滑引子 用来避免发生除以0的情况
return b / (distance + a)
# 计算并返回权重值 这个公式通过带入不同的距离可以发现 距离越小 权重越大 距离越大 权重越小
def predict_digit(self, filename):
img = Image.open(filename).convert('L')
img = img.resize((28, 28))
img = np.array(img).reshape(1, -1)
# 处理传入进来的图像 之前解释过了就不再说一遍了
distances = []
# 创建一个空列表 用于储存每个训练样本与输入图像的距离和标签
for i, X_train in enumerate(self.X_model):
# 遍历训练集中的每一个样本
distance = self.compute_distance(img, X_train.reshape(1, -1))
# 将图像和训练样本投入到compute_distance函数中进行计算 别忘了转换训练样本以匹配imn的形状
weight = self.compute_weight(distance)
# 计算权重
distances.append((weight, self.y_model[i]))
# 将权重和对应的标签作为元组添加到列表中
distances.sort(key=lambda x: x[0], reverse=True)
# 用lambda表达式让这个列表按照权重进行降序排序
k_neighbors = distances[:3]
# 选择排序后权重最大的三个邻居
weighted_votes = {}
# 创建一个空字典用于记录每个标签的加权投票结果
for weight, label in k_neighbors:
# 遍历三个邻居的权重和标签
if label in weighted_votes:
weighted_votes[label] += weight
# 如果这个标签已经在字典中存在 累加这个权重即可
else:
weighted_votes[label] = weight
# 反之如果不在的话就创建一个这个标签的字典来保存当前权重
predictions = max(weighted_votes, key=weighted_votes.get)
# 从加权投票结果中选出权重最大的标签作为最终的预测结果
return predictions
# 返回这个结果
def digit_test():
warnings.filterwarnings("ignore")
recognizer = DigitRecognizer('../mnist_784.pth')
filename = '../handwritten_digits/digit_1.png'
prediction = recognizer.predict_digit(filename)
print(f'测试结果为: {prediction}')
# 关于预测图像 之前的文件也详细说过了 就不再讲了
if __name__ == '__main__': # 主函数
digit_test()
(扩展)简单GUI界面软件实现
import warnings # 忽略警告信息
import sys # 系统相关功能
import joblib # 加载和保存训练好的模型
import numpy as np # 操作数组
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QFileDialog, QLabel # GUI构建模块 下面用到了就了解了
from PyQt5.QtGui import QPixmap # 处理图像显示
from PIL import Image # 处理图像文件
class MainWindow(QWidget):
# 创建GUI主窗口的类 并且继承QWidget的方法 可以完全自定义窗口的外观和操作
def __init__(self):
super().__init__()
# 在执行这个初始化函数的同时 调用父类(QWidget)方法
self.init_ui()
self.model = joblib.load('../mnist_784.pth')
# 使用joblib方法加载训练好的模型
def init_ui(self):
self.setWindowTitle('手写数字识别')
# 设置这个窗口的标题 也就是打开窗口后在最上方居中边框的位置
self.resize(1000, 600)
# 调整窗口的大小
layout = QVBoxLayout()
# 创建垂直布局 这样后面被添加到垂直布局里的子部件就会从上到下垂直排列
# 并且根据窗口大小的变化实时调整子部件的大小和位置 并均匀分布 防止重叠
self.btn = QPushButton('加载图片', self)
# 创建一个按钮 并在按钮上水平居中显示"Add Photo"的文本
# 对括号中的self做一个简单的解释: 这个self是指定了这个类作为按钮的"父控件“
# 将按钮添加到这个窗口中 等这个窗口被关闭时 该按钮就被自动销毁 用于避免内存泄露
self.btn.setFixedSize(200, 200)
# 调整按钮的大小
self.btn.clicked.connect(self.load_Image)
# 将该按钮的点击信号连接到self.loadImage这个函数 这样当按钮被点击时就会触发这个函数
layout.addWidget(self.btn)
# 将这个按钮添加到布局中
self.resultLabel = QLabel('测试结果为:', self)
# 创建一个标签用于显示最后的结果
layout.addWidget(self.resultLabel)
# 将结果标签也添加到布局中
self.imageLabel = QLabel(self)
# 创建一个标签用于显示测试的图片
layout.addWidget(self.imageLabel)
# 将这个图片标签也添加到布局中
self.setLayout(layout)
# 将刚才创建的布局设置为当前窗口的布局管理器 这样刚才添加到布局设置里的标签就可以被自动调整后显示了
def load_Image(self):
options = QFileDialog.Options()
# 这个方法就是创建点击图片的时候触发的那个文件选择框
filename, _ = QFileDialog.getOpenFileName(self, "请选择图片", "", "All Files (*)", options=options)
# 打开文件选择框 并且选取文件传递给对话框 ”“代表默认目录 ”All Files (*)"则可以显示选择所有类型的文件
if filename:
pixmap = QPixmap(filename)
# 使用QPixmap方法加载选择的图像 用QPixmap的主要原因是因为其可以和QLabel兼容 可以直接加载到imageLabel中
self.imageLabel.setPixmap(pixmap)
# 将加载的图像设置为imageLabel 这样可以在窗口中显示出来
self.imageLabel.adjustSize()
# 将ImageLabel调整为合适的大小以适应图像
prediction = self.predict_Digit(filename)
# 调用predictDigit函数进行预测 并将值返回给prediction
self.resultLabel.setText(f'测试结果为:{prediction}')
# 将result_Label的文本显示内容添加上刚刚预测的结果
def predict_Digit(self, filename):
img = Image.open(filename).convert('L')
# 用Image.open方法将该路径上的图片打开 并且使用convert方法将其转换为灰度图
img = img.resize((28, 28))
# 将图片大小压缩为28*28的格式 以符合模型的输入要求
img = np.array(img).reshape(1, -1)
# 用np.array方法把img转成数组类型 reshape方法中第一个‘1’ 是用于将这个二位数组拉伸成一个只有一行的数组
# 第二个‘-1’是让Numpy自动计算出剩余维度的大小 使得数组展平成一个包含784个元素的数组
prediction = self.model.predict(img)
# 将处理好的img投入到模型中进行预测 然后该模型会根据设定中最近的三个‘邻居'
# 进行投票后返回最终结果
return prediction[0]
# 因为返回的是一个数组的形式 所以其中的第一个数输出 就是结果 返回该结果即可
if __name__ == '__main__': # 主函数
warnings.filterwarnings("ignore")
# 因为训练时的特征名称和现在的特征名称不一样 会报warning但是不影响运行和结果
# 所以直接忽略这个warning就可以了 如果想解决的话在测试模型前的代码中加上
# X.columns = [f'pixel{i}' for i in range(X.shape[1])] 即可
app = QApplication(sys.argv)
# 创建一个QApplication对象 负责管理该程序的控制流和其他设置
ex = MainWindow()
ex.show()
sys.exit(app.exec_())
# app.exec_()是进入该程序的主循环 开始上述事件的处理 exit时确保该程序可以干净地退出