Tensorflow中模型的加载与保存

tensorflow中模型的加载与保存

参考资料:
1.https://blog.csdn.net/leo_xu06/article/details/79200634
2.https://blog.csdn.net/b876144622/article/details/79962727
3.https://blog.csdn.net/mieleizhi0522/article/details/80535189
4.https://blog.csdn.net/m0_37870649/article/details/81782036
5.https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

Saver

Saver是tensorflow中的一个类,主要用于保存和加载模型的参数,Saver类保存文件的目录结构为:
checkpoint:保存最近保存的模型文件名
model.ckpt.index, model.ckpt.data
model.ckpt.meta:图结构, 保存所有的变量和操作等
tensorflow的所有变量名都只是在session中才是alive的,所以必须在一个session中保存参数。

  • 模型的保存与加载
# 保存全部参数
saver = tf.train.Saver()
saver.save(sess, "model.ckpt")
# 加载模型的全部参数
saver.restore(sess, "model.ckpt")
# 获取模型图结构
tf.reset_default_graph()
tf.train.import_meta_graph('model.ckpt.meta')
  • 选择保存和加载的参数

–1.参数获取方式

# 参数获取的方式
var_list = tf.trainable_variables()
var_list = tf.global_variables()
var_list = tf.all_variables()
var_list = tf.contrib.frameworf.get_variables_to_restore()
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="RNNEncoder")

–2. 保存指定参数

# 保存指定参数
var_choose = [var["wc1"], var["wc2"]]
saver = tf.train.Saver(var_choose)
saver.save(sess, "model.ckpt")

–3. 加载预训练的参数

# 加载指定参数
var_choose = [var["wc1"], var["wc2"]]
saver = tf.train.Saver(var_choose)
saver.restore(sess, model_path)
# 手动初始化预训练模型中没有的参数
var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
# 最终保留全部参数
saver_out = tf.train.Saver()
saver_out.save(sess, "model.ckpt")
  • 判断模型是否存在并加载
# 如果代码中已经构建好了网络图结构
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(checkpoint_path)
    if ckpt and ckpt.model_checkpoint_path:
        # saver.restore(sess, ckpt.model_checkpoint_path)
        # 加载最新模型
        saver.restore(sess, tf.train.latest_checkpoint(save_dir))
# 如果使用别人完整的checkpoint文件,通过.meta来访问文件
sess = tf.Session()
with tf.Graph().as_default():
    ckpt = tf.train.get_checkpoint_state(ckpt_path)
    if ckpt and ckpt.model_checkpoint_path:
        saver = tf.train.import_meta_graph("".join([ckpt.model_checkpoint_path, ".meta"]))
        saver.restore(sess, ckpt.model_checkpoint_path)
  • 加载模型指定scope_name的参数
train_vars = tf.trainable_variables()
# 恢复模型预训练的部分参数
var_to_restore = [val for val in train_vars if 'conv' in val.name]
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))
  • 只有网络和预训练模型权重,借助该模型中某些层的权重来初始化自己的网络
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
with tf.variable_scope("", reuse=True):
    sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值