tf.train.MonitoredSession 简介

在run过程中的集成一些操作,比如输出log,保存,summary 等


基类一般用在infer阶段,训练阶段使用它的子类
tf.train.MonitoredTrainingSession

1 MonitoredTrainingSession

1.1 构造函数

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200
)

官方例子

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
  while not sess.should_stop():
    sess.run(train_op)

首先,当MonitoredSession初始化的时候,会按顺序执行下面操作:

  • 调用hook的begin()函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。
  • 通过调用scaffold.finalize()初始化计算图
    创建会话
  • 通过初始化Scaffold提供的操作(op)来初始化模型
  • 如果checkpoint存在的话,restore模型的参数
  • launches queue runners
  • 调用hook.after_create_session()

然后,当run()函数运行的时候,按顺序执行下列操作:

  • 调用hook.before_run()
  • 调用TensorFlow的 session.run()
  • 调用hook.after_run()
  • 返回用户需要的session.run()的结果
  • 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

最后,当调用close()退出时,按顺序执行下列操作:

  • 调用hook.end()
  • 关闭队列和会话
  • 阻止OutOfRange错误

1.2 Hook

所以这些钩子函数就是重点关注的对象

.1 LoggingTensorHook

tf.train.LoggingTensorHook 官方说明

Prints the given tensors every N local steps, every N seconds, or at end.

__init__(
    tensors,
    every_n_iter=None,
    every_n_secs=None,
    formatter=None
)
  • tensors: dict that maps string-valued tags to tensors/tensor names, or iterable of tensors/tensor names.

用法举例

# Set up logging for predictions
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=50)

.2 SummarySaverHook

tf.train.SummarySaverHook

Saves summaries every N steps

__init__(
    save_steps=None,
    save_secs=None,
    output_dir=None,
    summary_writer=None,
    scaffold=None,
    summary_op=None
)

output_dir 填 路径
summary_op 填 tf.summary.merge_all

.3 CheckpointSaverHook

tf.train.CheckpointSaverHook
MonitoredTrainingSession 只有 save_checkpoint_secs, 没有按step保存的选项
* Saves checkpoints every N steps or seconds

__init__(
    checkpoint_dir,
    save_secs=None,
    save_steps=None,
    saver=None,
    checkpoint_basename='model.ckpt',
    scaffold=None,
    listeners=None
)

必填 saver, save_secs 或者 save_steps

.4 NanTensorHook

tf.train.NanTensorHook
感觉是用来调试的,加到训练过程中可能会拖慢train

  • Monitors the loss tensor and stops training if loss is NaN.
    Can either fail with exception or just stop training.
__init__(
    loss_tensor,
    fail_on_nan_loss=True
)

.5 FeedFnHook

tf.train.FeedFnHook
看着像用来产生 feed_dict

Runs feed_fn and sets the feed_dict accordingly

__init__(feed_fn)

.6 GlobalStepWaiterHook

tf.train.GlobalStepWaiterHook
分布式用

.7 ProfilerHook

tf.train.ProfilerHook

This hook delays execution until global step reaches to wait_until_step. It is used to gradually start workers in distributed settings. One example usage would be setting wait_until_step=int(K*log(task_id+1)) assuming that task_id=0 is the chief

reference

tf.train.MonitoredSession
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/MonitoredSession
resnet_main.py
https://github.com/tensorflow/models/blob/master/research/resnet/resnet_main.py
tf.train.MonitoredTrainingSession
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/MonitoredTrainingSession
使用自己的数据集进行一次完整的TensorFlow训练
https://zhuanlan.zhihu.com/p/32490882
tf.train.LoggingTensorHook
https://www.tensorflow.org/api_docs/python/tf/train/LoggingTensorHook
tf.train.SummarySaverHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/SummarySaverHook
tf.train.CheckpointSaverHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/CheckpointSaverHook
tf.train.NanTensorHook
https://www.tensorflow.org/versions/master/api_docs/python/tf/train/NanTensorHook#__init__

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TensorFlow 中,`tf.train.Saver` 是用于保存和恢复模型参数的类。它可以将模型的变量保存到一个二进制的 checkpoint 文件中,也可以从 checkpoint 文件中恢复模型的变量。 `tf.train.Saver` 的用法如下: ```python saver = tf.train.Saver(var_list=None, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, sharded=False, write_version=tf.train.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None) ``` 其中,`var_list` 参数指定需要保存或恢复的变量列表,如果不指定,则默认保存或恢复所有变量。`max_to_keep` 参数指定最多保存的 checkpoint 文件数量,`keep_checkpoint_every_n_hours` 参数指定保存 checkpoint 文件的时间间隔,`name` 参数指定 saver 的名称。 保存模型的变量: ```python import tensorflow as tf # 创建计算图 x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) logits = tf.matmul(x, W) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 训练模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_xs, batch_ys = ... sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys}) # 保存模型参数 saver = tf.train.Saver() saver.save(sess, './model.ckpt') ``` 在这个例子中,我们创建了一个包含一个全连接层的简单神经网络,并使用梯度下降法训练模型。在训练完成后,我们调用 `tf.train.Saver` 类的 `save` 方法将模型的参数保存到文件 `'./model.ckpt'` 中。 恢复模型的变量: ```python import tensorflow as tf # 创建计算图 x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) logits = tf.matmul(x, W) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 恢复模型参数 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, './model.ckpt') # 使用模型进行预测 test_x, test_y = ... predictions = sess.run(logits, feed_dict={x: test_x}) ``` 在这个例子中,我们创建了与之前相同的计算图,并使用 `tf.train.Saver` 类的 `restore` 方法从文件 `'./model.ckpt'` 中恢复模型的参数。恢复参数后,我们可以使用模型进行预测。需要注意的是,恢复模型参数时,需要在调用 `tf.global_variables_initializer()` 之前调用 `saver.restore` 方法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值