机器学习-sklearn-classification, Recognize hand-written代码分析

本文介绍了如何使用scikit-learn库中的SVM模型对手写数字进行识别,包括数据预处理(如图片扁平化和灰度化),训练过程,以及评估模型性能的分类报告和混淆矩阵。
摘要由CSDN通过智能技术生成

import matplotlib.pyplot as plt

from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split

# 加载图片数据集(1797个8*8像素的手写数字图片)
digits = datasets.load_digits()

# 下划线“_”表示我们忽略了第一个返回值,它是一组包含所有子图的对象。
# 第二个返回值 axes 是一个包含所有子图对象的数组。
# 通过 axes 数组,我们可以访问子图对象
_,axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))

# 真值图片展示,利用zip一一对应,展示其中4张
for ax, image, label in zip(axes, digits.images, digits.target):
    ax.set_axis_off() # 隐藏图表的轴线和刻度标签,使图表更加简洁。
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")  # 用于显示图像,将二维数组表示的图像数据绘制成图像。
    ax.set_title("Training: %i" % label) # 打上标签

# 扁平化图片
n_samples = len(digits.images) # 获取样本数量

# 这一行代码将图片的二维数组(形状为 [n_samples, 8, 8])重塑为一个一维的特征向量数组。
# reshape 函数的第一个参数是目标形状元组 (n_samples, -1),其中 n_samples 表示样本数量,而 -1 表示自动计算该维度的大小,以使得总元素数量保持不变。
# 因此,data 数组的形状将变为 [n_samples, 64],即每个样本都变成了一个长度为 64 的一维特征向量。
data = digits.images.reshape((n_samples, -1))

# 创建一个支持向量机分类器(SVC)对象。
clf = svm.SVC(gamma=0.001)

# 划分训练集和测试集,55,shuffle=False,不对数据集进行洗牌
X_train, X_test, y_train, y_test = train_test_split(
    data, digits.target, test_size=0.5, shuffle=False
)

# 用训练集拟合模型
clf.fit(X_train, y_train)

# 利用训练完的模型对测试集进行预测
predicted = clf.predict(X_test)

# 展示预测结果中的四张结果
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, prediction in zip(axes, X_test, predicted):
    ax.set_axis_off()
    image = image.reshape(8, 8) # 重构图片为8*8的数组,才能进行图片展示
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
    ax.set_title(f"Prediction: {prediction}")

# 打印模型参数信息 生成分类器性能指标的报告。
print(
    f"Classification report for classifier {clf}:\n"
    f"{metrics.classification_report(y_test, predicted)}\n"
)

# 可视化分类器预测结果的混淆矩阵。
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
# plot 展示
disp.figure_.suptitle("Confusion Matrix")
# 控制台打印混淆矩阵
print(f"Confusion matrix:\n{disp.confusion_matrix}")

plt.show()

# 根据混淆矩阵 cm 的内容来重建真实标签列表 y_true 和预测标签列表 y_pred。
y_true = []
y_pred = []
cm = disp.confusion_matrix
for gt in range(len(cm)):
    for pred in range(len(cm)):
        y_true += [gt] * cm[gt][pred]
        y_pred += [pred] * cm[gt][pred]

print(
    "Classification report rebuilt from confusion matrix:\n"
    f"{metrics.classification_report(y_true, y_pred)}\n"
)

心得:

1.手写数字识别采用的模型是SVC,所以在进行模型训练前有一步很关键。就是对图片8*8的矩阵进行扁平化,变成1*64,这样才适用于该模型。除此以外,还有就是图片的灰度化处理(单通道),没有过多的其余颜色干扰。

2.通过这个示例代码,还能了解到了对训练完后的模型进行查看的一些方式。查看一些性能指标,这些在sklearn中都有提供对应的接口。

3.机器学习,对数据的预处理也是很关键的一步。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值