tensorflow保存和加载模型

1、tf保存模型

tf.summary.scalar('accuracy',acc)                  
merge_summary = tf.summary.merge_all()  
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)  
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    for step in xrange(training_step):       
        if step%1000==0:           
            saver.save(sess,checkpoint_dir,global_step=step)
            train_summary = sess.run(merge_summary,feed_dict =  {...})
            train_writer.add_summary(train_summary,step)

2、tf保存之后的模型

主要是三个文件,一个是.data文件(网络的权值,偏置,操作),一个是.index文件(“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等)和.meta文件(图结构) 。我们主要看一下checkpoint文件,打开如下:

可以看到保存的都是路径名,看到第一行默认保存的是最新的模型路径。

3、tf模型的加载

def checkpoint_load(path):
    print('Reading Checkpoints... .. .\n')
    ckpt = tf.train.get_checkpoint_state(path)
    print(ckpt)

print如下:

model_checkpoint_path: "model/mnist_model-49001"
all_model_checkpoint_paths: "model/mnist_model-45001"
all_model_checkpoint_paths: "model/mnist_model-46001"
all_model_checkpoint_paths: "model/mnist_model-47001"
all_model_checkpoint_paths: "model/mnist_model-48001"
all_model_checkpoint_paths: "model/mnist_model-49001"

所以可以看到tf.train.get_checkpoint_state(path)返回两个结果分别是:

ckpt.model_checkpoint_path
ckpt.all_model_checkpoint_paths

一般使用断点续训的时候我们只需要判断ckpt.model_checkpoint_path加载最新的模型即可:

   if ckpt and ckpt.model_checkpoint_path:
      ckpt_path = str(ckpt.model_checkpoint_path)
      self.saver.restore(self.sess, os.path.join(os.getcwd(), ckpt_path))
      step = int(os.path.basename(ckpt_path).split('-')[1])
      print("\nCheckpoint Loading Success! %s\n" % ckpt_path)

ckpt.model_checkpoint_path = "model/mnist_model-49001",step = int(os.path.basename(ckpt_path).split('-')[1]),得到49001,训练次数。

用法示例:

tf.summary.scalar('accuracy',acc)                  
merge_summary = tf.summary.merge_all()  
train_writer = tf.summary.FileWriter(dir,sess.graph)
......(交叉熵、优化器等定义)  
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    step = checkpoint_load(sess,saver,checkpoint_dir) 
    for step in xrange(training_step):  
        step +=1  
        if step%1000==0:           
            saver.save(sess,checkpoint_dir,step)
            train_summary = sess.run(merge_summary,feed_dict =  {...})
            train_writer.add_summary(train_summary,step)

checkpoint_load定义如下:

def checkpoint_load(sess,saver,path):
    print('Reading Checkpoints... .. .\n')
    ckpt = tf.train.get_checkpoint_state(path)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_path = ckpt.model_checkpoint_path
        saver.restore(sess,os.path.join(os.getcwd(),ckpt_path))
        step = int(os.path.basename(ckpt_path).split('-')[-1])
    # 如果模型加载失败,返回step = 0  
    else:
        step = 0
        print('Checkpoint load failed')
    return step

在测试的时候一般不需要返回step次数了,构建好网络之后直接调用checkpoint_load函数即可讲模型加载到当前图结构中,不需要返回值。有一点需要注意的是在全局初始化之后再加载参数,否则加载了模型参数又初始化之后没用。

完整的训练过程参考:https://blog.csdn.net/Li_haiyu/article/details/80846657

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值