利用KNN算法实现手写数字识别

本文介绍了如何使用Python和scikit-learn库处理MNIST数据集,包括数据加载、图像显示、特征预处理(通过MinMaxScaler缩放)、KNN模型训练与评估,以及模型预测。代码示例展示了如何展示单个手写数字图像并进行分类。
摘要由CSDN通过智能技术生成

1.数据介绍

数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像。

  • 每个图像高 28 像素,宽28 像素,共784个像素。

  • 每个像素取值范围[0,255]

2.数据资料链接

链接: https://pan.baidu.com/s/1C0y-1-jF7mY31xrUBQ1Z6w?pwd=6666 提取码: 6666 复制这段内容后打开百度网盘手机App,操作更方便哦

注意:需要改成自己电脑上的文件路径,还有一点就是数据量有4万多条,对一些性能不太好的电脑,跑代码可能需要些时间.

2.代码实现

#1.导包
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier
import joblib


#2.显示图像
def show_digits(idx):
    #2.1读取数据
    digit_data = pd.read_csv(r'D:\机器学习\02-KNN算法\03-代码\手写数字识别.csv')  #这里r是反转义,地址是你下载文件的地址

    #2.2判断值是否合理
    if idx < 0 or idx >= len(digit_data) - 1:
        return

    #2.3获取特征和目标
    x = digit_data.iloc[:,1:]
    y = digit_data.iloc[:,0]

    #2.4显示图像
    image_data = x.iloc[idx].values

    #2.5转换图像
    image_data = image_data.reshape(28,28) 将数据(1,784)转换成(28,28)

    #2.6显示图像 
    print(y[idx]) 
    plt.imshow(image_data,cmap='gray') #cmap参数是显示图片的颜色
    plt.show()

#3.模型训练
def train_model():
    #3.1加载数据
    digit_data = pd.read_csv(r'D:\机器学习\02-KNN算法\03-代码\手写数字识别.csv')

    #3.2获取特征和目标值
    x = digit_data.iloc[:, 1:]
    y = digit_data.iloc[:, 0]

    #3.3划分数据集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y, random_state=22)  #startify参数是表示,把训练集中的0-9每个数字的数量,按照训练集:测试集=8:2的比例均匀分开.

    #3.4特征预处理
    transfer = MinMaxScaler()
    x_train = transfer.fit_transform(x_train)
    x_text = transfer.transform(x_test)

    #3.5模型实例化
    knn = KNeighborsClassifier(n_neighbors = 5)

    #3.6模型训练
    knn.fit(x_train,y_train)

    #3.7模型评估
    print(knn.score(x_test,y_test))

    #3.8保存模型
    joblib.dump(knn,'knn.pth')

#4.模型预测
def model_predict():
    #4.1加载模型
    knn = joblib.load('knn.pth')

    #4.2读取数据
    img = plt.imread(r'D:\机器学习\02-KNN算法\03-代码\demo.png').reshape(1,-1) #1是从索引1起始,一直到最后一个的索引为-1,将图片(28,28)转化为(1,748)

    #4.3预测
    print(knn.predict(img)) #[2]

if __name__ == '__main__':
    show_digits(4)    #4
    train_model()     
    model_predict()

  • 6
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值