How to generate a tensorflow checkpoint file given a pb file?
Normally I am awkward to find out a desired model is using pb file to store their parameters. There was a time when I thought I could know nothing about the details of the network structure and ‘These researchers are really cunning’.
我看到别人拿pb存模型参数就怂了,因为很长时间我都以为这种存法就是不想让别人知道自己模型的细节,还觉得人家颇为狡猾。
However, this thought is partly wrong. In the worst case where only model parameters in pb format are provided, as shown in this blog and this blog, we still can peek a bit into the model, although some connection information is invisible unless we have access to the network codes.
不过这个想法现在看来不全对也不全错。如果只有一个pb文件,我们还是能窥探出一点东西的,不过网络的连接(那种无关参数的)就不太能知道了,详见blog和blog。
Here our problem is how to reload a pb parameter file to a standard tensorflow checkpoint file. For recap, this post has shown that pb file is loaded into tf.graph but as constant nodes. Then how can we make it a checkpoint file?
现在的问题是怎么从pb导入成一个ckpt。前情提要请见blog,pb导入进来的都是常量。那么到底怎么整一个ckpt出来?
The idea is to first assign the value of the constant nodes to pre-defined variable nodes, and then save as the normal tf.Saver()
approach.
简言之,把常量的值赋值给对应的变量的值就好了。
The following is an example.
from tensorflow.core.framework.graph_pb2 import *
import numpy as np
import tensorflow as tf
deepspeech_prefix = r'model/deepspeech'
newname = 'new'
with tf.Graph().as_default() as graph:
ref_input_tensor = tf.placeholder(tf.float32,
[batch_size, n_steps if n_steps &g