更多内容:
https://blog.csdn.net/lwplwf/article/details/62419087
1、保存模型
import tensorflow as tf
import numpy as np
**xxxxx无关紧要代码**
saver = tf.train.Saver(max_to_keep=3) #关键1 默认max_to_keep=5
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
**网络训练部分** #关键2
save_path = saver.save(sess, r"F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt") #关键3
print("Save to path: ", save_path)
运行结果:
Save to path: F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt
2、调用模型
import tensorflow as tf
import numpy as np
**xxxxx无关紧要代码**
saver = tf.train.Saver() #关键4
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, r"F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt") #关键5
**删除网络训练部分,这些代码不需要了** #关键6
运行结果:
INFO:tensorflow:Restoring parameters from F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt
weights: [[1. 2. 3.]
[3. 4. 5.]]
biases: [[1. 2. 3.]]
3、保存的文件名称含义:
https://blog.csdn.net/sinat_36618660/article/details/98665482
checkpoint: 检查点,记录保存了哪些checkpoint;
.ckpt.data: 以字典的形式保存权重偏置项等训练参数;
.ckpt.index: 参数索引;
.ckpt.meta: 图结构以及图中参数数据。
4、查看checkpoint文件中保存的变量名称:
https://www.cnblogs.com/weizhen/p/8451514.html
import tensorflow as tf
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("tmp/model.ckpt", tensor_name=None, all_tensors=True)