保存与加载Keras训练好的模型

简介

可以在训练过程中和训练完成后保存模型,这样就可以很方便地恢复和重用模型,节省模型训练时间。

这样也便于别人使用你的模型,一般有两种方式共享模型:

  • 创建模型的源码
  • 训练好的模型(包括权重、参数等)

这里主要使用第二种方式。

使用的框架是TensorFlow2.4的高阶API:Keras进行模型训练。

验证环境

假设你已经安装好了TensorFlow2.4的运行环境。

如未安装,请稳步 install

安装依赖:

pip install -q pyyaml h5py # Required to save models in HDF5 format

运行以下代码验证:

import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)

输出:

2.4.1

训练模型

使用mnist数据集进行数字分类,代码如下:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

以上代码即可完成模型的训练(为了说明问题,只使用了前1000个元素,节省时间),可以从输出中看到训练过程和结果。

训练中保存快照

可以在训练过程中保存模型,以便后续继续执行。

这需要使用回调函数:tf.keras.callbacks.ModelCheckpoint。

创建回调的代码如下:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.

这会创建一个TensorFlow节点文件,并在每轮训练后更新。

只要两个模型共用相同的网络结构,它们就可以共用权重。所以仅从权重恢复模型的时候,需要以原始模型相同的网络结构创建这个模型,再设置权重。

重建一个新的、未训练的模型,它的精度约10%:

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))

加载刚才已经训练的权重并重新评估精确度,精度可恢复到原来的水平:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

也可以手动保存权重:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
保存整个模型

这个是使用的比较多的方式。

一般会使用Keras训练好模型,保存为文件,再使用c++等方式加载模型,用于生产环境。

使用 model.save 就可以把模型的结构、权重和训练配置保存到单个文件或文件夹。

完整的模型可以保存为两种格式:

  • SavedModel,这是TF2.x的默认存储格式
  • HDF5
SavedModel格式

保存代码如下:

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
os.system('mkdir -p saved_model')
model.save('saved_model/my_model')

SavedModel格式生成的文件夹包括pb文件和TensorFlow节点文件。

加载也很简单:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
HDF5格式

这是Keras的基础格式,保存代码如下:

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')

加载:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

Keras会检查网络结构,并保存模型相关的所有内容:

  • 权重值
  • 模型结构
  • 模型训练配置,也就是传入compile的参数
  • 优化器和状态

SavedModel和HDF5格式的关键区别在于:

  • HDF5使用对象配置保存模型框架
  • SavedModel保存的是可执行图

因此,SavedModel不用查询源码即可保存诸如子类模型和定制层等定制化的对象。而HDF5就要复杂一些。

具体的不再详述,可参考相关资料。

参考资料

https://tensorflow.google.cn/tutorials/keras/save_and_load

  • 3
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值