使用 TensorFlow 构建图像分类模型:以 MNIST 为例详解流程
本文将通过手把手的方式,带你使用 TensorFlow 框架构建一个完整的图像分类模型,从加载数据到模型部署,全面了解每一个关键步骤。
🧠 为什么选择 MNIST?
MNIST 是一个手写数字图片的数据集,包含了 0 到 9 的灰度图像(28×28 像素),总共 70,000 张,分为 60,000 张训练集与 10,000 张测试集。它是图像分类入门最经典的数据集。
🧰 准备环境
确保你已经安装了 TensorFlow:
pip install tensorflow
📦 第一步:导入必要库
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import numpy as np
📥 第二步:加载并查看数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 标准化处理
x_train, x_test = x_train / 255.0, x_test / 255.0
查看样本图像:
plt.figure(figsize=(6,6))
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(x_train[i], cmap='gray')
plt.title(f"Label: {y_train[i]}")
plt.axis('off')
plt.show()
🏗 第三步:构建模型结构
我们使用 Keras 的 Sequential
API:
model = models.Sequential([
layers.Flatten(input_shape=(28, 28)), # 输入层:将28x28图像拉平为784维向量
layers.Dense(128, activation='relu'), # 隐藏层:128个神经元,ReLU 激活函数
layers.Dropout(0.2), # Dropout层:减少过拟合
layers.Dense(10, activation='softmax') # 输出层:10类数字,使用softmax激活函数
])
⚙ 第四步:编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
🏃♂️ 第五步:训练模型
history = model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.1)
epochs=5
:训练轮次为5validation_split=0.1
:留出10%作为验证集
📊 第六步:可视化训练过程
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('训练过程准确率变化')
plt.show()
✅ 第七步:模型评估
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'测试集准确率: {test_acc:.4f}')
🔍 第八步:进行预测
predictions = model.predict(x_test)
# 查看前5张图的预测结果
for i in range(5):
plt.imshow(x_test[i], cmap='gray')
plt.title(f"预测值: {np.argmax(predictions[i])} / 实际值: {y_test[i]}")
plt.axis('off')
plt.show()
📁 第九步:保存与加载模型
# 保存
model.save('mnist_model.h5')
# 加载
new_model = tf.keras.models.load_model('mnist_model.h5')
🧩 第十步:下一步探索方向
你可以尝试:
- 使用卷积神经网络(CNN)提升准确率
- 引入数据增强(如旋转、缩放)
- 在其他图像数据集(如 CIFAR-10)上复用思路
✅ 总结
步骤 | 内容 |
---|---|
数据准备 | 下载 MNIST 并标准化 |
模型构建 | Sequential 模型堆叠 |
编译模型 | 设置损失函数、优化器 |
模型训练 | 使用 fit 方法进行训练 |
模型评估 | 在测试集上评估准确率 |
预测使用 | 模型进行数字识别预测 |
通过本案例,你已经掌握了图像分类的基本流程和 TensorFlow 的核心用法。继续加油,拥抱深度学习的广阔天地!