1.什么是Hooks?
中文直译为“钩子”,在tensorflow中概念:Hooks are tools that run in the process of training/evaluation of the model.*
Hooks是模型训练/测试过程中的工具,这些工具用于在训练/评估过程中执行特定任务。例如:
- 控制训练EarlyStopping
- 改变学习率
- 打印一些中间日志,如loss、auc等
- 保存checkpoint
- …
这些hooks可以在以下几个地方生效:
- when a session starts being used
- before a call to the
session.run()
- after a call to the
session.run()
- when the session closed
2.怎么定义Hooks?
在tensorflow中,tf.training.SessionRunHook类及其派生类负责创建hooks,tf.training.SessionRunHook有5个接口函数,分别是begin, after_create_session, before_run, after_run, end。自定义一个Hook类:
class ExampleHook(SessionRunHook):
def __init__(self):
# Yor can init the hook here
def begin(self):
"""在创建会话之前调用
调用begin()时,default graph会被创建,
可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
"""
print('Starting the session.')
self.your_tensor = ...
def after_create_session(self, session, coord):
"""tf.Session被创建后调用
调用后会指示所有的Hooks有一个新的会话被创建
Args:
session: A TensorFlow Session that has been created.
coord: A Coordinator object which keeps track of all threads.
"""
# When this is called, the graph is finalized and
# ops can no longer be added to the graph.
print('Session created.')
def before_run(self, run_context):
"""在每个sess.run()执行之前调用
返回一个tf.train.SessRunArgs(fetches, feed_dict),fetches、feed_dict和sess.run()里概念一样。
实际上它们会和sess.run()中已定义的fetches和feed_dict合并一起执行。
Args:
run_context: A `SessionRunContext` object, 包含session的一些信息
"""
print('Before calling session.run().')
return SessionRunArgs(self.your_tensor)
def after_run(self, run_context, run_values):
"""在每个sess.run()之后调用
参数run_values是befor_run()中要求的op/tensor的返回值;
可以调用run_context.qeruest_stop()用于停止迭代
sess.run抛出任何异常after_run不会被调用
"""
print('Done running one step. The value of my tensor: %s', run_values.results)
if you-need-to-stop-loop:
run_context.request_stop()
def end(self, session):
print('Done with the session.')
除了自定义Hooks外,estimator有几个预制好的Hooks类:
- StopAtStepHook: Request stop based on global_step
- CheckpointSaverHook: saves checkpoint
- LoggingTensorHook: outputs one or more tensor values to log
- NanTensorHook: Request stop if given
Tensor
contains Nans. - SummarySaverHook: saves summaries to a summary writer
3.怎么执行Hooks
Hooks由 MonitoredSession.run()调用,具体方式:
hook1 = ExampleHook()
hook2 = CheckpointSaverHook()
your_hooks = [hook1, hook2]
with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:
while not sess.should_stop():
sess.run(your_fetches)
其背后大概执行流程是这样的:
call hooks.begin()
sess = tf.compat.v1.Session()
call hooks.after_create_session()
while not stop is requested:
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
给个具体的例子(from qq924178473:https://blog.csdn.net/h_jlwg6688/article/details/117514323):
# 定义自己的hook类,实现每个step执行后打印日志
class YourOwnHook(tf.train.SessionRunHook):
def __init__(self):
np.set_printoptions(suppress=True)
np.set_printoptions(linewidth=400)
def before_run(self, run_context):
"""返回SessionRunArgs和session run一起跑"""
v1 = tf.get_collection('logis')
prob = tf.get_collection('prob')
return tf.train.SessionRunArgs(fetches=[v1, prob])
def after_run(self, run_context, run_values):
v1, batch_labels = run_values.results
logger.info("logis value:{}".format(v1))
print("prob :",batch_labels)
# 实现estimator
class MyEstimator(tf.estimator.Estimator):
def __init__(self,
model_dir,
hidden_units,
optimizer,
activation_fn,
dropout=None,
batch_norm=False,
weight_column=None,
label_vocabulary=None,
loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,
params=None,
config=None,
warm_start_from=None):
def model_fn(features,labels,mode):
inputs_layers =tf.feature_column.input_layer(features,feature_columns)
# 自定义网络层
user_hidden_fn = DNNModel(
hidden_units=hidden_units,
activation_fn=activation_fn,
dropout=dropout,
batch_norm=batch_norm,
name="user_dnn"
)
user_hidden_net = user_hidden_fn(inputs_layers,mode=mode)
with tf.name_scope("logits"):
logits = tf.keras.layers.Dense(units=2, activation=None)(user_hidden_net)
loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(labels['label'],[-1]),logits=logits))
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
# Compute predictions.
predicted_classes = tf.argmax(logits, 1)
# 设置模型评价指标
accuracy = tf.metrics.accuracy(labels=labels["label"],
predictions=predicted_classes,
name='acc_op')
auc = tf.metrics.auc(labels=labels["label"],predictions=predicted_classes,name='auc_op')
metrics = {'accuracy': accuracy,'auc':auc}
tf.summary.scalar('accuracy', accuracy[1])
if mode==tf.estimator.ModeKeys.TRAIN:
# 定义自定义钩子函数,并设置要输出的中间值的名称
ownhook = YourOwnHook()
tf.add_to_collection('logis', logits)
tf.add_to_collection('prob',predicted_classes)
# 将自定义钩子添加到训练的estimator中
return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[ownhook])
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,loss=loss,eval_metric_ops=metrics)
super(MyEstimator,self).__init__(
model_fn=model_fn,model_dir=model_dir,params=params,config=config,warm_start_from=warm_start_from
)
Reference
session_run_hook.py源码
Hook? tf.train.SessionRunHook()介绍【精】
TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据