tensorflow的模型/graph包含两部分:模型定义和tensor数值。相应的,训练好的模型也会存储于两个文件:由tf.train.write_graph()输出的protobuf文件,如graph.pb;以及saver输出的checkpoint文件,如model.ckpt。
载入时,saver和import_graph_def不能同时使用。这是因为saver使用特殊的collection存储variables,而使用import_graph_def时,这个collection并没有初始化。
所以,为了载入训练好的模型,有两种可选方案:
1)只使用saver。你需要手动地构建一幅包含与原图有相同名称节点的图,然后用saver载入weights
2)或者,使用import_graph_def载入图定义,手工创建variables,并对每个variable使用tf.add_to_collection,然后使用saver
如果是做app开发,需要从pc端载入模型,由于没有c++ API载入saver存储的variables,这时必须使用一些小tricks。
代码如下:
这份代码生成的pb文件在运行时会报错,猜测是由于将second_graph中的w和b等variables转换为constant之后,没有删除。也就是说second_graph还保存着variables。移动端载入这张graph时,由于没有执行tf.initialize,就会报错
E/native (27277): tensorflow_jni.cc:312 Error during inference: Invalid argument: No OpKernel was registered to support Op 'Mod' with these attrs
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import data
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
vars = {}
first_graph = tf.Graph()
with first_graph.as_default():
with tf.Session() as sess:
print("# build graph and run")
x = tf.placeholder(tf.float32, shape=[None, 784], name="input")
y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y_")
W = tf.Variable(tf.zeros([784,10]), name = "W")
b = tf.Variable(tf.zeros([10]), name = "b")
sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b, name="output")
print ("all tensors: ")
print ("x : ", x)
print ("y_ : ", y_)
print ("y : ", y)
print ("W : ", W)
print ("b : ", b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
print("push input_map & input_node")
for v in tf.trainable_variables():
vars[v.value().name] = sess.run(v)
#print(vars[v.value().name])
print(vars)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
second_graph = tf.Graph()
with second_graph.as_default():
with tf.Session() as persisted_sess:
print("# input map")
map = {}
for key in vars.keys():
map[key] = tf.constant(vars[key])
map["input:0"] = tf.placeholder(tf.float32, shape=[None, 784], name="input")
map["y_:0"] = tf.placeholder(tf.float32, shape=[None, 10], name="y_")
#map["output:0"] = tf.placeholder(tf.float32, shape=[None, 10], name="output")
print(map)
print("# load graph")
#tf.import_graph_def(first_graph.as_graph_def(), input_map = map)
output_list = tf.import_graph_def(first_graph.as_graph_def(), input_map = map, return_elements = ["output:0"], name = '')
#tf.import_graph_def(sess.graph_def, input_map={name:consts[name] for name in consts.keys()}, name = '')
print(output_list)
print("# map variables")
c = persisted_sess.graph.get_tensor_by_name("Const:0")
print(c)
x = persisted_sess.graph.get_tensor_by_name("input:0")
print(x)
y_ = persisted_sess.graph.get_tensor_by_name("y_:0")
print(y_)
y = persisted_sess.graph.get_tensor_by_name("output:0")
print(y)
print ("# strip extra nodes: ")
#graph_util.extract_sub_graph(persisted_sess.graph, b)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
tf.train.write_graph(second_graph.as_graph_def(),'/tmp/graph/','graph.pb',False)
运行(草稿):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import data
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.platform import gfile
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
#load graph
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
#run
with tf.Session() as sess:
x = sess.graph.get_tensor_by_name("input:0")
#print input_x
y_ = sess.graph.get_tensor_by_name("y_:0")
y = sess.graph.get_operation_by_name("output:0")
print y
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))