文章目录
1. 为什么要有 Hook?
SessionRunHook
用来扩展那些将 session
封装起来的高级 API 的 session.run
的行为。
2. Hook 有什么用?
SessionRunHook
对于追踪训练过程、报告进度、实现提前停止等非常有用。
SessionRunHook
以观察者模式运行。SessionRunHook
的设计中有几个非常重要的时间点:
session
使用前session.run()
调用之前session.run()
调用之后session
关闭前
SessionRunHook
封装了一些可重用、可组合的计算,并且可以顺便完成 session.run()
的调用。利用 Hook,我们可以为 run()
调用添加任何的 ops或tensor/feeds;并且在 run()
调用完成后获得请求的输出。Hook 可以利用 hook.begin()
方法向图中添加 ops,但请注意:在 begin()
方法被调用后,计算图就 finalized 了。
3. TF 内置了哪些 Hook?
TensorFlow 中已经内置了一些 Hook:
StopAtStepHook
:根据 global_step 来停止训练。CheckpointSaverHook
:保存 checkpoint。LoggingTensorHook
:以日志的形式输出一个或多个 tensor 的值。NanTensorHook
:如果给定的Tensor
包含 Nan,就停止训练。SummarySaverHook
:保存 summaries 到一个 summary writer。
4. TF 怎么自定义 Hook?
上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。
下面是自定义 Hook 的编写模板:
class ExampleHook(tf.train.SessionRunHook):
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
self.your_tensor = ...
def after_create_session(self, session, coord):
# When this is called, the graph is finalized and
# ops can no longer be added to the graph.
print('Session created.')
def before_run(self, run_context):
print('Before calling session.run().')
return SessionRunArgs(self.your_tensor)
def after_run(self, run_context, run_values): # run_values 为 sess.run 的结果
print('Done running one step. The value of my tensor: %s',
run_values.results)
if you-need-to-stop-loop:
run_context.request_stop()
def end(self, session):
print('Done with the session.')
上面是官方给的解释