源码
"""
# -*- coding: utf-8 -*-
# @Time : 2023/6/13 17:31
# @Author : 王摇摆
# @FileName: Multiclass.py
# @Software: PyCharm
# @Blog :https://blog.csdn.net/weixin_44943389?type=blog
"""
# 导入数据
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
digits = load_digits()
# 数据集切分
X_train, X_test, y_train, y_test = train_test_split(digits['data'], digits['target'], test_size=0.2)
fig, axes = plt.subplots(10, 16, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(X_train[i, :].reshape(8, 8), cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(y_train[i]),
transform=ax.transAxes, color='blue')
ax.set_xticks([])
ax.set_yticks([])
plt.show()
学习
这段代码的目标是加载手写数字数据集(load_digits()
),并显示训练集中的一部分样本图像。
以下是代码的解释:
- 导入所需的库:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
- 加载手写数字数据集:
digits = load_digits()
- 切分数据集为训练集和测试集:
X_train, X_test, y_train, y_test = train_test_split(digits['data'], digits['target'], test_size=0.2)
- 创建一个包含10行16列的图像子图,并设置图像之间的间距和尺寸:
fig, axes = plt.subplots(10, 16, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
- 使用循环遍历每个子图,显示训练集中的图像和对应的标签:
for i, ax in enumerate(axes.flat):
ax.imshow(X_train[i, :].reshape(8, 8), cmap='binary', interpolation='nearest')
ax.text(0.05, 0.05, str(y_train[i]), transform=ax.transAxes, color='blue')
ax.set_xticks([])
ax.set_yticks([])
ax.imshow()
用于在子图中显示图像。X_train[i, :].reshape(8, 8)
将图像数据从一维数组形状转换为8x8的二维数组形状,并使用二进制颜色映射显示图像。ax.text()
用于在子图中添加文本,显示对应的标签。ax.set_xticks([])
和ax.set_yticks([])
用于隐藏子图中的刻度标签。
- 最后,使用
plt.show()
显示图像子图:
plt.show()
运行代码后,应该会显示一个窗口,其中包含一部分训练集图像及其标签。您可以通过拖动窗口或使用放大/缩小选项来查看图像。请注意,plt.show()
会阻塞代码的执行,直到您关闭显示窗口。
实验结果
.