python实现SVM手写体数字识别

一、MNIST数据集介绍

MNIST数据集是一个二进制图像数据集,广泛用于机器学习中的训练和测试,该数据集于1998年发布。数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像均为28x28的灰度图像,每张图像包含一个手写数字。数据集包含10个类别,每个类别代表0~9之间的一个数字,每张图像只有一个类别。为了方便后续训练,在加载数据时,将28*28的图像保存为1x784

二、实验代码

1.导入相关的库

import os
import struct
from datetime import datetime
from matplotlib import pyplot as plt
import numpy as np
from sklearn import  svm
from PIL import Image

2.加载数据集

在加载数据集时,将图片转换成numpy的格式,便于后续训练

def load_mnist(path, kind='train'):
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind) 
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind) 
 
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8)) 
        labels = np.fromfile(lbpath, dtype=np.uint8)
 
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) 
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels

3.训练函数

训练代码如下:

def train(train_num):
    X_train, y_train = load_mnist('.//dataset//MNIST//raw//', kind='train')
    # 加载训练集
    #X = preprocessing.StandardScaler().fit_transform(X_train)
    X=X_train
    X_train = X[0:train_num]  # 训练60000张
    y_train = y_train[0:train_num]

    dt = datetime.now()
    print('time is ' + dt.strftime('%Y-%m-%d %H:%M:%S'))

    model_svc = svm.SVC(kernel='rbf', gamma='scale')
    model_svc.fit(X_train, y_train) #fit()函数会返回一个训练好的模型

    dt = datetime.now()
    print('time is ' + dt.strftime('%Y-%m-%d %H:%M:%S'))

    return model_svc

sklearn里面提供了svm的包,直接调用即可,这里核函数选择rbf(高斯径向核函数),gamma表示内核系数,默认为scale。而后进行训练
训练时调用的是fit()函数,该函数最终会返回一个训练好的模型。

4.测试函数

def test(model_svc, test_num):
    test_images, test_labels = load_mnist('.//dataset//MNIST//raw//', kind='t10k')  # 加载测试集
    #x = preprocessing.StandardScaler().fit_transform(test_images)
    x=test_images
    x_test = x[0:test_num]
    y_test = test_labels[0:test_num]

    print(model_svc.score(x_test, y_test))  # 根据训练的模型,进行分类得分计算
    #return model_svc.score(x_test, y_test)
    return test_images, test_labels, x

score()函数可返回测试训练集带入模型后的正确率

5.预测函数

def pred(model_svc, pred_num, test_images, test_labels, x):
    y_pred = model_svc.predict(x[9690 - pred_num:9690])  # 进行预测,能得到一个结果
    print(y_pred)

    X_show = test_images[9690 - pred_num:9690]
    #Y_show = test_labels[9690 - pred_num:9690]

    for i in range(pred_num):
        x_show = X_show[i].reshape(28, 28)
        plt.subplot(1, pred_num, i + 1)
        plt.imshow(x_show, cmap=plt.cm.gray_r)
        plt.title(str(y_pred[i]))
        plt.axis('off')
    plt.savefig('./picture.png')#Linux环境下运行时加入
    plt.show()

调用predict()函数对图片进行预测,并且返回预测结果。
注:在Linux环境下运行时,matplotlib无法显示图片,需要先保存成图片并在程序运行后查看

6.实验结果

model = train(5000)#训练个数
test_images, test_labels, x = test(model,9900)
pred(model,10,test_images, test_labels, x)

训练个数:1000,有两个数字识别错误

在这里插入图片描述
在这里插入图片描述

训练个数:3000,有一个数字识别错误

在这里插入图片描述
在这里插入图片描述

训练个数5000,全部预测正确

在这里插入图片描述
在这里插入图片描述

测试自己手写的图片

image_file = Image.open(".//mynum.png") # open colour image
image_file = image_file.resize((28,28))
image_file = image_file.convert('L') # convert image to black and white
image_file = np.array(image_file,dtype=np.uint8)
image_file = image_file.reshape(1,784)
# image_file=image_file-255
mypred = model.predict(image_file)
print(mypred)
plt.imshow(Image.open(".//mynum.png"))
plt.title(str(mypred[0]))
plt.savefig('./write.png')
plt.show()

将手写的数字2处理成1x784的大小后,输入到训练好的模型中,得到结果如下:
在这里插入图片描述
在这里插入图片描述

参考博客

https://blog.csdn.net/Brinshy/article/details/122483150

  • 11
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值