通过本篇blog,你将会学到
- 将所有内容以 TensorFlow SavedModel 格式(或较早的 Keras H5 格式)保存到单个归档。这是标准做法。
- 仅保存架构/配置,通常保存为 JSON 文件。
- 仅保存权重值。通常在训练模型时使用。
参考链接Tensorflow官方
保存和加载整个模型
模型的保存:
API:
model.save() 或 tf.keras.models.save_model()
参考此API,你将保存完整的模型架构、训练权重、优化器及其状态等各种信息
同时模型的保存有两种格式:
- TensorFlow SavedModel 格式(推荐使用/默认格式)
- Keras H5 格式(较早格式)
您可以通过以下方式切换到 H5 格式。
3. 将 save_format=‘h5’ 传递给 save()。
4. 将以 .h5 或 .keras 结尾的文件名传递给 save()。
模型的加载
API:
tf.keras.models.load_model()
举个例子:
如下所示,加载器动态地创建了一个与原始模型行为类似的新模型。
model和loaded为两个Model,一个是先前的自己写的模型。一个是加载保存后的模型。
class CustomModel(keras.Model):
def __init__(self, hidden_units):
super(CustomModel, self).__init__()
self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]
def call(self, inputs):
x = inputs
for layer in self.dense_layers:
x = layer(x)
return x
model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")
# Delete the custom-defined model class to ensure that the loader does not have
# access to it.
del CustomModel
loaded = keras.models.load_model("my_model")
np.testing.assert_allclose(loaded(input_arr), outputs)
print("Original model:", model)
print("Loaded model:", loaded)
加载模型的参数
我们日常训练时,可能自己训练了老久的参数不想放弃,下次训练时接着上次的参数训练下去。
因此下面将讲述如何加载模型的参数,来实现接着上次训练结束后的参数训练。(可以节约很多时间)
- 模型的保存–参考上面模型的保存
- 权重值的加载(未看)
目前本人采取的方法: 采用if判断,如果有保存的模型,则加载之前保存的模型(即不再通过代码的模型),此时训练的参数一并会被加载。
如果没有保存的模型,则重新开始训练。
未完待续……