实战FashionMNIST数据集分类任务之数据处理及验证

该文详细介绍了如何加载GZ格式的MNIST数据集,包括读取训练和测试数据,并对数据进行了预处理,如归一化到0-1范围,以便适应常见的神经网络激活函数。此外,还展示了数据的可视化和标签映射,为后续的图像识别模型训练做好准备。
摘要由CSDN通过智能技术生成

1、加载数据

本数据集是GZ格式,以下使用了加载GZ格式数据集的方法

import os
import gzip
import numpy as np
import matplotlib.pyplot as plt
 
#加载数据
def load_data(data_file):
    files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
    paths = []
    for fileName in files:
        paths.append(os.path.join(data_file, fileName))
        
    # 读取每个文件夹的数据    
    with gzip.open(paths[0], 'rb') as train_labels_path:
        train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8)
      
    with gzip.open(paths[1], 'rb') as train_images_path:
        train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784)
       
    with gzip.open(paths[2], 'rb') as test_labels_path:
        test_labels = np.frombuffer(test_labels_path.read(), np.uint8, offset=8)
        
    with gzip.open(paths[3], 'rb') as test_images_path:
        test_images = np.frombuffer(test_images_path.read(), np.uint8, offset=16).reshape(len(test_labels), 784)
        
    return train_labels,train_images,test_labels,test_images
 
train_labels,train_images,test_labels,test_images = load_data('MNIST/')


2、预处理数据

先将第一个数据进行可视化,检查数据的正确性

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

发现图像的像素值处于 0 到 255 之间,也就是说数据的范围都在0到255之间。

因激活函数通常是 sigmoid 或 ReLU,它们的输出范围是 [0, 1] 或 [-1, 1]。如果输入数据的像素值超出了这个范围,就会导致梯度消失或梯度爆炸的问题,从而影响模型的训练效果,所以要先进行归一化处理,将这些值缩小至 0 到 1 之间。

train_images = train_images / 255.0
test_images = test_images / 255.0

再次查看数据

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()


3、验证数据 

因图像是 28x28 的 NumPy 数组,标签是整数数组,介于 0 到 9 之间。这些标签对应于图像所代表的服装类别,由于数据集不包括类名称,所以将根据标签的整数自定义映射名称的数组。

标签类别映射名称
0T恤/上衣

T-shirt/top

1裤子

Trouser

2套头衫

Pullover

3连衣裙

Dress

4外套

Coat

5凉鞋

Sandal

6衬衫

Shirt

7运动鞋

Sneaker

8

Bag

9短靴

Ankle boot

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

接下来验证数据集,显示训练集中的前 30个图像,并在每个图像下方显示类名称

plt.figure(figsize=(20,20))
for i in range(30):
    plt.subplot(10,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid()
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

这段代码是用于显示训练集中的图像及其标签。其中,figsize 参数设置了图像的大小为 20x20。通过 for 循环遍历训练集中的所有图像,使用 plt.subplot() 函数将图像排列成一个 10x10的网格。然后使用 plt.imshow() 函数将每个图像以灰度图的形式显示出来,并使用 class_names[train_labels[i]] 作为每个图像的标签。最后使用 plt.xticks([]) 和 plt.yticks([]) 将 x 轴和 y 轴的刻度线去掉,避免出现不必要的刻度干扰。使用 plt.grid(False) 关闭网格线。调用 plt.show() 函数显示图像。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

缘起性空、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值