背景介绍
上一篇 【编程技术】Keras Forward Outputs &Backwards Gradient API介绍 介绍如何通过Keras API 获得Keras 模型(Ex:比如预定义的ResNet50 或者自定义的其它模型)的每一层前向输出(Forward Outputs)和每一层的后向梯度(Backward Gradients)
本文旨在介绍通过Tensorflow API 在Google Official ResNet Mode中获得每一层前向输出(Forward Outputs)和每一层的后向梯度(Backward Gradients)
(至于Keras API和Tensorflow API 优劣对比不在本文讨论范围)
示例代码
Google Official ResNet50 - resnet_run_loop.py
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False):
"""Shared functionality for different resnet model_fns.
...
"""
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
# Checks that features/images have same data type being used for calculations.
assert features.dtype == dtype
model = model_class(resnet_size, data_format, resnet_version=resnet_version,
dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
...
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.identity(accuracy_top_5[1], name='train_accuracy_top_5')
tf.summary.scalar('train_accuracy', accuracy[1])
tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1])
if flag_compute_variable_grads == True:
#计算loss对所有可以训练的变量的梯度
grad_targets = tf.gradients(loss, tf.trainable_variables())
if flag_compute_op_outputs_grads == True:
g = tf.get_default_graph()
temp = []
for op in g.get_operations():
#计算loss对所有Op的output的梯度
if len(op.outputs) > 0 and len(op.inputs) > 0:
if is_outputs:
for t in op.outputs:
temp.append(t)
else:
for t in op.inputs:
temp.append(t)
grad_targets = tf.gradients(loss, temp)
while None in grad_targets:
grad_targets.remove(None)
#保存 grad_targets 到DumpSessionHook中
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
Google Official ResNet Model - hooks.py
class DumpSessionHook(tf.train.SessionRunHook):
"""Hook to extend calls to MonitoredSession.run()."""
...
def before_run(self, run_context): # pylint: disable=unused-argument
"""Called before each call to run().
You can return from this call a `SessionRunArgs` object indicating ops or
tensors to add to the upcoming `run()` call. These ops/tensors will be run
together with the ops/tensors originally passed to the original run() call.
The run args you return can also contain feeds to be added to the run()
call.
The `run_context` argument is a `SessionRunContext` that provides
information about the upcoming `run()` call: the originally requested
op/tensors, the TensorFlow Session.
At this point graph is finalized and you can not add ops.
Args:
run_context: A `SessionRunContext` object.
Returns:
None or a `SessionRunArgs` object.
"""
#将需要求解的grad_targets包装为SessionRunArgs返回
return tf.train.SessionRunArgs(grad_targets)
def after_run(self,
run_context, # pylint: disable=unused-argument
run_values): # pylint: disable=unused-argument
"""Called after each call to run().
The `run_values` argument contains results of requested ops/tensors by
`before_run()`.
The `run_context` argument is the same one send to `before_run` call.
`run_context.request_stop()` can be called to stop the iteration.
If `session.run()` raises any exceptions then `after_run()` is not called.
Args:
run_context: A `SessionRunContext` object.
run_values: A SessionRunValues object.
"""
#输出session计算后的grad_targets值
print(run_values)
如上面的代码所示,关键的Tensorflow API 都加上了注释,包括
- 在构建model的函数中通过tf.gradients获取梯度Tensor
- 在tf.train.SessionRunHook的回调中feed 需求求解的Tensor,以及fetch到具体的结果