四段代码实现参数保存、下载、续训练
1.设置ckpt文件保存路径
其中保存有模型参数(特定文件类型
checkpoint_save_path = "checkpoint/mnist.ckpt"
2.判断是否有索引,如果保存了参数模型,会有相应索引文件
if os.path.exists(checkpoint_save_path + '.index'):
print("---------------loading the model---------------")
model.load_weights(checkpoint_save_path)
3.callback定义模型参数保存相关内容
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
# save_best_only=True,
save_weights_only=True
)
4.在训练中加入callback保存模型参数,创建checkpoint文件
model.fit(x_train, y_train, epochs=5, callbacks=[cp_callback])
5.源码
训练fashion_mnist全连接神经网络
import os
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# # callback类
# class myCallback(tf.keras.callbacks.Callback):
# def on_epoch_end(self, epoch, logs=None):
#
# if logs.get('loss') < 0.4:
# print("\n Loss is low so cancelling training!")
# self.model.stop_training = True
#
#
# # 实例化类
# callback = myCallback()
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# print(x_train[0].shape)
x_train = x_train / 255.0
y_train = y_train / 255.0
# plt.imshow(x_train[0])
# print(x_train[0])
# print(y_train[0])
# plt.show()
model = keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
# 断点续训
# 寻找路径,没有则生成
checkpoint_save_path = "checkpoint/mnist.ckpt"
# 判断是否存在索引
if os.path.exists(checkpoint_save_path + '.index'):
print("---------------loading the model---------------")
# 从路径中下载模型
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
# save_best_only=True,
save_weights_only=True
)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(x_train, y_train, epochs=5, callbacks=[cp_callback])