Tensorflow2.x框架-MNIST数据集

MNIST数据集

博主微信公众号(左)、Python+智能大数据+AI学习交流群(右):欢迎关注和加群,大家一起学习交流,共同进步!

目录

摘要

一、导入模块

二、导入 MNIST 数据集

三、作为输入特征,输入神经网络时,将数据拉伸为一维数组

四、可视化训练集输入特征的第一个元素

五、打印出训练集输入特征的第一个元素

六、打印出训练集标签的第一个元素

七、打印出训练集(输入特征、标签)、测试集(输入特征、标签)的形状

八、完整代码

九、Sequential() 实现手写数字识别

十、Class() 实现手写数字识别


摘要

MNIST 数据集:

    提供 6w 张 28*28 像素点的 0~9 手写数字图片和标签,用于训练。

    提供 1w 张 28*28 像素点的 0~9 手写数字图片个标签,用于测试。

一、导入模块

# 模块导入
import tensorflow as tf
from matplotlib import pyplot as plt

二、导入 MNIST 数据集

# 导入数据集,分别为输入特征和标签
mnist = tf.keras.datasets.mnist
# (x_train, y_train):(训练集输入特征,训练集标签)
# (x_test, y_test):(测试集输入特征,测试集标签)
(x_train, y_train), (x_test, y_test) = mnist.load_data()

三、作为输入特征,输入神经网络时,将数据拉伸为一维数组

tf.keras.layers.Flatten()    # 将输入特征拉直为一维数组,也就是拉直为28*28=784个数值

把28*28=784个像素点的灰度值作为输入特征送入神经网络。

四、可视化训练集输入特征的第一个元素

# 可视化训练集输入特征的第一个元素
plt.imshow(x_train[0], cmap="gray")     # 绘制灰度图
plt.show()

可视化后的图片:一张黑底白字的手写数字图片5。

五、打印出训练集输入特征的第一个元素

# 打印出训练集输入特征的第一个元素
print(f"x_train[0]: \n {x_train[0]}")

输出结果:手写数字5的28行28列个像素值。0表示纯黑色,255表示纯白色。

x_train[0]: 
 [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136
  175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253
  225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251
   93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  11 190 253  70   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  35 241 225 160 108   1
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  81 240 253 253 119
   25   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  45 186 253 253
  150  27   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  16  93 252
  253 187   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 249
  253 249  64   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  46 130 183 253
  253 207   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  39 148 229 253 253 253
  250 182   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  24 114 221 253 253 253 253 201
   78   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  23  66 213 253 253 253 253 198  81   2
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  18 171 219 253 253 253 253 195  80   9   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  55 172 226 253 253 253 253 244 133  11   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 136 253 253 253 212 135 132  16   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]

六、打印出训练集标签的第一个元素

# 打印出训练集标签的第一个元素
print(f"y_train[0]: \n {y_train[0]}")

输出结果:

y_train[0]: 
 5

七、打印出训练集(输入特征、标签)、测试集(输入特征、标签)的形状

# 打印出整个训练集输入特征的形状
print(f"x_train.shape: \n {x_train.shape}")
# 打印出整个训练集标签的形状
print(f"y_train.shape: \n {y_train.shape}")
# 打印出整个测试集输入特征的形状
print(f"x_test.shape: \n {x_test.shape}")
# 打印出整个测试集标签的形状
print(f"y_test.shape: \n {y_test.shape}")

输出结果:

    x_train —— 60000 个 28 行 28 列的数据

    y_train —— 60000 个标签

    x_test —— 10000 个 28 行 28 列的数据

    y_test —— 10000 个标签

x_train.shape: 
 (60000, 28, 28)
y_train.shape: 
 (60000,)
x_test.shape: 
 (10000, 28, 28)
y_test.shape: 
 (10000,)

八、完整代码

# 模块导入
import tensorflow as tf
from matplotlib import pyplot as plt

# 导入数据集,分别为输入特征和标签
mnist = tf.keras.datasets.mnist
# (x_train, y_train):(训练集输入特征,训练集标签)
# (x_test, y_test):(测试集输入特征,测试集标签)
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 可视化训练集输入特征的第一个元素
plt.imshow(x_train[0], cmap="gray")     # 绘制灰度图
plt.show()

# 打印出训练集输入特征的第一个元素
print(f"x_train[0]: \n{x_train[0]}")
# 打印出训练集标签的第一个元素
print(f"y_train[0]: \n {y_train[0]}")

# 打印出整个训练集输入特征的形状
print(f"x_train.shape: \n {x_train.shape}")
# 打印出整个训练集标签的形状
print(f"y_train.shape: \n {y_train.shape}")
# 打印出整个测试集输入特征的形状
print(f"x_test.shape: \n {x_test.shape}")
# 打印出整个测试集标签的形状
print(f"y_test.shape: \n {y_test.shape}")

九、Sequential() 实现手写数字识别

# 模块导入
import tensorflow as tf

# 导入数据集,分别为输入特征和标签
mnist = tf.keras.datasets.mnist
# (x_train, y_train):(训练集输入特征,训练集标签)
# (x_test, y_test):(测试集输入特征,测试集标签)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对输入网络的输入特征进行归一化,使原本0到255之间的灰度值,变为0到1之间的数值
# (把输入特征的数值变小更适合神经网络吸收)
x_train, x_test = x_train / 255.0, x_test / 255.0

# 搭建网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),  # 将输入特征拉直为一维数组,也就是拉直为28*28=784个数值
    tf.keras.layers.Dense(128, activation="relu"),  # 第一层网络128个神经元,使用relu激活函数
    tf.keras.layers.Dense(10, activation="softmax")     # 第二层网络10个神经元,使用softmax激活函数,使输出符合概率分布
])

# 配置训练方法
model.compile(optimizer="adam",     # 优化器
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),    # 损失函数,输出是概率分布,from_logits=False
              metrics=["sparse_categorical_accuracy"])  # 数据集中的标签是数值,神经网络输出y是概率分布

# 执行训练过程
model.fit(x_train,  # 训练集输入特征
          y_train,  # 训练集标签
          batch_size=32,    # 每次喂入网络32组数据
          epochs=5,     # 数据集迭代5次
          validation_data=(x_test, y_test),     # 测试集输入特征,测试集标签
          validation_freq=1)     # 每迭代1次训练集执行一次测试集的评测

# 打印出网络结构和参数统计
model.summary()

十、Class() 实现手写数字识别

# 模块导入
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Flatten

# 导入数据集,分别为输入特征和标签
mnist = tf.keras.datasets.mnist
# (x_train, y_train):(训练集输入特征,训练集标签)
# (x_test, y_test):(测试集输入特征,测试集标签)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对输入网络的输入特征进行归一化,使原本0到255之间的灰度值,变为0到1之间的数值
# (把输入特征的数值变小更适合神经网络吸收)
x_train, x_test = x_train / 255.0, x_test / 255.0


# 搭建网络结构
class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()     # 将输入特征拉直为一维数组,也就是拉直为28*28=784个数值
        self.d1 = Dense(128, activation="relu")  # 第一层网络128个神经元,使用relu激活函数
        self.d2 = Dense(10, activation="softmax")     # 第二层网络10个神经元,使用softmax激活函数,使输出符合概率分布

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        y = self.d2(x)
        return y


model = MnistModel()

# 配置训练方法
model.compile(optimizer="adam",     # 优化器
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),    # 损失函数,输出是概率分布,from_logits=False
              metrics=["sparse_categorical_accuracy"])  # 数据集中的标签是数值,神经网络输出y是概率分布

# 执行训练过程
model.fit(x_train,  # 训练集输入特征
          y_train,  # 训练集标签
          batch_size=32,    # 每次喂入网络32组数据
          epochs=5,     # 数据集迭代5次
          validation_data=(x_test, y_test),     # 测试集输入特征,测试集标签
          validation_freq=1)     # 每迭代1次训练集执行一次测试集的评测

# 打印出网络结构和参数统计
model.summary()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值