tf.train.MonitoredTrainingSession()解析【精】

最近看了下cifar10源码,训练代码中使用了tf.train.SessionRunHook(),tf.train.MonitoredTrainingSession();查看官方API后终于有些眉目了,特记录备忘。

 

首先,先讲下tf.train.MonitoredTrainingSession();

 一 .MonitoredTrainingSession()

 

首先,tf.train.MonitorSession()从单词的字面意思理解是用于监控训练的回话,返回值是tf.train.MonitorSession()类的一个实例Object, tf.train.MonitorSession()会在下面讲。

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100
)

Args:

  •  is_chief:用于分布式系统中,用于判断该系统是否是chief,如果为True,它将负责初始化并恢复底层TensorFlow会话。如果为False,它将等待chief初始化或恢复TensorFlow会话。

  •  checkpoint_dir:一个字符串。指定一个用于恢复变量的checkpoint文件路径。

  •  scaffold:用于收集或建立支持性操作的脚手架。如果未指定,则会创建默认一个默认的scaffold。它用于完成图表

  •  hooks:SessionRunHook对象的可选列表。可自己定义SessionRunHook对象,也可用已经预定义好的SessionRunHook对象,如:tf.train.StopAtStepHook()设置停止训练的条件;tf.train.NanTensorHook(loss):如果loss的值为Nan则停止训练;

  •  chief_only_hooks:SessionRunHook对象列表。如果is_chief== True,则激活这些挂钩,否则忽略。

  •  

     save_checkpoint_secs:用默认的checkpoint saver保存checkpoint的频率(以秒为单位)。如果save_checkpoint_secs设置为None,不保存checkpoint。

  • save_summaries_steps:使用默认summaries saver将摘要写入磁盘的频率(以全局步数表示)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的summaries saver保存summaries。默认为100

  •  

    save_summaries_secs:使用默认summaries saver将摘要写入磁盘的频率(以秒为单位)。如果save_summaries_steps和save_summaries_secs都设置为None,则不使用默认的摘要保存。默认未启用。

  •  

    config:用于配置会话的tf.ConfigProtoproto的实例。它是tf.Session的构造函数的config参数。

  •  

     stop_grace_period_secs:调用close()后线程停止的秒数。

  •  

     log_step_count_steps:记录全局步/秒的全局步数的频率

Returns:          

       一个MonitoredSession() 实例。


下面主要介绍tf.train.MonitoredSession()类

 

二tf.train.MonitoredSession()类

官方文档给的定义是:

Session-like object that handles initialization, recovery and hooks.

是一个处理初始化,模型恢复,和处理Hooks的类似与Session的类。

Args:

  •  

    session_creator:制定用于创建回话的ChiefSessionCreator

  •  

    hooks:tf.train.SessionRunHook()实例的列表

Returns:          

         一个MonitoredSession 实例。

 

Example usage:

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        sess.run(train_op)

初始化:在创建一个MonitoredSession时,会按顺序执行以下操作:

  • 调用[Hooks]列表中每一个Hook的begin()函数
  • 通过scaffold.finalize()完成图graph的定义
  • 创建会话
  • 用Scaffold提供的初始化操作(op)来初始化模型
  • 如果给定checkpoint_dir中存在checkpoint文件,则用checkpoint恢复变量
  • 启动队列线程
  • 调用hook.after_create_session()

Run:当调用run()函数时,按顺序执行以下操作

  • 调用hook.before_run()
  • 用合并后的fetches 和feed_dict调用TensorFlow的session.run() (这里是真正调用tf.Session().run(fetches ,feed_dict))
  • 调用hook.after_run()
  • 返回用户需要的session.run()的结果
  • 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

Exit:当调用close()退出时,按顺序执行下列操作

  • 调用hook.end()
  • 关闭队列线程queuerunners和会话session
  • 在monitored_session的上下文中,抑制由于处理完所有输入抛出的OutOf Range错误。

 

tf.train.MonitoredTrainingSession()中的tf.train.SessRunHook()介绍与使用请看我的另一篇博客

才疏学浅,如有错误欢迎留言指出!

 

 

  • 9
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
TensorFlow 中,`tf.train.Saver` 是用于保存和恢复模型参数的类。它可以将模型的变量保存到一个二进制的 checkpoint 文件中,也可以从 checkpoint 文件中恢复模型的变量。 `tf.train.Saver` 的用法如下: ```python saver = tf.train.Saver(var_list=None, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, sharded=False, write_version=tf.train.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None) ``` 其中,`var_list` 参数指定需要保存或恢复的变量列表,如果不指定,则默认保存或恢复所有变量。`max_to_keep` 参数指定最多保存的 checkpoint 文件数量,`keep_checkpoint_every_n_hours` 参数指定保存 checkpoint 文件的时间间隔,`name` 参数指定 saver 的名称。 保存模型的变量: ```python import tensorflow as tf # 创建计算图 x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) logits = tf.matmul(x, W) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 训练模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_xs, batch_ys = ... sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys}) # 保存模型参数 saver = tf.train.Saver() saver.save(sess, './model.ckpt') ``` 在这个例子中,我们创建了一个包含一个全连接层的简单神经网络,并使用梯度下降法训练模型。在训练完成后,我们调用 `tf.train.Saver` 类的 `save` 方法将模型的参数保存到文件 `'./model.ckpt'` 中。 恢复模型的变量: ```python import tensorflow as tf # 创建计算图 x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) logits = tf.matmul(x, W) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 恢复模型参数 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, './model.ckpt') # 使用模型进行预测 test_x, test_y = ... predictions = sess.run(logits, feed_dict={x: test_x}) ``` 在这个例子中,我们创建了与之前相同的计算图,并使用 `tf.train.Saver` 类的 `restore` 方法从文件 `'./model.ckpt'` 中恢复模型的参数。恢复参数后,我们可以使用模型进行预测。需要注意的是,恢复模型参数时,需要在调用 `tf.global_variables_initializer()` 之前调用 `saver.restore` 方法。
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值