“”“第五章 TensorFlow进阶(三) 加载经典数据集 minist”""
import tensorflow as tf
from tensorflow import keras
from keras import datasets
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print(x.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
print(train_db)
随机打散
train_db = train_db.shuffle(10000) # 随机打散样本,不会打乱样本与标签映射关系
批训练
train_db = train_db.batch(128) # 设置批训练,batch size 为 128
print(x.shape)
预处理
def preprocess(x,y):
x = tf.cast(x, dtype=tf.float32)/255
x = tf.reshape(x, [-1, 28*28]) # 打平
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x,y
train_db = train_db.map(preprocess(x,y))
Dataset对象通过提供map(func)工具函数,可以非常方便地调用用户自定义的预处理逻辑,它实现在func函数里。
循环训练
for epoch in range(20): # 训练 Epoch 数
for step, (x,y) in enumerate(train_db): # 迭代 Step 数