【编程技术】Tensorflow Forward Outputs &Backwards Gradient API介绍

背景介绍

上一篇 【编程技术】Keras Forward Outputs &Backwards Gradient API介绍 介绍如何通过Keras API 获得Keras 模型(Ex:比如预定义的ResNet50 或者自定义的其它模型)的每一层前向输出(Forward Outputs)和每一层的后向梯度(Backward Gradients)

本文旨在介绍通过Tensorflow API 在Google Official ResNet Mode中获得每一层前向输出(Forward Outputs)和每一层的后向梯度(Backward Gradients)

(至于Keras API和Tensorflow API 优劣对比不在本文讨论范围)

示例代码

Google Official ResNet50 - resnet_run_loop.py

def resnet_model_fn(features, labels, mode, model_class,
                    resnet_size, weight_decay, learning_rate_fn, momentum,
                    data_format, resnet_version, loss_scale,
                    loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
                    fine_tune=False):
  """Shared functionality for different resnet model_fns.
  ...
  """

  # Generate a summary node for the images
  tf.summary.image('images', features, max_outputs=6)
  # Checks that features/images have same data type being used for calculations.
  assert features.dtype == dtype

  model = model_class(resnet_size, data_format, resnet_version=resnet_version,
                      dtype=dtype)

  logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)

 ...

  # Create a tensor named train_accuracy for logging purposes
  tf.identity(accuracy[1], name='train_accuracy')
  tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
  tf.summary.scalar('train_accuracy', accuracy[1])
  tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])

  if flag_compute_variable_grads == True:
    #计算loss对所有可以训练的变量的梯度
    grad_targets = tf.gradients(loss, tf.trainable_variables())
  if flag_compute_op_outputs_grads == True:
    g = tf.get_default_graph()
    temp = []
    for op in g.get_operations():
      #计算loss对所有Op的output的梯度
      if len(op.outputs) > 0 and len(op.inputs) > 0:
          if is_outputs:
              for t in op.outputs:
                  temp.append(t)
          else:
              for t in op.inputs:
                  temp.append(t)
    grad_targets = tf.gradients(loss, temp)
    while None in grad_targets:
        grad_targets.remove(None)
  #保存 grad_targets 到DumpSessionHook中

  return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=metrics)

Google Official ResNet Model - hooks.py

class DumpSessionHook(tf.train.SessionRunHook):
  """Hook to extend calls to MonitoredSession.run()."""

...

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().
    You can return from this call a `SessionRunArgs` object indicating ops or
    tensors to add to the upcoming `run()` call.  These ops/tensors will be run
    together with the ops/tensors originally passed to the original run() call.
    The run args you return can also contain feeds to be added to the run()
    call.
    The `run_context` argument is a `SessionRunContext` that provides
    information about the upcoming `run()` call: the originally requested
    op/tensors, the TensorFlow Session.
    At this point graph is finalized and you can not add ops.
    Args:
      run_context: A `SessionRunContext` object.
    Returns:
      None or a `SessionRunArgs` object.
    """
    #将需要求解的grad_targets包装为SessionRunArgs返回
    return tf.train.SessionRunArgs(grad_targets)


  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """Called after each call to run().
    The `run_values` argument contains results of requested ops/tensors by
    `before_run()`.
    The `run_context` argument is the same one send to `before_run` call.
    `run_context.request_stop()` can be called to stop the iteration.
    If `session.run()` raises any exceptions then `after_run()` is not called.
    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    #输出session计算后的grad_targets值
    print(run_values)

如上面的代码所示,关键的Tensorflow API 都加上了注释,包括

  • 在构建model的函数中通过tf.gradients获取梯度Tensor
  • 在tf.train.SessionRunHook的回调中feed 需求求解的Tensor,以及fetch到具体的结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值