一、tensorflow中的global_step
1、tf.train.get_or_create_global_step()
bert中的optimization.py中有以下一段代码:
def create_optimizer(loss, init_lr, num_train_steps):
global_step = tf.train.get_or_create_global_step()
"""省略无用代码"""
optimizer = AdamWeightDecayOptimizer(*args)
"""省略无用代码"""
train_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=global_step)
"""global_step正常是无需手动更新的, 但这里手动实现了Adam,所以需要赋值更新。"""
new_global_step = global_step + 1
train_op = tf.group(train_op, [global_step.assign(new_global_step)])
return train_op
从该段代码说明了以下几点:
- tensorflow直接调用的optimizer是包含global_step的更新的,无需自己手动实现
- tensorflow的step不是一个batch_size, 也不是一个epoch, 而是一次梯度更新。
- 通过调用tf.train.get_or_create_global_step(), 可以获得当前优化步数。