目录
1. 模型保存的两种方式
- Checkpoints:只保存模型的参数,不保存模型的训练过程
- 使用
tf.keras.Model.save_weights
- 使用
tf.train.Checkpoint()
或者tf.train.CheckpointManager()
- 使用
- SavedModel format,保存完整的tensorflow程序,适用于模型部署 tensorflow serving
tf.saved_model
tf.keras.Model
TF2 模型保存与加载HDF5 和pb文件
2. 构建模型
import os
import tensorflow as tf
from tensorflow import keras
# 加载数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
3 模型保存
3.1 Checkpoints形式
在训练期间保存模型,您可以使用经过训练的模型而无需重新训练,或者在训练过程中断的情况下从离开处继续训练。tf.keras.callbacks.ModelCheckpoint
回调允许您在训练期间和结束时持续保存模型。
3.1.1 Checkpoint回调用法
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# model.summary() # 打印模型结构
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
只要两个模型共享相同的架构,您就可以在它们之间共享权重。因此,当从仅权重恢复模型时,创建一个与原始模型具有相同架构的模型,然后设置其权重。
latest = tf.train.latest_checkpoint(checkpoint_dir)
# Loads the weights
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
3.1.2 手动保存权重
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
3.2 SavedModel format
3.2.1 模型保存
SavedModel
格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 tf.keras.models.load_model
恢复,并且与 TensorFlow Serving 兼容
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
3.2.2 模型加载
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)