MNIST数据集相关知识可参考:MNIST + tf.layers
import tensorflow as tf
# 载入数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 构建模型
layer = tf.keras.layers
model = tf.keras.Sequential()
model.add(layer.Flatten())
model.add(layer.Dense(128, activation='relu'))
model.add(layer.Dropout(0.2))
model.add(layer.Dense(10, activation='softmax'))
# 设置模型流程
model.compile(optimizer=tf.train.AdamOptimizer(),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[tf.keras.metrics.categorical_accuracy])
# 使用 fit 方法使模型与训练数据“拟合
model.fit(x_train, y_train, epochs=5)
# 评估和预测
model.evaluate(x_test, y_test)
# 模型权重保存,保存为tensorflow的checkpoint形式
model.save_weights('./weights/my_model')
# 模型权重恢复
# model.load_weights('./weights/my_model')