Tensorflow: 保存和复原模型(save and restore)

报错:

is not valid checkpoint

解决:

module_file = tf.train.latest_checkpoint(diag_obj.save_path)
saver.restore(sess, module_file) 


Tensorflow: 保存和复原模型(save and restore)

目前我主要看到了两种方法来保存和复原tensorflow model,先总结一下:

MetaGraph

这种就是我们经常看到的 tf.train.Saver 对应的东西。使用这种方法保存模型,会产生两种文件。

  • meta: 里面存储的是整个graph的定义
  • checkpoint: 这里保存的是 variable 的状态。 
    这里通过如下的方式保存一个模型
checkpoint_dir = "mysaver"

# first creat a simple graph
graph = tf.Graph()

#define a simple graph
with graph.as_default():
    x = tf.placeholder(tf.float32,shape=[],name='input')
    y = tf.Variable(initial_value=0,dtype=tf.float32,name="y_variable")
    update_y = y.assign(x)
    saver = tf.train.Saver(max_to_keep=3)
    init_op = tf.global_variables_initializer()

# train the model and save the model every 4000 iterations.
sess = tf.Session(graph=graph)
sess.run(init_op)
for i in range(1,10000):
    y_result = sess.run(update_y,feed_dict={x:i})
    if i %4000 == 0:
        saver.save(sess,checkpoint_dir,global_step=i)       
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

这些是产生的文件

checkpoint
mysaver-4000.data-00000-of-00001
mysaver-4000.index
mysaver-4000.meta
mysaver-8000.data-00000-of-00001
mysaver-8000.index
mysaver-8000.meta
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

稍后我们可以复原model

tf.reset_default_graph()
restore_graph = tf.Graph()
with tf.Session(graph=restore_graph) as restore_sess:
    restore_saver = tf.train.import_meta_graph('mysaver-8000.meta')
    restore_saver.restore(restore_sess,tf.train.latest_checkpoint('./'))
    print(restore_sess.run("y_variable:0"))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

上面这段python代码的输出如下:

INFO:tensorflow:Restoring parameters from ./mysaver-8000
8000.0
  • 1
  • 2

因为最新的checkpoint文件是在 8000th iterations保存的,所以当model复原后 y_variable的值是 80000

SavedModel

还有一种保存模型的方法就是 SavedModel 。 
这种方法我是在看tensorflow servicing的时候看到的,个人的感觉,这是一种更适合部署的方法。暂时没有去研究tensorflow servicing。但是我看很多代码都使用到了通过这种方式保存的文件。比如imagenet example。所以这里着重介绍怎么使用从别的地方拿到的SavedModel文件。

建立 SavedModel

主要分为三部 
* 建立一个 tf.saved_model.builder.SavedModelBuilder
* 使用刚刚建立的 builder把当前的graph和variable添加进去:SavedModelBuilder.add_meta_graph_and_variables(...) 
* 可以使用 SavedModelBuilder.add_meta_graph 添加多个meta graph

复原 SavedModel

这个需要通过这个 model 来完成的:tf.saved_model.loader

通过命令来查看和执行SavedModel

上面的通过编程的方式来建立和复原SavedModel, 我现在基本上不需要发布模型给别人用,但是经常想使用一下别人已经训练好的模型。当拿到别人的模型的时候,需要知道怎么使用。官方提供了一个工具:saved_model_cli,这个工具包含了 show 和 run 两类命令

感兴趣的同学可以查看官方文档 或者这篇博客对应的 jupyter notebook

可视化 SavedModel

我们知道google提供 TensorBoard给我们可视化的调试tensorflow, tensorboard一个最基本的功能就是把graph展示出来。但是有时候我们拿到别人 SavedModel, 我们需要把这个model跑一遍,产生summary文件才能在tensorboard里面看。google deepdream 参考代码里面提供了一个很方便的代码可以让我们快速的把graph展示出来。代码如下, 这个代码是我也放到我的github了,大家也可以直接去看google deepdram 参考代码

# these function is copied from google deepdream example code
import numpy as np
from IPython.display import clear_output, Image, display, HTML
def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size)
    return strip_def

def rename_nodes(graph_def, rename_func):
    res_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = res_def.node.add() 
        n.MergeFrom(n0)
        n.name = rename_func(n.name)
        for i, s in enumerate(n.input):
            n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:])
    return res_def
def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:800px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值