tensorflow用于生成pb的代码

tensorflow的模型/graph包含两部分:模型定义和tensor数值。相应的,训练好的模型也会存储于两个文件:由tf.train.write_graph()输出的protobuf文件,如graph.pb;以及saver输出的checkpoint文件,如model.ckpt。

载入时,saver和import_graph_def不能同时使用。这是因为saver使用特殊的collection存储variables,而使用import_graph_def时,这个collection并没有初始化。

所以,为了载入训练好的模型,有两种可选方案:

1)只使用saver。你需要手动地构建一幅包含与原图有相同名称节点的图,然后用saver载入weights

2)或者,使用import_graph_def载入图定义,手工创建variables,并对每个variable使用tf.add_to_collection,然后使用saver

如果是做app开发,需要从pc端载入模型,由于没有c++ API载入saver存储的variables,这时必须使用一些小tricks。

代码如下:

这份代码生成的pb文件在运行时会报错,猜测是由于将second_graph中的w和b等variables转换为constant之后,没有删除。也就是说second_graph还保存着variables。移动端载入这张graph时,由于没有执行tf.initialize,就会报错

E/native  (27277): tensorflow_jni.cc:312 Error during inference: Invalid argument: No OpKernel was registered to support Op 'Mod' with these attrs

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Import data
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

vars = {}
first_graph = tf.Graph()
with first_graph.as_default():
  with tf.Session() as sess:
    print("# build graph and run")
    x = tf.placeholder(tf.float32, shape=[None, 784], name="input")
    y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y_")
    
    W = tf.Variable(tf.zeros([784,10]), name = "W")
    b = tf.Variable(tf.zeros([10]), name = "b")
    sess.run(tf.initialize_all_variables())
    
    y = tf.nn.softmax(tf.matmul(x,W) + b, name="output")
    
    print ("all tensors: ")
    print ("x : ", x)
    print ("y_ : ", y_)
    print ("y : ", y)
    print ("W : ", W)
    print ("b : ", b)
    
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    
    for i in range(1000):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={x: batch[0], y_: batch[1]})
        
    print("push input_map & input_node")
    for v in tf.trainable_variables():
        vars[v.value().name] = sess.run(v)
        #print(vars[v.value().name])
        
    print(vars)
    
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
  
  
  
second_graph = tf.Graph()
with second_graph.as_default():
  with tf.Session() as persisted_sess:
    print("# input map")
    map = {}
    for key in vars.keys():
      map[key] = tf.constant(vars[key])
    
    map["input:0"] = tf.placeholder(tf.float32, shape=[None, 784], name="input")
    map["y_:0"] = tf.placeholder(tf.float32, shape=[None, 10], name="y_")
    #map["output:0"] = tf.placeholder(tf.float32, shape=[None, 10], name="output")
    print(map)
    
    print("# load graph")
    #tf.import_graph_def(first_graph.as_graph_def(), input_map = map)
    output_list = tf.import_graph_def(first_graph.as_graph_def(), input_map = map, return_elements = ["output:0"], name = '')
    #tf.import_graph_def(sess.graph_def, input_map={name:consts[name] for name in consts.keys()}, name = '')
    print(output_list)
    
    print("# map variables")
    c = persisted_sess.graph.get_tensor_by_name("Const:0")
    print(c)
    x = persisted_sess.graph.get_tensor_by_name("input:0")
    print(x)
    y_ = persisted_sess.graph.get_tensor_by_name("y_:0")
    print(y_)
    y = persisted_sess.graph.get_tensor_by_name("output:0")
    print(y)
     
    print ("# strip extra nodes: ")
    #graph_util.extract_sub_graph(persisted_sess.graph, b)
  
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
    
    tf.train.write_graph(second_graph.as_graph_def(),'/tmp/graph/','graph.pb',False)

 

运行(草稿):

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Import data
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
from tensorflow.python.platform import gfile
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)


#load graph
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

#run
with tf.Session() as sess:
    x = sess.graph.get_tensor_by_name("input:0")
    #print input_x
    y_ = sess.graph.get_tensor_by_name("y_:0")
    y = sess.graph.get_operation_by_name("output:0")
    print y
    
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

 

转载于:https://my.oschina.net/hounLeft/blog/720790

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值