【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复

本文详细介绍了 TensorFlow 中如何使用 tf.train.Checkpoint 类来保存和恢复模型参数。通过实例展示了保存、加载模型的步骤,包括如何在即时执行模式下延迟恢复变量,以及如何利用 CheckpointManager 控制 Checkpoint 文件数量和自定义文件编号。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  1. 基础
    TensorFlow 基础
    TensorFlow 模型建立与训练
    基础示例:多层感知机(MLP)
    卷积神经网络(CNN)
    循环神经网络(RNN)
    深度强化学习(DRL)
    Keras Pipeline
    自定义层、损失函数和评估指标
    常用模块 tf.train.Checkpoint :变量的保存与恢复
    常用模块 TensorBoard:训练过程可视化
    常用模块 tf.data :数据集的构建与预处理
    常用模块 TFRecord :TensorFlow 数据集存储格式
    常用模块 tf.function :图执行模式
    常用模块 tf.TensorArray :TensorFlow 动态数组
    常用模块 tf.config:GPU 的使用与分配

  2. 部署
    TensorFlow 模型导出
    TensorFlow Serving
    TensorFlow Lite

  3. 大规模训练与加速
    TensorFlow 分布式训练
    使用 TPU 训练 TensorFlow 模型

  4. 扩展
    TensorFlow Hub 模型复用
    TensorFlow Datasets 数据集载入

  5. 附录
    强化学习基础简介


Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel

tf.train.Checkpoint

很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle 存储 model.variables。但不幸的是,TensorFlow 的变量类型 ResourceVariable 并不能被序列化。

好在 TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save()restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizertf.Variabletf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:

checkpoint = tf.train.Checkpoint(model=model)

这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model 的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer ,我们可以这样写:

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

保存参数

接下来,当模型训练完成需要保存的时候,使用:

checkpoint.save(save_path_with_prefix)

就可以。 save_path_with_prefix 是保存文件的目录 + 前缀

  • 例如,在源代码目录建立一个名为 save 的文件夹并调用一次 checkpoint.save('./save/model.ckpt') ,我们就可以在 save 目录下发现名为 checkpointmodel.ckpt-1.indexmodel.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。

载入之前保存的参数

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号

  • 例如,调用 checkpoint.restore('./save/model.ckpt-1') 就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。

当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。

  • 例如如果 save 目录下有 model.ckpt-1.indexmodel.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10

保存变量+恢复变量

总体而言,恢复与保存变量的典型代码框架如下:

# train.py 模型训练阶段

model = MyMod
### 回答1: 你好,我理解你是想通过 CSDN 学习 C 知识,并且你也使用 ChatGPT 这种类型的 AI。请问你有什么问题需要我回答吗? 关于你提到的问题,我可以解释一下。tf.train.Saver() 是 TensorFlow 中用于保存恢复模型的类。它允许你将 TensorFlow 变量保存到磁盘,并从这些变量恢复。通常在训练过程中,你需要定期保存模型以便稍后恢复或进行推断。tf.train.Saver() 的常见用法如下: ```python import tensorflow as tf # 声明 TensorFlow 变量 weights = tf.Variable(...) biases = tf.Variable(...) # 初始化 TensorFlow 变量 init_op = tf.global_variables_initializer() # 训练模型 # 创建 Saver 对象 saver = tf.train.Saver() # 保存模型 with tf.Session() as sess: sess.run(init_op) # 训练模型 saver.save(sess, 'my_model') # 保存模型到文件 'my_model' ``` 当你需要恢复模型时,可以使用 tf.train.Saver() 的 restore() 方法,如下所示: ```python # 创建 Saver 对象 saver = tf.train.Saver() # 加载模型 with tf.Session() as sess: saver.restore(sess, 'my_model') # 从文件 'my_model' 恢复模型 # 执行推断 ``` 希望这可以回答你的问题。如果你有其他问题,请随时提出。 ### 回答2: tf.train.Saver()是一个用于保存、加载以及管理模型参数的类。在TensorFlow中,我们通常使用tf.train.Saver()来保存恢复模型的变量。 在TensorFlow中,模型的训练过程通常会生成一些变量,比如神经网络的权重和偏差。而tf.train.Saver()类提供了一种方法,可以将这些变量保存到文件中。通过调用tf.train.Saver().save()方法,可以将模型的变量保存在一个checkpoint文件中,以供将来使用。 除了保存模型变量tf.train.Saver()还可以用于加载已保存的模型变量。通过调用tf.train.Saver().restore()方法,可以从checkpoint文件中载入模型的变量,并且将其赋值给指定的TensorFlow变量。这样,我们就可以在程序中使用这些已保存的模型变量,而无需重新训练模型。 另外,tf.train.Saver()还具备一些其他的功能,比如可以指定保存和加载的变量以及保存恢复模型的过程是否应该包含模型的图结构。 总结起来,tf.train.Saver()是一个用于保存、加载和管理TensorFlow模型参数的类。它提供了保存恢复模型变量的功能,可以确保模型的训练结果可以方便地在之后的使用中进行加载和重用。 ### 回答3: tf.train.Saver()是tensorflow中用于模型参数的保存恢复的类。 在tensorflow中,模型参数通常是在训练过程中不断更新的,而为了保留训练过程中的模型参数,我们可以使用tf.train.Saver()类来保存这些参数。tf.train.Saver()类提供了保存恢复模型的方法,可以将模型的参数保存到文件中,并在需要的时候恢复这些参数。 保存模型参数是通过调用tf.train.Saver()类的save()方法实现的。save()方法需要传入一个session和一个保存路径,表示将当前模型的参数保存到指定的路径下。保存的参数可以是全局变量、权重、偏置等等。 恢复模型参数是通过调用tf.train.Saver()类的restore()方法实现的。restore()方法需要传入一个session和一个保存路径,表示从指定的路径中恢复模型的参数。恢复参数时,tensorflow会自动判断模型的参数是否当前模型的参数匹配,如果匹配,则恢复参数;如果不匹配,则会抛出异常。 使用tf.train.Saver()类可以实现模型的断点续训。即在训练过程中,可以将当前的模型参数保存到文件中。如果训练过程中发生意外,可以在恢复训练时,加载之前保存的模型参数,从上一次中断的地方继续训练。 总之,tf.train.Saver()是tensorflow中用于保存恢复模型参数的重要工具,它提供了方便的接口,使得我们可以灵活地管理模型参数,实现模型的保存恢复和断点续训。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值