tensorflow 模型保存、加载预训练模型

在训练神经网络时候,最重要的一步是保存训练好网络参数,否者session关闭后,我们辛辛苦苦训练模型会丢失,无法复用,,我们使用别人训练好的模型fineturn同样需要加载模型,tensorflow 提供模型保存和加载工具。

参考:参考博客

0.参考是实例

tensorflow_gpu_1.15 mnist

tensorflow mnist应用实例

1.模型说明

tf版本:tensorflow_gpu 1.15

tensorflow 模型保存为checkpoint主要包括以下4个文件。

1.1 .mate文件

mate文件保存的是图结构TensorFlow计算图的结构,也就是神经网络的结构,其中记录计算图中节点,变量,操作等。

1.2.data-00000-of-00001

保存网络中每个变量值,类似python字典结果,对应每个变量的数值(key,value)。

1.3.index

是对应模型的索引文件

1.4 checkpoint

checkpoint是个文本文件文件保存了一个目录下所有的模型文件列表。

2.模型保存

一般都是在训练完成多少个epoch后保存模型,通过 tf.train.Saver类来保存模型。注意tf是先建图后启动Session。具体如下

tf.train.Saver.save()

函数说明

API说明

save(
    sess,         #需要保存的Session
    save_path,      #模型保存路径
    global_step=None, #设置多少步保存以下,需要自己设置步数
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True, #是否保存mate文件
    write_state=True,
    strip_default_attrs=False,
    save_debug_info=False
)

还可以设置多少就保存一次
keep_checkpoint_every_n_hours=2

设置保存最近几个模型。

max_to_keep=5,

如果tf.train.Saver默认保存Session所有的变量。保存一部分变量,可以通过指定variables/collections。在创建tf.train.Saver实例时,通过将需要保存的变量构造list或者dictionary,传入到Saver中:

''''
使用tensorflow_gpu 1.15版本
'''
model_saver = tf.train.Saver()#建图
    with tf.Session() as sess:
        sess.run(tf.compat.v1.global_variables_initializer())#全局变量初始化
        for i in range(TOTAL_ITERATION +1 ):#开始训练循环
            
            #.........................省略训练代码
            if step % 1000:#设置多少步保存一次
            model_saver.save(sess, MODEL_SAVE_PATH,global_step=step)#模型保存


            #还可以自己设置保存条件
            #设置模型保存条件,一般是n个epoch保存一下,还有最后一个循环保存
            if step % MODEL_SAVE_ITERATION == 0 or i == TOTAL_ITERATION:
                
                model_name = 'model_{}.ckpt'.format(i)#模型命名
                model_saver.save(sess, os.path.join(MODEL_SAVE_PATH, model_name))#模型保存
                print('------->model_{}.ckpt save ok!!!'.format(model_name))

3.在制定模型基础上训练

3.1在上一次基础上继续(断点训练)

有时候训练意外终止或者我们想在上一次训练的基础上继续迭代,这是就需要继续训练模型。

saver = tf.train.Saver()

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())#初始化函数
   #加入断点续训功能
   ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)#获取保存模型最新中模型
   if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess,ckpt.model_checkpoint_path)#如果存在,则加载最新模型

   for i in range(STEPS):#开始训练
      xs,ys = mnist.train.next_batch(BATCH_SIZE)
      sess.run(train_op,feed_dict={x:xs,y_:ys})
      if i % 1000 == 0:
         saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=i)

 

3.2指定初始化模型训练

TODO

4.模型加载

模型加载包括两个方面,加载计算图和加载神经网络的参数(权重)。

4.1加载计算图(网络)

使用import_meta_graph()函数

saver=tf.train.import_meta_graph('MODEL_PATH/Model-1000.meta')

4.2加载网络参数

变量值需要依赖于Session,加载变量时需要建图

import tensorflow as tf
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('./model_path/MyModel-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('、model_path'))
    print(sess.run('w1:0')#输出网络w1:的权重

 

 

 

 

 

 

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值