1、加载数据
本数据集是GZ格式,以下使用了加载GZ格式数据集的方法
import os
import gzip
import numpy as np
import matplotlib.pyplot as plt
#加载数据
def load_data(data_file):
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
paths = []
for fileName in files:
paths.append(os.path.join(data_file, fileName))
# 读取每个文件夹的数据
with gzip.open(paths[0], 'rb') as train_labels_path:
train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as train_images_path:
train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784)
with gzip.open(paths[2], 'rb') as test_labels_path:
test_labels = np.frombuffer(test_labels_path.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as test_images_path:
test_images = np.frombuffer(test_images_path.read(), np.uint8, offset=16).reshape(len(test_labels), 784)
return train_labels,train_images,test_labels,test_images
train_labels,train_images,test_labels,test_images = load_data('MNIST/')
2、预处理数据
先将第一个数据进行可视化,检查数据的正确性
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()
发现图像的像素值处于 0 到 255 之间,也就是说数据的范围都在0到255之间。
因激活函数通常是 sigmoid 或 ReLU,它们的输出范围是 [0, 1] 或 [-1, 1]。如果输入数据的像素值超出了这个范围,就会导致梯度消失或梯度爆炸的问题,从而影响模型的训练效果,所以要先进行归一化处理,将这些值缩小至 0 到 1 之间。
train_images = train_images / 255.0
test_images = test_images / 255.0
再次查看数据
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()
3、验证数据
因图像是 28x28 的 NumPy 数组,标签是整数数组,介于 0 到 9 之间。这些标签对应于图像所代表的服装类别,由于数据集不包括类名称,所以将根据标签的整数自定义映射名称的数组。
标签 | 类别 | 映射名称 |
0 | T恤/上衣 | T-shirt/top |
1 | 裤子 | Trouser |
2 | 套头衫 | Pullover |
3 | 连衣裙 | Dress |
4 | 外套 | Coat |
5 | 凉鞋 | Sandal |
6 | 衬衫 | Shirt |
7 | 运动鞋 | Sneaker |
8 | 包 | Bag |
9 | 短靴 | Ankle boot |
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
接下来验证数据集,显示训练集中的前 30个图像,并在每个图像下方显示类名称
plt.figure(figsize=(20,20))
for i in range(30):
plt.subplot(10,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid()
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
这段代码是用于显示训练集中的图像及其标签。其中,figsize 参数设置了图像的大小为 20x20。通过 for 循环遍历训练集中的所有图像,使用 plt.subplot() 函数将图像排列成一个 10x10的网格。然后使用 plt.imshow() 函数将每个图像以灰度图的形式显示出来,并使用 class_names[train_labels[i]] 作为每个图像的标签。最后使用 plt.xticks([]) 和 plt.yticks([]) 将 x 轴和 y 轴的刻度线去掉,避免出现不必要的刻度干扰。使用 plt.grid(False) 关闭网格线。调用 plt.show() 函数显示图像。