tensorflow中模型的加载与保存
参考资料:
1.https://blog.csdn.net/leo_xu06/article/details/79200634
2.https://blog.csdn.net/b876144622/article/details/79962727
3.https://blog.csdn.net/mieleizhi0522/article/details/80535189
4.https://blog.csdn.net/m0_37870649/article/details/81782036
5.https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
Saver
Saver是tensorflow中的一个类,主要用于保存和加载模型的参数,Saver类保存文件的目录结构为:
checkpoint:保存最近保存的模型文件名
model.ckpt.index, model.ckpt.data
model.ckpt.meta:图结构, 保存所有的变量和操作等
tensorflow的所有变量名都只是在session中才是alive的,所以必须在一个session中保存参数。
- 模型的保存与加载
# 保存全部参数
saver = tf.train.Saver()
saver.save(sess, "model.ckpt")
# 加载模型的全部参数
saver.restore(sess, "model.ckpt")
# 获取模型图结构
tf.reset_default_graph()
tf.train.import_meta_graph('model.ckpt.meta')
- 选择保存和加载的参数
–1.参数获取方式
# 参数获取的方式
var_list = tf.trainable_variables()
var_list = tf.global_variables()
var_list = tf.all_variables()
var_list = tf.contrib.frameworf.get_variables_to_restore()
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="RNNEncoder")
–2. 保存指定参数
# 保存指定参数
var_choose = [var["wc1"], var["wc2"]]
saver = tf.train.Saver(var_choose)
saver.save(sess, "model.ckpt")
–3. 加载预训练的参数
# 加载指定参数
var_choose = [var["wc1"], var["wc2"]]
saver = tf.train.Saver(var_choose)
saver.restore(sess, model_path)
# 手动初始化预训练模型中没有的参数
var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
# 最终保留全部参数
saver_out = tf.train.Saver()
saver_out.save(sess, "model.ckpt")
- 判断模型是否存在并加载
# 如果代码中已经构建好了网络图结构
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
# saver.restore(sess, ckpt.model_checkpoint_path)
# 加载最新模型
saver.restore(sess, tf.train.latest_checkpoint(save_dir))
# 如果使用别人完整的checkpoint文件,通过.meta来访问文件
sess = tf.Session()
with tf.Graph().as_default():
ckpt = tf.train.get_checkpoint_state(ckpt_path)
if ckpt and ckpt.model_checkpoint_path:
saver = tf.train.import_meta_graph("".join([ckpt.model_checkpoint_path, ".meta"]))
saver.restore(sess, ckpt.model_checkpoint_path)
- 加载模型指定scope_name的参数
train_vars = tf.trainable_variables()
# 恢复模型预训练的部分参数
var_to_restore = [val for val in train_vars if 'conv' in val.name]
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))
- 只有网络和预训练模型权重,借助该模型中某些层的权重来初始化自己的网络
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
with tf.variable_scope("", reuse=True):
sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))