问题:使用tensoflow导入训练模型参数后,eoch, global_step仍然是之前训练所保留的值。因此,在导入预训练模型参数之后,需要将epoch, global_step的值重新置0.
部分参考代码:
with tf.Session(config=config) as sess:
# Initialize all variables
sess.run(tf.global_variables_initializer())
print("{}: Start training...".format(datetime.datetime.now()))
# summary writer for tensorboard
train_summary_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_summary_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test', sess.graph)
# restore from checkpoint
if FLAGS.restore_training:
# check if checkpoint exists
if os.path.exists(checkpoint_prefix + "-latest"):
print("{}: Last checkpoint found at {}, loading...".format(datetime.datetime.now(),
FLAGS.checkpoint_dir))
latest_checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir,
latest_filename="checkpoint-latest")
saver.restore(sess, latest_checkpoint_path)
if FLAGS.pre_training:
print("{}: pre_train checkpoint found at {}, loading...".format(datetime.datetime.now(),
FLAGS.pre_checkpoint_path))
saver.restore(sess, FLAGS.pre_checkpoint_path)
# start_epoch.eval()[0]= 0
op1 =tf.assign(global_step, 0)
op2 =tf.assign(start_epoch, [0])
sess.run([op1, op2])
# sess.run(start_epoch)
print("{}: Last checkpoint epoch: {}".format(datetime.datetime.now(), start_epoch.eval()[0]))
print("{}: Last checkpoint global step: {}".format(datetime.datetime.now(),
tf.train.global_step(sess, global_step)))
其中,将epoch, global_step重新设置为0的代码如下:
op1 =tf.assign(global_step, 0) op2 =tf.assign(start_epoch, [0]) sess.run([op1, op2])