直接上代码吧:
pb_path = 'model.pb'
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
input_x = sess.graph.get_tensor_by_name('Placeholder:0')
output= sess.graph.get_tensor_by_name('output:0')
preValue=tf.arg_max(output,1)
preValue = sess.run(preValue, feed_dict={input_x: testPicArr})
print(type(preValue))
print ("The prediction number is:", preValue)