18 篇文章 0 订阅

# How to generate a tensorflow checkpoint file given a pb file?

Normally I am awkward to find out a desired model is using pb file to store their parameters. There was a time when I thought I could know nothing about the details of the network structure and ‘These researchers are really cunning’.

However, this thought is partly wrong. In the worst case where only model parameters in pb format are provided, as shown in this blog and this blog, we still can peek a bit into the model, although some connection information is invisible unless we have access to the network codes.

Here our problem is how to reload a pb parameter file to a standard tensorflow checkpoint file. For recap, this post has shown that pb file is loaded into tf.graph but as constant nodes. Then how can we make it a checkpoint file?

The idea is to first assign the value of the constant nodes to pre-defined variable nodes, and then save as the normal tf.Saver() approach.

The following is an example.

from tensorflow.core.framework.graph_pb2 import *
import numpy as np
import tensorflow as tf

deepspeech_prefix = r'model/deepspeech'
newname = 'new'

with tf.Graph().as_default() as graph:
ref_input_tensor = tf.placeholder(tf.float32,
[batch_size, n_steps if n_steps > 0 else None, 2 * n_context + 1, n_input],
name='input_node')

with tf.gfile.FastGFile("../models/output_graph.pb", 'rb') as fin:
graph_def = tf.GraphDef()
logits, init_state = tf.import_graph_def(graph_def,
return_elements={"logits:0",'initialize_state'},
input_map={"input_node:0": ref_input_tensor},
name=newname)

# Now let's dump these weights into a new copy of the network.
with tf.Session(graph=graph) as sess:
def inference_with_real_data(input_tensor):
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
init_c = tf.zeros_initializer()
init_h = tf.zeros_initializer()
previous_state_c = DeepSpeech.variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=init_c)
previous_state_h = DeepSpeech.variable_on_worker_level('previous_state_h', [batch_size, n_cell_dim], initializer=init_h)
previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)

logits, layers = DeepSpeech.BiRNN(batch_x=input_tensor,
seq_length=seq_length,# if FLAGS.use_seq_length else None,
dropout=[ 0.0 ] * 6, # no dropout
batch_size=batch_size,
n_steps=n_steps,
previous_state=previous_state)

new_state_c, new_state_h = layers['rnn_output_state']

with tf.control_dependencies(
[tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')

return {
'outputs': logits,
'initialize_state': [previous_state_c, previous_state_h],#initialize_state,
}

# my model inference graph
with tf.variable_scope(deepspeech_prefix):
outputs = inference_with_real_data(ref_input_tensor)
out_logits = outputs['outputs']

mapping = {v.name: v for v in tf.global_variables() if not v.name.startswith(deepspeech_prefix+'/'+'previous_state_')}
for name,var in mapping.items():
name = name[len(deepspeech_prefix)+1:]
sess.run(var.assign(sess.run(newname + '/' + name))) # tensor(name) = tensor(newname+name)

saver = tf.train.Saver(mapping.values())
saver.save(sess, "../models/deepspeech_v3")


I should express thankfulness to this repo.
The codes :
1. import the graph defined in the pb file.
as we mentioned, there are only constant nodes
2. construct our graph that aligns with the pb file.
this creates the variable nodes corresponding to the constant ones.
3. assign the values
note that we can reference a variable not only by its handle but also its name!
4. save using tf.train.Saver()

The key here is, we have two graphs, one is imported from the pb file, the other is from our graph definition. The imported graph has only constants but they preserve the values, the defined graph has variables but have no value. All we do is value assignment. But the reason we can assign from and to two exactly the same graph is that they have different prefix!

• 0
点赞
• 0
评论
• 1
收藏
• 一键三连
• 扫一扫，分享海报

04-17 665

01-06 5476
04-13 807
09-26 997
05-08 2291
06-14 3963
04-29 9147
10-25 12
09-18 1442
07-04 4951
©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。