# 构建训练节点
train_op = create_optimizer(
total_loss, lr, optimizer_params, 1., variables_to_train, use_fp16=FLAGS.use_fp16)
# 将优化器ops添加进依赖
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops.append(train_op)
update_op = tf.group(*update_ops)
with tf.control_dependencies([update_op]):
train_tensor = tf.identity(total_loss, name='train_op')
# 构建hook
avg_logging_hooks = LogSessionRunHook(FLAGS.train_batch_size, 400)
# 构建训练spec
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_tensor,
scaffold=scaffold_fn,
training_hooks=[avg_logging_hooks])
tensorflow 在优化器后面添加bn的ops依赖
最新推荐文章于 2022-03-20 16:49:07 发布