简介
可以在训练过程中和训练完成后保存模型,这样就可以很方便地恢复和重用模型,节省模型训练时间。
这样也便于别人使用你的模型,一般有两种方式共享模型:
- 创建模型的源码
- 训练好的模型(包括权重、参数等)
这里主要使用第二种方式。
使用的框架是TensorFlow2.4的高阶API:Keras进行模型训练。
验证环境
假设你已经安装好了TensorFlow2.4的运行环境。
如未安装,请稳步 install
安装依赖:
pip install -q pyyaml h5py # Required to save models in HDF5 format
运行以下代码验证:
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
输出:
2.4.1
训练模型
使用mnist数据集进行数字分类,代码如下:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
以上代码即可完成模型的训练(为了说明问题,只使用了前1000个元素,节省时间),可以从输出中看到训练过程和结果。
训练中保存快照
可以在训练过程中保存模型,以便后续继续执行。
这需要使用回调函数:tf.keras.callbacks.ModelCheckpoint。
创建回调的代码如下:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
这会创建一个TensorFlow节点文件,并在每轮训练后更新。
只要两个模型共用相同的网络结构,它们就可以共用权重。所以仅从权重恢复模型的时候,需要以原始模型相同的网络结构创建这个模型,再设置权重。
重建一个新的、未训练的模型,它的精度约10%:
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
加载刚才已经训练的权重并重新评估精确度,精度可恢复到原来的水平:
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
也可以手动保存权重:
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
保存整个模型
这个是使用的比较多的方式。
一般会使用Keras训练好模型,保存为文件,再使用c++等方式加载模型,用于生产环境。
使用 model.save 就可以把模型的结构、权重和训练配置保存到单个文件或文件夹。
完整的模型可以保存为两种格式:
- SavedModel,这是TF2.x的默认存储格式
- HDF5
SavedModel格式
保存代码如下:
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
os.system('mkdir -p saved_model')
model.save('saved_model/my_model')
SavedModel格式生成的文件夹包括pb文件和TensorFlow节点文件。
加载也很简单:
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
HDF5格式
这是Keras的基础格式,保存代码如下:
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
加载:
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
Keras会检查网络结构,并保存模型相关的所有内容:
- 权重值
- 模型结构
- 模型训练配置,也就是传入compile的参数
- 优化器和状态
SavedModel和HDF5格式的关键区别在于:
- HDF5使用对象配置保存模型框架
- SavedModel保存的是可执行图
因此,SavedModel不用查询源码即可保存诸如子类模型和定制层等定制化的对象。而HDF5就要复杂一些。
具体的不再详述,可参考相关资料。
参考资料
https://tensorflow.google.cn/tutorials/keras/save_and_load