tensorflow hook架构

介绍

所有的hook都继承自SessionRunHook,定义在session_run_hook.py 文件里。其包含五个通用接口:

  1. def begin(self)
  2. def after_create_session(self, session, coord)
  3. def before_run(self, run_context)
  4. def after_run(self, run_context, run_values)
  5. def end(self, session)

源码

所以先要详细看下每个接口函数的介绍,源码不是很多,如下:

class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""

  def begin(self):
    """Called once before using the session.

    in This hook api, you can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. 
    """
    pass

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Called when new TensorFlow session is created.

    This is called to signal the hooks that a new session has been created.

    * When this is called, the graph is finalized already and ops can no longer be added
        to the graph.
    * This method will also be called as a result of recovering a wrapped
        session, not only at the beginning of the overall session.即不仅在
        session第一次初始化时会调用,当从一个错误或者sumamary恢复时时也会调用。调用
        这个API之前表示session已经创建好了.

    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run(). 在每次run之前调用

    You can return from this call a `SessionRunArgs` object indicating ops or
    tensors to add to the upcoming `run()` call.  These ops/tensors will be run
    together with the ops/tensors originally passed to the original run() call.
    The run args you return can also contain feeds to be added to the run()
    call.

    The `run_context` argument is a `SessionRunContext` that provides
    information about the upcoming `run()` call: the originally requested
    op/tensors, the TensorFlow Session.
    

    Args:
      run_context: A `SessionRunContext` object.

    Returns:
      None or a `SessionRunArgs` object.
      即返回一个SessionRunArgs对象。SessionRunArgs是一个nameduple,里面包含了fetch,
      feed成员。返回的SessionRunArgs里的fetch和feed都会加到sess.run(fetch,feed)里的
      fetch和feed参数里
    """
    return None

  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """Called after each call to run(). 这个api是在run结束之后会被运行到

    The `run_values` argument contains results of requested ops/tensors by
    `before_run()`. run_values和sess.run(fetch)里的fetch结构是一样的。可以参考
    SessionRunValues的介绍,即
    1. fetch如果是张量,则run_values也是一个张量的具体值;
    2. fetch是一个张量列表,则run_values也是一个相同size的列表;
    3. fetch是一个词典,则run_values也是一个词典

    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass

  def end(self, session):  # pylint: disable=unused-argument
    """Called at the end of session.

    The `session` argument can be used in case the hook wants to run final ops,
    such as saving a last checkpoint.

    If `session.run()` raises exception other than OutOfRangeError or
    StopIteration then `end()` is not called.
    Note the difference between `end()` and `after_run()` behavior when
    `session.run()` raises OutOfRangeError or StopIteration. In that case
    `end()` is called but `after_run()` is not called.

    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass
   


----------


class SessionRunArgs(
    collections.namedtuple("SessionRunArgs",
                           ["fetches", "feed_dict", "options"])):
  """Represents arguments to be added to a `Session.run()` call.

  Args:
    fetches: 和Session.Run()里的fetches参数定义是一样的.
      Can be a single tensor or op, a list of 'fetches' or a dictionary
      of fetches.  For example:
        fetches = global_step_tensor
        fetches = [train_op, summary_op, global_step_tensor]
        fetches = {'step': global_step_tensor, 'summ': summary_op}
      Note that this can recurse as expected:
        fetches = {'step': global_step_tensor,
                   'ops': [train_op, check_nan_op]}
    feed_dict: Exactly like the `feed_dict` argument to `Session.Run()`.
    根据https://stackoverflow.com/questions/37849322/how-to-understand-the-term-tensor-in-tensorflow/37870634#37870634
    feed时key可以是tensor,也可以是tensor的名字
    options: Exactly like the `options` argument to `Session.run()`, i.e., a
      config_pb2.RunOptions proto.
  """

  def __new__(cls, fetches, feed_dict=None, options=None):
    return super(SessionRunArgs, cls).__new__(cls, fetches, feed_dict, options)

实际运行run的地方

在estimator.py里

  def _train_model(self, input_fn, hooks, saving_listeners):
    ...
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, model_fn_lib.ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      estimator_spec = self._call_model_fn(
          features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      ...
      
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss

实际运行hook的地方

知道了每个API的作用后,接下来的问题就是执行hook的地方在哪呢?
查看源代码最终发现是在 monitored_session.py文件 类class _HookedSession(_WrappedSession)
里定义的run函数里

def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    actual_fetches = {'caller': fetches}
    #将fetch和feed保存在SessionRunArgs这个tupe里
    run_context = session_run_hook.SessionRunContext(
        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
        session=self._sess)
    #调用相关hook函数将hook里的fetch ops和feed合并到actual_fetches、feed_dict里
    #可以看下面的定义
    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
                                           feed_dict, options)

    #运行run
    outputs = _WrappedSession.run(self,
                                  fetches=actual_fetches,
                                  feed_dict=feed_dict,
                                  options=options,
                                  run_metadata=run_metadata)
    #run一次结束后调用每个hook的after_run函数来处理返回值
    for hook in self._hooks:
      hook.after_run(
          run_context,
          session_run_hook.SessionRunValues(
              results=outputs[hook] if hook in outputs else None,
              options=options,
              run_metadata=run_metadata))
    self._should_stop = self._should_stop or run_context.stop_requested

    return outputs['caller'] #这里只返回caller这个key对应的值,hook里加入的fetch不返回

  def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict,
                            options):
    hook_feeds = {}
    for hook in self._hooks:
      #request 是一个SessionRunArgs tuple
      request = hook.before_run(run_context)
      if request is not None:
        if request.fetches is not None:
          #将hook作为key,request.fetches作为value保存到fetch_dict里
          fetch_dict[hook] = request.fetches
        if request.feed_dict:
          hook_feeds.update(request.feed_dict)
        if request.options:
          self._merge_run_options(options, request.options)

    if not hook_feeds:
      return user_feed_dict

    if not user_feed_dict:
      return hook_feeds
      
    hook_feeds.update(user_feed_dict)
    return hook_feeds

hook实例

class IteratorInitializerHook(tf.train.SessionRunHook):
    """Hook to initialise data iterator after Session is created."""

    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None

    def after_create_session(self, session, coord):
        """Initialise the iterator after the session has been created."""
        self.iterator_initializer_func(session)


# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
    """Return the input function to get the training data.

    Args:
        batch_size (int): Batch size of training iterator that is returned
                          by the input function.
        mnist_data (Object): Object holding the loaded mnist data.

    Returns:
        (Input function, IteratorInitializerHook):
            - Function that returns (features, labels) when called.
            - Hook to initialise input iterator.
    """

    #这个类里有个成员函数iterator_initializer_func,当session初始化完后会调用到,所以可以用来
    #初始化dataset的iterator
    iterator_initializer_hook = IteratorInitializerHook()

    def train_inputs():
        """Returns training set as Operations.

        Returns:
            (features, labels) Operations that iterate over the dataset
            on every evaluation
        """
        with tf.name_scope('Training_data'):
            # Get Mnist data
            images = mnist_data.train.images.reshape([-1, 28, 28, 1])
            labels = mnist_data.train.labels
            # Define placeholders
            images_placeholder = tf.placeholder(
                images.dtype, images.shape)
            labels_placeholder = tf.placeholder(
                labels.dtype, labels.shape)
            # Build dataset iterator
            dataset = tf.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))
            dataset = dataset.repeat(None)  # Infinite iterations
            dataset = dataset.shuffle(buffer_size=10000)
            dataset = dataset.batch(batch_size)
            #这个迭代器要经过session.run后才能使用
            iterator = dataset.make_initializable_iterator()
            next_example, next_label = iterator.get_next()
            # Set runhook to initialize iterator
            iterator_initializer_hook.iterator_initializer_func = \
                lambda sess: sess.run(
                    iterator.initializer,
                    feed_dict={images_placeholder: images,
                               labels_placeholder: labels})
            # Return batched (features, labels)
            return next_example, next_label

    # Return function and hook
    return train_inputs, iterator_initializer_hook

另一个实例

import tensorflow as tf

class trainhook(tf.train.SessionRunHook):
  def __init__(self, ops):
    super(trainhook, self).__init__()
    self.ops = ops

  def before_run(self, run_context):
    return tf.train.SessionRunArgs(self.ops)

  def after_run(self,run_context, run_values):
    print 'after run : ', run_values.results
    exit(0)

self.ops和results会有一模一样的形状。具体可以参考

  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiqingyang2012

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值