tensorflow笔记(二十二)—— Hooks

1.什么是Hooks?

中文直译为“钩子”,在tensorflow中概念:Hooks are tools that run in the process of training/evaluation of the model.*
Hooks是模型训练/测试过程中的工具,这些工具用于在训练/评估过程中执行特定任务。例如:

  • 控制训练EarlyStopping
  • 改变学习率
  • 打印一些中间日志,如loss、auc等
  • 保存checkpoint

这些hooks可以在以下几个地方生效:

  • when a session starts being used
  • before a call to the session.run()
  • after a call to the session.run()
  • when the session closed

2.怎么定义Hooks?

在tensorflow中,tf.training.SessionRunHook类及其派生类负责创建hooks,tf.training.SessionRunHook有5个接口函数,分别是begin, after_create_session, before_run, after_run, end。自定义一个Hook类:

class ExampleHook(SessionRunHook):
	def __init__(self):
		# Yor can init the hook here
	def begin(self):	
	    """在创建会话之前调用
	    调用begin()时,default graph会被创建,
	    可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
	    """
      	print('Starting the session.')
      	self.your_tensor = ...
    def after_create_session(self, session, coord):
    	"""tf.Session被创建后调用
  		调用后会指示所有的Hooks有一个新的会话被创建
    	Args:
      		session: A TensorFlow Session that has been created.
      		coord: A Coordinator object which keeps track of all threads.
    	"""
      	# When this is called, the graph is finalized and
      	# ops can no longer be added to the graph.
      	print('Session created.')
    def before_run(self, run_context):
    	"""在每个sess.run()执行之前调用
	    返回一个tf.train.SessRunArgs(fetches, feed_dict),fetches、feed_dict和sess.run()里概念一样。
	    实际上它们会和sess.run()中已定义的fetches和feed_dict合并一起执行。
	    Args:
	      run_context: A `SessionRunContext` object, 包含session的一些信息
	  	"""
      	print('Before calling session.run().')
      	return SessionRunArgs(self.your_tensor)
    def after_run(self, run_context, run_values):
    	"""在每个sess.run()之后调用
	    参数run_values是befor_run()中要求的op/tensor的返回值;
	    可以调用run_context.qeruest_stop()用于停止迭代
	    sess.run抛出任何异常after_run不会被调用
		"""
      	print('Done running one step. The value of my tensor: %s', run_values.results)
      	if you-need-to-stop-loop:
        	run_context.request_stop()
    def end(self, session):
      	print('Done with the session.')

除了自定义Hooks外,estimator有几个预制好的Hooks类:

  • StopAtStepHook: Request stop based on global_step
  • CheckpointSaverHook: saves checkpoint
  • LoggingTensorHook: outputs one or more tensor values to log
  • NanTensorHook: Request stop if given Tensor contains Nans.
  • SummarySaverHook: saves summaries to a summary writer

3.怎么执行Hooks

Hooks由 MonitoredSession.run()调用,具体方式:

hook1 = ExampleHook()
hook2 = CheckpointSaverHook()
your_hooks = [hook1, hook2]
with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:
    while not sess.should_stop():
        sess.run(your_fetches)

其背后大概执行流程是这样的:

call hooks.begin()
sess = tf.compat.v1.Session()
call hooks.after_create_session()
	while not stop is requested:
    call hooks.before_run()
    try:
		results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
		break
    call hooks.after_run()
call hooks.end()
sess.close()

给个具体的例子(from qq924178473:https://blog.csdn.net/h_jlwg6688/article/details/117514323):

# 定义自己的hook类,实现每个step执行后打印日志
class YourOwnHook(tf.train.SessionRunHook):
    def __init__(self):
        np.set_printoptions(suppress=True)
        np.set_printoptions(linewidth=400)
 
    def before_run(self, run_context):
        """返回SessionRunArgs和session run一起跑"""
        v1 = tf.get_collection('logis')
        prob = tf.get_collection('prob')
        return tf.train.SessionRunArgs(fetches=[v1, prob])
    def after_run(self, run_context, run_values):
        v1, batch_labels = run_values.results
        logger.info("logis value:{}".format(v1))
        print("prob :",batch_labels)

# 实现estimator
class MyEstimator(tf.estimator.Estimator):
    def __init__(self,
                 model_dir,
                 hidden_units,
                 optimizer,
                 activation_fn,
                 dropout=None,
                 batch_norm=False,
                 weight_column=None,
                 label_vocabulary=None,
                 loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,
                 params=None,
                 config=None,
                 warm_start_from=None):
 
	    def model_fn(features,labels,mode):
	         inputs_layers =tf.feature_column.input_layer(features,feature_columns)
	         # 自定义网络层
	         user_hidden_fn = DNNModel(
	            hidden_units=hidden_units,
	             activation_fn=activation_fn,
	             dropout=dropout,
	             batch_norm=batch_norm,
	             name="user_dnn"
	         )
	
	         user_hidden_net = user_hidden_fn(inputs_layers,mode=mode)
	         with tf.name_scope("logits"):
	             logits = tf.keras.layers.Dense(units=2, activation=None)(user_hidden_net)
	
	         loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(labels=tf.reshape(labels['label'],[-1]),logits=logits))
	   
	         train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
	         # Compute predictions.
	         predicted_classes = tf.argmax(logits, 1)
	         # 设置模型评价指标
	         accuracy = tf.metrics.accuracy(labels=labels["label"],
	                                predictions=predicted_classes,
	                                name='acc_op')
	         auc = tf.metrics.auc(labels=labels["label"],predictions=predicted_classes,name='auc_op')
	         metrics = {'accuracy': accuracy,'auc':auc}
	         tf.summary.scalar('accuracy', accuracy[1])
	         if mode==tf.estimator.ModeKeys.TRAIN:
	             # 定义自定义钩子函数,并设置要输出的中间值的名称
	             ownhook = YourOwnHook()
	             tf.add_to_collection('logis', logits)
	             tf.add_to_collection('prob',predicted_classes)
	             # 将自定义钩子添加到训练的estimator中
	             return tf.estimator.EstimatorSpec(mode=mode,loss=loss,train_op=train_op,training_hooks=[ownhook])
	
	         if mode == tf.estimator.ModeKeys.EVAL:
	             return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL,loss=loss,eval_metric_ops=metrics)
     super(MyEstimator,self).__init__(
         model_fn=model_fn,model_dir=model_dir,params=params,config=config,warm_start_from=warm_start_from
     )

Reference

session_run_hook.py源码
Hook? tf.train.SessionRunHook()介绍【精】
TensorFlow系列——在自定义的标准estimator中使用tensorboard及打印中间数据

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
React Hooks 是 React 16.8 版本引入的新特性,它允许我们在无需编写类组件的情况下使用状态和其他 React 特性。对于二级路由的实现,我们可以使用 React Router 来管理路由。 首先,确保你已经安装了 `react-router-dom` 包。接下来,你需要创建一个包含二级路由的组件。这个组件将作为父级路由的容器,并且可以渲染其他子组件。 首先,在你的项目中导入所需的依赖: ```javascript import React from 'react'; import { BrowserRouter as Router, Route, Switch } from 'react-router-dom'; ``` 然后,创建一个包含二级路由的组件: ```javascript const SubRoutesComponent = () => { return ( <Router> <Switch> <Route exact path="/subroute1"> {/* 渲染子路由1的组件 */} </Route> <Route exact path="/subroute2"> {/* 渲染子路由2的组件 */} </Route> <Route exact path="/subroute3"> {/* 渲染子路由3的组件 */} </Route> </Switch> </Router> ); }; ``` 在上面的示例中,我们创建了三个子路由,分别是 `/subroute1`、`/subroute2` 和 `/subroute3`。你可以根据自己的需求添加更多的子路由。 最后,在你的应用程序中使用该组件: ```javascript const App = () => { return ( <Router> <Switch> <Route exact path="/"> {/* 渲染主页组件 */} </Route> <Route path="/subroutes"> <SubRoutesComponent /> </Route> </Switch> </Router> ); }; ``` 在上面的示例中,我们将 `SubRoutesComponent` 组件作为 `/subroutes` 路径的子路由引入。这意味着当访问 `/subroutes` 时,将渲染 `SubRoutesComponent` 组件,并且根据子路由的路径进行匹配和渲染。 这样就实现了在 React 中使用二级路由。你可以根据自己的需求修改和扩展这个示例来满足你的项目要求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值