tensorflow中使用save和restore保存和恢复模型

我们在训练模型过程中,有时训练一段时间后,往往想要在验证集上验证一下,模型是否存在过拟合,然后视验证情况,再选择继续训练还是修改模型参数。这时tensorflow提供的Saver类,就能很好的帮助到我们。
当我们保存一个模型到指定路径后,还目录下将会出现四种类型的文件:

checkpoint: 具有最近检查点列表的协议缓存区
.data: 保存模型中的变量
.index: 标识检查点
.meta: 保存模型中计算图的结构信息

1、tf.train.Saver( )

首先需要在程序中定义一个saver操作,该定义在会话结构之外。

import tensorflow as tf
...
saver = tf.train.Saver()
...
with tf.Session() as sess:
    ...

这样一个saver操作就定义好了。tf.train.Saver( )有几个我们平时常用到的参数,具体如下:

max_to_keep: 设置保存最近的检查点文件的个数,例如max_to_keep=4,就是只保存最新的四个模型。
keep_checkpoint_every_n_hours: 设置每隔多长时间保存一次模型。
savable_variables: 可以设置将要保存的tensor。如tf.train.Saver([w1, w2]),就是只保存w1和w2。如果不指定任何想要保存的tensor,saver默认保存所有的tensor。

2、saver.save( )

在使用tf.train.Saver( )创建了saver操作之后,我们就可以在一个会话中保存我们的模型。

...
with tf.Session() as sess:
    ...
    for epoch in range(10):
        ...
        saver.save(sess, model_path, global_step=epoch, write_meta_graph=False)
        ...

使用上面代码中的saver.save( )就可以按照我们的要求保存模型。其中参数说明如下:

sess: 会话对象
model_path: 模型保存的路径
global_step=epoch: 可选,在我们保存的文件名字中,加上迭代次数,以方便我们区分保 存的文件是经过多少次的训练迭代。如global_step = 2,则我们保存的文件名字为-2.data-00000-of-00001,-2.index,-2.meta。
write_meta_graph: 可选,False: 只保存一次.meta文件;True:根据我们设置的保存次 数,保存多次.meta文件。这里对这个参数加一点说明:因为模型一旦建立好之后,计算图的结构就确定了,所以每次保存的.meta文件都是一样的,有时为了节省存储空间,我们选择只保存一次.meta文件。

3、saver.restore( )

在保存了一个模型之后,我们使用saver.restore( )来恢复模型。恢复操作也需要在session会话中。我们可以创建一个新的会话:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model_path/-2.meta') # 以.meta文件名为-2.meta为例
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    ...

我们首先要通过saver = tf.train.import_meta_graph(‘model_path/-2.meta’)加载模型的计算图结构,然后通过saver.restore(sess, tf.train.latest_checkpoint(‘model_path’))来恢复我们保存的所有变量和操作。其中tf.train.latest_checkpoint(‘model_path’)是从最近的检查点中恢复模型。

以上就是保存和恢复全部模型的操作。在实际进行模型优化时,有时我们会对原来的模型进行修改,如增加网络的深度,重新定义一些权重变量,重新定义精确度指标等。这时,我们就可以通过变量或操作的名字来加载指定的变量或操作。

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model_path/-2.meta') # 以.meta文件名为-2.meta为例
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    graph = tf.get_default_graph()
    # 加载网络权重变量w1和w2,"w1:0"中,weight1为定义w1变量时指定的名字,当此tensor没有重复时,后面加上0
    w1 = graph.get_tensor_by_name("weight1:0")
    w2 = graph.get_tensor_by_name("weight2:0")
    # 恢复网络中的第七全连接层,fully_connected7为定义fc7时指定的名字,当此tensor没有重复时,后面加上0
    fc7 = graph.get_tensor_by_name("fully_connected7:0")

加载到指定的变量后,我们就可以在其基础上,对原来的模型进行修改。
下面是我在github上对save和restore验证的代码地址:
https://github.com/Demohai/my_tensorflow_learn/tree/master/save_and_restore_models

TensorFlow ,可以使用 `tf.train.Saver` 类来管理和保存模型。下面是一个简单的示例: ```python import tensorflow as tf # 假设你已经构建了一个模型并定义了损失和优化器 # ... # 创建 Saver 对象 saver = tf.train.Saver() with tf.Session() as sess: # 训练模型 # 在训练结束后保存模型 save_path = saver.save(sess, "/path/to/model.ckpt") print("Model saved in path: %s" % save_path) ``` 这段代码,我们首先导入 TensorFlow,然后创建一个 `Saver` 对象。在训练结束后,我们可以调用 `saver.save()` 方法保存模型。该方法需要两个参数:`sess` 表示当前的 TensorFlow 会话,`save_path` 表示模型保存的路径。在保存模型时,TensorFlow 会将模型的变量值保存在一个名为 `model.ckpt` 的文件。 如果要恢复模型,可以使用 `Saver` 类的 `restore()` 方法。例如: ```python import tensorflow as tf # 假设你已经构建了一个模型并定义了损失和优化器 # ... # 创建 Saver 对象 saver = tf.train.Saver() with tf.Session() as sess: # 恢复模型 saver.restore(sess, "/path/to/model.ckpt") print("Model restored.") # 进行预测或评估等操作 ``` 在这个示例,我们首先导入 TensorFlow,然后创建一个 `Saver` 对象。在恢复模型时,我们可以调用 `saver.restore()` 方法,并传入当前 TensorFlow 会话和模型保存的路径。这个方法会将模型的变量值从文件加载到当前的 TensorFlow 会话
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值