TensorFlow 2.0 模型的保存和恢复

本文详细介绍了TensorFlow 2.0中模型的五种保存方法:整体保存(包括框架和权重)、仅保存架构、权重保存、使用回调函数在训练期间保存检查点,以及自定义保存模型的检查点。通过这些方法,可以在不依赖原始代码的情况下恢复模型状态,便于训练的连续性和效率提升。
摘要由CSDN通过智能技术生成

共包含五终保存的方式

  1. 模型整体的保存(框架和权重)
  2. 保存模型的框架(代码的结构)
  3. 保存模型的权重
  4. 使用回调函数对模型进行保存
  5. 自定义训练模型的保存

1 模型整体的保存

整个模型可以保存到一个文件中,其中包含权重值、模型配置乃至优化器配置。可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras使用HDF5标准提供基本的保存格式

#保存模型的代码
model.save('./save_model/less_model.h5')

#加载保存的模型
new_model = tf.keras.models.load_model('./save_model/less_model.h5')

此方法主要保存一下所有内容:

  • 权重值
  • 模型配置(架构)
  • 优化器配置

2 仅保存架构

如果对模型的架构感兴趣,而无需保存权重值和优化器。可以仅保存模型的‘配置’

#保存
json_model = model.to_json()


#加载
re_model = tf.keras.models.model_from_json(json_model)

#查看网络结构
re_model.summary()

3 保存模型权重

如果只需要保存模型的状态(权重值),可以通过get_weights()获取权重值,并通过set_weights()设置权重值

#获得权重
weights = model.get_weights()

#保存权重
model.save_weights('./save_model/less_weights.h5')


#加载权重
re_model.load_weights('./save_model/less_weights.h5')

4 回调函数在训练期间保存检查点

在训练期间或者结束时自动保存检查点。便于使用经过训练的模型,而无需重新训练该模型,或者从上次暂停的地方继续训练,以防止训练中断
回调函数:tf.keras.callback.ModelCheckpoint()

#保存的代码

model.compile(optimizer='adam',
          loss='mse',
          metrics=['acc'])

checkpoint_save_path = "./model_save/derain.ckpt" # 保存的路径

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,s
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值