[Tensorflow] 如何从pb文件生成标准的tensorflow checkpoint文件?

专栏收录该内容
18 篇文章 0 订阅

How to generate a tensorflow checkpoint file given a pb file?

Normally I am awkward to find out a desired model is using pb file to store their parameters. There was a time when I thought I could know nothing about the details of the network structure and ‘These researchers are really cunning’.
我看到别人拿pb存模型参数就怂了,因为很长时间我都以为这种存法就是不想让别人知道自己模型的细节,还觉得人家颇为狡猾。

However, this thought is partly wrong. In the worst case where only model parameters in pb format are provided, as shown in this blog and this blog, we still can peek a bit into the model, although some connection information is invisible unless we have access to the network codes.
不过这个想法现在看来不全对也不全错。如果只有一个pb文件,我们还是能窥探出一点东西的,不过网络的连接(那种无关参数的)就不太能知道了,详见blogblog

Here our problem is how to reload a pb parameter file to a standard tensorflow checkpoint file. For recap, this post has shown that pb file is loaded into tf.graph but as constant nodes. Then how can we make it a checkpoint file?
现在的问题是怎么从pb导入成一个ckpt。前情提要请见blog,pb导入进来的都是常量。那么到底怎么整一个ckpt出来?

The idea is to first assign the value of the constant nodes to pre-defined variable nodes, and then save as the normal tf.Saver() approach.
简言之,把常量的值赋值给对应的变量的值就好了。

The following is an example.

from tensorflow.core.framework.graph_pb2 import *
import numpy as np
import tensorflow as tf

deepspeech_prefix = r'model/deepspeech'
newname = 'new'

with tf.Graph().as_default() as graph:
    ref_input_tensor = tf.placeholder(tf.float32,
                                      [batch_size, n_steps if n_steps > 0 else None, 2 * n_context + 1, n_input],
                                      name='input_node')

    with tf.gfile.FastGFile("../models/output_graph.pb", 'rb') as fin:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fin.read())
    logits, init_state = tf.import_graph_def(graph_def,
                                      return_elements={"logits:0",'initialize_state'},
                                      input_map={"input_node:0": ref_input_tensor},
                                      name=newname)
                                      
    # Now let's dump these weights into a new copy of the network.
    with tf.Session(graph=graph) as sess:
        def inference_with_real_data(input_tensor):
            seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
            init_c = tf.zeros_initializer()
            init_h = tf.zeros_initializer()
            previous_state_c = DeepSpeech.variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=init_c)
            previous_state_h = DeepSpeech.variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=init_h)
            previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)


            logits, layers = DeepSpeech.BiRNN(batch_x=input_tensor,
                                   seq_length=seq_length,# if FLAGS.use_seq_length else None,
                                   dropout=[ 0.0 ] * 6, # no dropout
                                   batch_size=batch_size,
                                   n_steps=n_steps,
                                   previous_state=previous_state)

            new_state_c, new_state_h = layers['rnn_output_state']
           
            with tf.control_dependencies(
                    [tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
                logits = tf.identity(logits, name='logits')

            return {
                    'outputs': logits,
                    'initialize_state': [previous_state_c, previous_state_h],#initialize_state,
                }

        # my model inference graph
        with tf.variable_scope(deepspeech_prefix):
            outputs = inference_with_real_data(ref_input_tensor)
            out_logits = outputs['outputs']
            
        mapping = {v.name: v for v in tf.global_variables() if not v.name.startswith(deepspeech_prefix+'/'+'previous_state_')}
        for name,var in mapping.items():
            name = name[len(deepspeech_prefix)+1:]
            sess.run(var.assign(sess.run(newname + '/' + name))) # tensor(name) = tensor(newname+name)
      
        saver = tf.train.Saver(mapping.values())
        saver.save(sess, "../models/deepspeech_v3")

I should express thankfulness to this repo.
The codes :
1. import the graph defined in the pb file.
as we mentioned, there are only constant nodes
2. construct our graph that aligns with the pb file.
this creates the variable nodes corresponding to the constant ones.
3. assign the values
note that we can reference a variable not only by its handle but also its name!
4. save using tf.train.Saver()

The key here is, we have two graphs, one is imported from the pb file, the other is from our graph definition. The imported graph has only constants but they preserve the values, the defined graph has variables but have no value. All we do is value assignment. But the reason we can assign from and to two exactly the same graph is that they have different prefix!
关键在于这有两张图,一个导入的一个定义的。导入的都是常量但是有值,定义的是变量但是无值。于是就去做赋值。这两张一模一样的图能相互赋值是因为他们名字的前缀不一样。

  • 0
    点赞
  • 0
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值