1、使用tf.keras.callbacks.ModelCheckpoint()
定义回调函数tf.keras.callbacks.ModelCheckpoint(),在model.fit()中加入该回调函数,将在model训练时自动调用回调函数保存训练过程记录和模型结构以及参数权重,这些内容将保存在一个.ckpt文件中。、
示例:
import os
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
batch_size,num_classes = train_labels.shape[0], 10
model.compile(optimizer='adam',
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
checkpoint_path = "tf_ckpt_logs/model.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=2,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
# Create a basic model instance
model = create_model()
# Evalua