def train(args, sess, model):
optimizer = tf.train.AdamOptimizer(args.learning_rate, beta1=args.momentum, name="AdamOptimizer_G").minimize(model.g_loss_all, var_list=model.c_vars)
epoch = 0
#saver
saver = tf.train.Saver()
#提取
if args.continue_training:
tf.local_variables_initializer().run()
last_ckpt = tf.train.latest_checkpoint(args.checkpoints_path)
saver.restore(sess, last_ckpt)
ckpt_name = str(last_ckpt)
print ("Loaded model file from " + ckpt_name)
epoch = int(ckpt_name.split('-')[-1])
else:
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
while epoch < args.train_step:
pass
#保存
if epoch % 10 ==0:
saver.save(sess, args.checkpoints_path + "/model", global_step=epoch)
epoch += 1
print("Done.")
tensorflow :Saver保存和提取
最新推荐文章于 2021-11-15 20:47:39 发布