唐宇迪TensorFlow2.0入门到实战:神经网络分类任务MNIST手写数据集

MNIST数据集

  MNIST数据集是一个手写数字识别数据集,常用于机器学习和深度学习的训练与测试。这个数据集包含了大量的手写数字图像,包括0到9的数字。这些图像都是28x28像素的灰度图像,每个图像都对应着一个标签,表示图像中所包含的数字。

官网中http://yann.lecun.com/exdb/mnist/数据集一共分成了四个文件:

train-images-idx3-ubyte.gz:  training set images (9912422 bytes)
train-labels-idx1-ubyte.gz:  training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

文件下载文件用途
train-images-idx3-ubyte.gz训练集图像
train-labels-idx1-ubyte.gz训练集标签
t10k-images-idx3-ubyte.gz测试集图像
t10k-labels-idx1-ubyte.gz测试集标签

训练集一共包含了 60,000 张图像和标签,而测试集一共包含了 10,000 张图像和标签。测试集中前5000个来自最初NIST项目的训练集.,后5000个来自最初NIST项目的测试集。前5000个比后5000个要规整,这是因为前5000个数据来自于美国人口普查局的员工,而后5000个来自于大学生。

读取 MNIST 数据集

根据 MNIST 数据集官网可知,读取数据集需要 offset,因为,在数据头部的数据存储了数据集的一些信息。

training set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8

training set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数, 因此需要 offset 16

test set label file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),因此需要 offset 8

test set images file: 前 4 个字节为 魔术数,第 4-7 字节为数据的条数(number of items),第 8-11 是每张图片的行数,第 12-15 是每张图片的列数,因此需要 offset 16

import idx2numpy
import numpy as np
# 把gz文件先解压,再通过idx2numpy包读取文件
file1 = 'data/mnist/train-images.idx3-ubyte'
x_train = idx2numpy.convert_from_file(file1)

file2 = 'data/mnist/train-labels.idx1-ubyte'
y_train = idx2numpy.convert_from_file(file2)

file3 = 'data/mnist/t10k-images.idx3-ubyte'
x_valid = idx2numpy.convert_from_file(file3)

file4 = 'data/mnist/t10k-labels.idx1-ubyte'
y_valid = idx2numpy.convert_from_file(file4)

# Reshape data,变成一个矢量
x_train = x_train.reshape(-1, 28*28).astype('float32') / 255.0
x_valid = x_valid.reshape(-1, 28*28).astype('float32') / 255.0

显示图片

from matplotlib import pyplot 
import numpy as np

pyplot.imshow(x_train[0].reshape(28, 28), cmap="gray")
print(x_train.shape)
# print(y_train[0])

torch.Size([60000, 1, 28, 28]) #数据为28×28×1的灰度图

任务目标

输入图片,将该图片分类为0-9数字中的某一类

1. 将图片矩阵,拉成一个向量(已经通过reshape完成)

2. 构建模型(选择合适的loss函数,不同loss函数对标签的格式要求不一样,如tf.keras.losses.CategoricalCrossentropy需要one hot格式,而tf.keras.losses.SparseCategoricalCrossentropy则只需要一个数值分类标签即可

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(10, activation="softmax")) ## 得到10个分类各自的概率

选择和评估函数时候需要选择合适的API:https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy

选择合适的损失函数如tf.keras.losses.SparseCategoricalCrossentropy 和tf.keras.losses.CategoricalCrossentropy

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

运行模型

model.fit(x_train, y_train, epochs=5, batch_size=64,
          validation_data=(x_valid, y_valid))

Epoch 1/5 938/938 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - loss: 0.7341 - sparse_categorical_accuracy: 0.7801 - val_loss: 0.2180 - val_sparse_categorical_accuracy: 0.9388 Epoch 2/5 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - loss: 0.1984 - sparse_categorical_accuracy: 0.9399 - val_loss: 0.1640 - val_sparse_categorical_accuracy: 0.9513 Epoch 3/5 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - loss: 0.1536 - sparse_categorical_accuracy: 0.9544 - val_loss: 0.1394 - val_sparse_categorical_accuracy: 0.9576 Epoch 4/5 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - loss: 0.1274 - sparse_categorical_accuracy: 0.9608 - val_loss: 0.1378 - val_sparse_categorical_accuracy: 0.9587 Epoch 5/5 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - loss: 0.1096 - sparse_categorical_accuracy: 0.9673 - val_loss: 0.1266 - val_sparse_categorical_accuracy: 0.9608

<keras.src.callbacks.history.History at 0x244444f6910>

将numpy数据转换为张量数据再进行建模分析

train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train = train.batch(32)
train = train.repeat()

valid = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid = valid.batch(32)
valid = valid.repeat()

model.fit(train, epochs=5, steps_per_epoch=100, validation_data=valid, validation_steps=100)

  • 31
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值