TensorFlow 模型保存/载入

5人阅读 评论(0) 收藏 举报
分类:

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  1. 定义变量
  2. 使用saver.save()方法保存

载入

  1. 定义变量
  2. 使用saver.restore()方法载入

保存 代码如下


import tensorflow as tf  
import numpy as np  

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w')  
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b')  

init = tf.initialize_all_variables()  
saver = tf.train.Saver()  
with tf.Session() as sess:  
        sess.run(init)  
        save_path = saver.save(sess,"save/model.ckpt")  

载入代码如下

import tensorflow as tf  
import numpy as np  

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w')  
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b')  

saver = tf.train.Saver()  
with tf.Session() as sess:  
        saver.restore(sess,"save/model.ckpt")  

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
    meta_graph_or_file,
    clear_devices=False,
    import_scope=None,
    **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 保存checkpoint, 同时也默认导出一个meta_graph
        # graph名为'my-model-{global_step}.meta'.
        saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
  y = tf.get_collection('pred_network')[0]

  graph = tf.get_default_graph()

  # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
  input_x = graph.get_operation_by_name('input_x').outputs[0]
  keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

  # 使用y进行预测  
  sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})

这里有两点需要注意的:
一、 saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

checkpoint 文件解析:

这里写图片描述

saver.restore函数给出 model.ckpt-n 的路径后会自动寻找参数名-值文件进行加载:

saver.restore(sess,’./model/model.ckpt-0’)
saver.restore(sess,ckpt.model_checkpoint_path)

if ckpt and ckpt.model_checkpoint_path:
    print(ckpt.model_checkpoint_path)
    saver.restore(sess,'./model/model.ckpt-0')
    #加载最新的模型
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    res = sess.run(y, feed_dict={x: img})
    print(global_step,sess.run(tf.argmax(res,1)))

转载: https://blog.csdn.net/thriving_fcl/article/details/71423039

查看评论

tensorflow保存 和 加载模型

1、 import tensorflow as tf import numpy as np # save to file W = tf.Variable([[1,2,3],[4,5,6]],dty...
  • lujiandong1
  • lujiandong1
  • 2016-11-22 22:12:47
  • 2991

TensorFlow 模型保存/载入的两种方法

TensorFlow 模型保存/载入方法记录
  • thriving_fcl
  • thriving_fcl
  • 2017-05-08 16:02:05
  • 15310

tensorflow从0开始(6)——保存加载模型

目的 学习tensorflow的目的是能够训练的模型,并且利用已经训练好的模型对新数据进行预测。下文就是一个简单的保存模型加载模型的过程。 保存模型 import tenso...
  • searobbers_duck
  • searobbers_duck
  • 2016-06-20 16:52:12
  • 52109

Tensorflow深度学习笔记(十)--模型保存与重新载入

在深度学习过程中我们会训练很多的模型,有些模型的训练很费时间。是否可以保存已经训练好的模型应用于后续的图像识别呢?答案自然是肯定的,本节我们来讲述模型的保存与载入。 1.模型的保存 模型的保存有...
  • juyin2015
  • juyin2015
  • 2017-12-14 22:25:16
  • 202

TensorFlow模型的保存和持久化

前言在TensorFlow中,一旦模型训练完成,就需要对其进行持久化操作,也就是将其保存起来,在需要进行对新样本进行测试时,程序加载已经持久化后的模型。在这个过程中就涉及到了模型的持久化操作,在这里简...
  • LoseInVain
  • LoseInVain
  • 2017-10-15 15:18:40
  • 1452

【tensorflow】保存模型、再次加载模型等操作

由于经常要使用tensorflow进行网络训练,但是在用的时候每次都要把模型重新跑一遍,这样就比较麻烦;另外由于某些原因程序意外中断,也会导致训练结果拿不到,而保存中间训练过程的模型可以以便下次训练时...
  • liuxiao214
  • liuxiao214
  • 2018-01-12 21:04:16
  • 1703

使用tensorflow保存、加载和使用模型

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍: http:/...
  • LordofRobots
  • LordofRobots
  • 2017-08-30 17:23:33
  • 4906

tensorflow学习笔记六:保存和加载训练模型

对于机器学习,尤其是深度学习DL的算法,模型训练可能很耗时,几个小时或者几天,所以如果是测试模块出了问题,每次都要重新运行就显得很浪费时间,所以如果训练部分没有问题,那么可以直接将训练的模型保存起来,...
  • xiaopihaierletian
  • xiaopihaierletian
  • 2017-03-13 22:02:25
  • 2286

tensorflow 模型保存与加载

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/ 什么...
  • spylyt
  • spylyt
  • 2017-05-11 10:17:30
  • 7462

tensorflow 加载部分变量

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下: import tensorflow as tf v1= tf.Va...
  • u011961856
  • u011961856
  • 2017-08-07 16:01:44
  • 2350
    个人资料
    等级:
    访问量: 6357
    积分: 255
    排名: 31万+
    文章分类