TF2.x的keras模型保存与加载
传送门:官方文档
Keras模型包含多个组件:
- 模型的结构或者配置文件,表明模型包含哪些网络层以及各层之间的连接方式。
- 当前状态模型的参数。
- 模型的optimizer,在complie里进行定义的。
- 模型的损失函数和度量函数(在complile函数中定义的或者通过
add_loss()
、add_metric()
函数添加)。
通过Keras的API可以将上述的所有组件保存成一个文件或者选择性的保存其中某些组件:
- 以
Tensorflow SavedModel格式
或者Keras H5格式
将整个模型保存为一个文件。 - 以
JSON
文件形式保存模型的结构或者配置。 - 只保留模型的权重,通常在训练模型的过程中使用。
Keras模型的保存与读取
保存 | 加载 |
---|---|
model.save() |
tf.keras.models.load_model() |
tf.keras.models.save_model() |
tf.keras.models.load_model() |
model.save_weights() |
model.load_weights() |
tf.saved_model.save() |
tf.saved_model.load() |
整个模型的保存与加载
- 模型的结构和配置
- 通过训练学习到的权重
- 模型的编译信息(如果保存前有调用
model.compile
)
APIs
model.save()
或者tf.keras.models.save_model()
tf.keras.models.load_model()
使用model.save()
或者tf.keras.models.save_model()
此种方式可以Keras H5格式
或者Tensorflow SavedModel格式
保存整个模型,在TF2.x
版本中默认以SavedModel格式
保存,如果想要使用Keras H5格式
,可以通过以下形式进行保存:
- 在
model.save()
函数中传递参数saved_format='h5'
; - 在
model.save()
函数传递文件名参数时以.h5
或者.keras
结尾。
SavedModel
格式
import tensorflow as tf
model = tf.keras.applications.ResNet50()