KNN实现简单手写数字识别

前言

由于我们使用的数据集是 MNIST 数据集,而该数据集的手写数字为灰度图像,所以对色差非常敏感,如果手写数字图像的质量不高或者存在噪声,极有可能会导致识别错误。不过,对于初学者来说,这类数据集已经足够用于理解和学习 KNN 算法的应用。

训练模型

import pandas as pd    # 处理数据集
import joblib   # 加载和保存训练好的模型
from sklearn.datasets import fetch_openml   # 手写数字图像
from sklearn.model_selection import train_test_split    # 分割训练数据和测试数据并调整比例
from sklearn.neighbors import KNeighborsClassifier  # K-最近邻分类器 也就是KNN算法 根据最近的K个邻居进行预测
from sklearn.metrics import accuracy_score  # 计算模型的准确率

mnist = fetch_openml('mnist_784',version=1)
# 取出名为mnist_784的数据集 版本为1 该数据集里包含了7万张0~9的手写数字图像
# 一个有趣的背景 该数据集中的6万张训练图像是由美国人口普查局的工作人员所写 而一万张测试图像是来自美国高中生

X = pd.DataFrame(mnist['data'])
# 提取该数据集中的特征 这些特征为二维向量且在数据中的索引为data 所以需要用DataFrame来处理他们
y = pd.Series(mnist.target).astype('int')
# 对于此数据集中的标签 也就是一个28*28的二维向量所对应的目标值 是一维向量 所以用Series处理即可
# 因为目标值被设计为字符串类型 所以要用astype('int')将其转化为int类型
    
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)
# 使用train_test_split函数将特征和标签(也就是上面的X和y)以80%和20%的比例进行分割
# 使X_train和y_train用于训练 X_test和y_test用于测试

estimator = KNeighborsClassifier(n_neighbors=3)
# 实例化一个K-最近邻分类器 并将n_neighbors设置为3 这也就意味着在以后的数据测试或模型使用中
# 该算法会自动查找其在特征空间中距离最近的三个邻居(一般使用欧氏距离)
# 并通过这三个邻居的类别来决定新数据点的类别

estimator.fit(X_train,y_train)
# 开始正式训练模型 使用这个分类器的fit方法 传入训练数据和其对应的标签
# 让模型学习并记住数据特征和目标值之间的关系 以便以后可以用来做预测

y_pred = estimator.predict(X_test)
# 模型训练完成之后 就需要使用上面拆分出的测试数据来进行预测 计算模型的准确度。
# 通过调用分类器的predict方法 传入测试数据X_test 就会根据之前学到的规律 返回对每个测试样本的预测结果。

print(accuracy_score(y_test,y_pred))
# 通过测试结果的对比来计算出模型的精准度

joblib.dump(estimator, '../mnist_784.pth')
# 最后 调用joblib库中的dump方法 传入训练好的模型并设置文件名为‘mnist_784.pth'
# 这样模型就被保存在了磁盘上 以后就可以在不需要重新训练模型的情况下 直接加载使用这个已经训练好的模型

数字预测

import warnings   # 忽略警告信息
import joblib   # 加载和保存训练好的模型
import numpy as np  # 操作数组
from PIL import Image  # 处理图像文件

def digit_test():
    warnings.filterwarnings("ignore")
    # 因为训练时的特征名称和现在的特征名称不一样 会报warning但是不影响运行和结果
    # 所以直接忽略这个warning就可以了 如果想解决的话在测试模型前的代码中加上
    # X.columns = [f'pixel{i}' for i in range(X.shape[1])] 即可

    model = joblib.load('../mnist_784.pth')
    # 使用joblib方法加载刚才训练好的模型

    filename = '../handwritten_digits/digit_1.png'
    # 将图片文件的路径储存在filename里 方便下面直接调用

    img = Image.open(filename).convert('L')
    # 用Image.open方法将该路径上的图片打开 并且使用convert方法将其转换为灰度图

    img = img.resize((28, 28))
    # 将图片大小压缩为28*28的格式 以符合模型的输入要求

    img = np.array(img).reshape(1,-1)
    # 用np.array方法把img转成数组类型 reshape方法中第一个‘1’ 是用于将这个二位数组拉伸成一个只有一行的数组
    # 第二个‘-1’是让Numpy自动计算出剩余维度的大小 使得数组展平成一个包含784个元素的数组

    predict = model.predict(img)
    # 将处理好的img投入到模型中进行预测 然后该模型会根据设定中最近的三个‘邻居'
    # 进行投票后返回最终结果

    print(f'测试结果为: {predict[0]}')
    # 因为返回的是一个数组的形式 所以其中的第一个数输出 就是结果

if __name__ == '__main__': # 主函数
    digit_test()

加权优化

import warnings   # 忽略警告信息
import joblib   # 加载和保存训练好的模型
import numpy as np  # 操作数组
from PIL import Image  # 处理图像文件

class DigitRecognizer:
    # 创建一个手写数字识别器的类

    def __init__(self, model_path):
        self.model = joblib.load(model_path)
        # 使用joblib方法加载训练好的模型
        self.X_model = self.model._fit_X
        self.y_model = self.model._y
        # 取出模型的训练数据和标签

    def compute_distance(self, x1, x2):
        return np.sqrt(np.sum((x1 - x2) ** 2))
        # 通过公式计算两个数之间的欧式距离

    def compute_weight(self, distance):
        a, b = 1, 1
        # 定义用来调整权重的参数 a可以理解为一个平滑引子 用来避免发生除以0的情况
        return b / (distance + a)
        # 计算并返回权重值 这个公式通过带入不同的距离可以发现 距离越小 权重越大 距离越大 权重越小

    def predict_digit(self, filename):
        img = Image.open(filename).convert('L')
        img = img.resize((28, 28))
        img = np.array(img).reshape(1, -1)
        # 处理传入进来的图像 之前解释过了就不再说一遍了

        distances = []
        # 创建一个空列表 用于储存每个训练样本与输入图像的距离和标签
        for i, X_train in enumerate(self.X_model):
            # 遍历训练集中的每一个样本

            distance = self.compute_distance(img, X_train.reshape(1, -1))
            # 将图像和训练样本投入到compute_distance函数中进行计算 别忘了转换训练样本以匹配imn的形状

            weight = self.compute_weight(distance)
            # 计算权重

            distances.append((weight, self.y_model[i]))
            # 将权重和对应的标签作为元组添加到列表中

        distances.sort(key=lambda x: x[0], reverse=True)
        # 用lambda表达式让这个列表按照权重进行降序排序

        k_neighbors = distances[:3]
        # 选择排序后权重最大的三个邻居

        weighted_votes = {}
        # 创建一个空字典用于记录每个标签的加权投票结果

        for weight, label in k_neighbors:
        # 遍历三个邻居的权重和标签

            if label in weighted_votes:
                weighted_votes[label] += weight
                # 如果这个标签已经在字典中存在 累加这个权重即可

            else:
                weighted_votes[label] = weight
                # 反之如果不在的话就创建一个这个标签的字典来保存当前权重

        predictions = max(weighted_votes, key=weighted_votes.get)
        # 从加权投票结果中选出权重最大的标签作为最终的预测结果

        return predictions
        # 返回这个结果

def digit_test():
    warnings.filterwarnings("ignore")
    recognizer = DigitRecognizer('../mnist_784.pth')
    filename = '../handwritten_digits/digit_1.png'
    prediction = recognizer.predict_digit(filename)
    print(f'测试结果为: {prediction}')
    # 关于预测图像 之前的文件也详细说过了 就不再讲了

if __name__ == '__main__':  # 主函数
    digit_test()

(扩展)简单GUI界面软件实现

import warnings  # 忽略警告信息
import sys  # 系统相关功能
import joblib  # 加载和保存训练好的模型
import numpy as np  # 操作数组
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QFileDialog, QLabel  # GUI构建模块 下面用到了就了解了
from PyQt5.QtGui import QPixmap  # 处理图像显示
from PIL import Image  # 处理图像文件

class MainWindow(QWidget):
    # 创建GUI主窗口的类 并且继承QWidget的方法 可以完全自定义窗口的外观和操作

    def __init__(self):
        super().__init__()
        # 在执行这个初始化函数的同时 调用父类(QWidget)方法

        self.init_ui()
        self.model = joblib.load('../mnist_784.pth')
        # 使用joblib方法加载训练好的模型

    def init_ui(self):
        self.setWindowTitle('手写数字识别')
        # 设置这个窗口的标题 也就是打开窗口后在最上方居中边框的位置

        self.resize(1000, 600)
        # 调整窗口的大小

        layout = QVBoxLayout()
        # 创建垂直布局 这样后面被添加到垂直布局里的子部件就会从上到下垂直排列
        # 并且根据窗口大小的变化实时调整子部件的大小和位置 并均匀分布 防止重叠

        self.btn = QPushButton('加载图片', self)
        # 创建一个按钮 并在按钮上水平居中显示"Add Photo"的文本
        # 对括号中的self做一个简单的解释: 这个self是指定了这个类作为按钮的"父控件“
        # 将按钮添加到这个窗口中 等这个窗口被关闭时 该按钮就被自动销毁 用于避免内存泄露

        self.btn.setFixedSize(200, 200)
        # 调整按钮的大小

        self.btn.clicked.connect(self.load_Image)
        # 将该按钮的点击信号连接到self.loadImage这个函数 这样当按钮被点击时就会触发这个函数

        layout.addWidget(self.btn)
        # 将这个按钮添加到布局中

        self.resultLabel = QLabel('测试结果为:', self)
        # 创建一个标签用于显示最后的结果

        layout.addWidget(self.resultLabel)
        # 将结果标签也添加到布局中

        self.imageLabel = QLabel(self)
        # 创建一个标签用于显示测试的图片

        layout.addWidget(self.imageLabel)
        # 将这个图片标签也添加到布局中

        self.setLayout(layout)
        # 将刚才创建的布局设置为当前窗口的布局管理器 这样刚才添加到布局设置里的标签就可以被自动调整后显示了
    def load_Image(self):
        options = QFileDialog.Options()
        # 这个方法就是创建点击图片的时候触发的那个文件选择框

        filename, _ = QFileDialog.getOpenFileName(self, "请选择图片", "", "All Files (*)", options=options)
        # 打开文件选择框 并且选取文件传递给对话框 ”“代表默认目录 ”All Files (*)"则可以显示选择所有类型的文件

        if filename:
            pixmap = QPixmap(filename)
            # 使用QPixmap方法加载选择的图像 用QPixmap的主要原因是因为其可以和QLabel兼容 可以直接加载到imageLabel中

            self.imageLabel.setPixmap(pixmap)
            # 将加载的图像设置为imageLabel 这样可以在窗口中显示出来

            self.imageLabel.adjustSize()
            # 将ImageLabel调整为合适的大小以适应图像

            prediction = self.predict_Digit(filename)
            # 调用predictDigit函数进行预测 并将值返回给prediction

            self.resultLabel.setText(f'测试结果为:{prediction}')
            # 将result_Label的文本显示内容添加上刚刚预测的结果


    def predict_Digit(self, filename):
        img = Image.open(filename).convert('L')
        # 用Image.open方法将该路径上的图片打开 并且使用convert方法将其转换为灰度图

        img = img.resize((28, 28))
        # 将图片大小压缩为28*28的格式 以符合模型的输入要求

        img = np.array(img).reshape(1, -1)
        # 用np.array方法把img转成数组类型 reshape方法中第一个‘1’ 是用于将这个二位数组拉伸成一个只有一行的数组
        # 第二个‘-1’是让Numpy自动计算出剩余维度的大小 使得数组展平成一个包含784个元素的数组

        prediction = self.model.predict(img)
        # 将处理好的img投入到模型中进行预测 然后该模型会根据设定中最近的三个‘邻居'
        # 进行投票后返回最终结果

        return prediction[0]
        # 因为返回的是一个数组的形式 所以其中的第一个数输出 就是结果 返回该结果即可

if __name__ == '__main__':  # 主函数
    warnings.filterwarnings("ignore")
    # 因为训练时的特征名称和现在的特征名称不一样 会报warning但是不影响运行和结果
    # 所以直接忽略这个warning就可以了 如果想解决的话在测试模型前的代码中加上
    # X.columns = [f'pixel{i}' for i in range(X.shape[1])] 即可

    app = QApplication(sys.argv)
    # 创建一个QApplication对象 负责管理该程序的控制流和其他设置

    ex = MainWindow()
    ex.show()
    sys.exit(app.exec_())
    # app.exec_()是进入该程序的主循环 开始上述事件的处理 exit时确保该程序可以干净地退出

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值