【Google 机器学习笔记】
十、TensorFlow 2.1 实战(二)基本图像分类
为节省时间,降低学习成本,本节实战的图片分类对象 tf.keras 中内置的 MNIST 数据集。
首先回顾机器学习编程的几个基本步骤
1. 数据
① 获取数据
② 处理数据
③ 拆分数据
④ 检查数据
2. 模型
① 构建模型
② 检查模型
③ 训练模型
④ 进行预测
现根据以上步骤进行实战训练
数据
获取数据
# 首先先导入 TF
import tensorflow as tf
# TF 的高级 API —— Keras
from tensorflow import keras
from tensorflow.keras import layers
# 用到的第三方包
import numpy as np
import matplotlib.pyplot as plt
# TensorFlow 中为方便机器学习,内置了很多数据集在 tf.keras.datasets 中,
# 而且已经将数据处理并拆分好了。具体使用方法可前往官网查看 API 文档
# 获取数据
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
处理数据
处理数据前,我们先认识一下这份数据集。先看看 train_images 的情况,
train_images.shape
(60000, 28, 28)
这里说明 train_images 中包含 60000 张图片,每张图片都是 28*28 像素。
再看看 train_labels 的情况,
train_labels.shape
(60000,)
共 60000 个标签,一一对应 train_images 中的图片。
你也可以通过 train_lables
看看 train_labels 中的内容,结果应该是:
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
对于 test_images、test_labels,你也可以通过类似的方法了解这份数据。
test_images.shape
(10000, 28, 28)
test_labels.shape
(10000,)
OK,我们已经基本知道这个数据集的结构了,但是对于 train_images 中的图片的情况还不是很了解,这时我们可以使用 Numpy 中的 plt 来查看这些图片。
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.xlabel(train_labels[0])
plt.show()
这是 train_images 中的第一张图片,数字 5 (第一眼我没看出来这是 5 …),颜色范围 0~255,是一个离散值。
之前讲到,如果数据是一个离散值,一般将其缩放至 [-1,1],这样可以加快模型训练速度。这里我们采用线性缩放的方式将其缩放至 [0,1]。
train_images = train_images / 255.0
test_images = test_images / 255.0
至此数据处理完毕
拆分数据
此步骤在 “获取数据” 已完成,略过。
检查数据
为了确保数据的正确性,我们需要在将它放进模型训练之前检查这份数据的正确性。
plt