参数初始化:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer()) # 初始化模型的参数
参数保存:
saver = tf.train.Saver()
save_path = saver.save(sess, "E:\\projects\\PycharmProjects\\OCR\\model\\")
print(save_path)
在训练循环中,定期调用 saver.save() 方法,向文件夹中写入包含了当前模型中所有可训练变量的 checkpoint 文件。
saver.save(sess, FLAGS.train_dir, global_step=step)
global_step是训练的第几步
保存参数:
import tensorflow as tf
W = tf.Variable([[1, 2, 3]], dtype=tf.float32)
b = tf.Variable([[1]], dtype=tf.float32)
saver = tf.train.Saver()
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# 必须要指定文件夹,保存到ckpt文件
save_path = saver.save(sess, "winycg/1.ckpt")
print(save_path)
参考博文: https://blog.csdn.net/winycg/article/details/78572438