根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph –> write your graph –> import from written graph –> run compute etc。
以下我们使用slim提供的网络Vgg19作为例子:
1. export inference graph
import tensorflow as tf
from tensorflow.python.platform import gfile
from datasets import dataset_factory
from nets import nets_factory
import nets.vgg as net
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
‘model_name’, ‘vgg_19’, ‘The name of the architecture to save.’)
tf.app.flags.DEFINE_boolean(
‘is_training’, False,
‘Whether to save out a training-focused version of the model.’)
tf.app.flags.DEFINE_integer(
‘default_image_size’, 224,
‘The image size to use if the model does not define it.’)
tf.app.flags.DEFINE_string(‘dataset_name’, ‘imagenet’,
‘The name of the dataset to use with the model.’)
tf.app.flags.DEFINE_integer(
‘labels_offset’, 0,
‘An offset for the labels in the dataset. This flag is primarily used to ’
‘evaluate the VGG and ResNet architectures which do not use a background ’
‘class for the ImageNet dataset.’)
tf.app.flags.DEFINE_string(
‘output_file’, ’ /log/model_graph.pb’, ‘Where to save the resulting file to.’)
tf.app.flags.DEFINE_string(
‘dataset_dir’, ”, ‘Directory to save intermediate dataset files to’)
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.output_file:
raise ValueError(‘You must supply the path to save to with –output_file’)
tf.logging.set_verbosity(tf.logging.INFO)
# checkpoint path
checkpoint_path = “You cpkt model path” # ckpt file obtained during model training or fine-tuning
# set up and load session
sess = tf.Session()
arg_scope = net.vgg_arg_scope()
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(19 - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, ‘default_image_size’):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name=’input’, dtype=tf.float32,
shape=[1, image_size, image_size, 3])
with slim.arg_scope(arg_scope):
logits, end_points = network_fn(placeholder)
probabilities = tf.nn.softmax(logits)
result = tf.identity(probabilities,’output’)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with gfile.GFile(FLAGS.output_file, ‘wb’) as f:
f.write(sess.graph_def.SerializeToString())
f.close()
if name == ‘main‘:
tf.app.run()
- freeze model
可以通过bazel执行
bazelbuildtensorflow/python/tools:freezegraph bazel-bin/tensorflow/python/tools/freeze_graph \ –input_graph=/your/path/to/model_graph.pb \ # obtained above –input_checkpoint=/your/path/to/vgg-19.ckpt \ –input_binary=true –output_graph=/your/path/to/frozen_graph.pb \ –output_node_names=ouput # output node name defined in inception resnet v2 net
也可以通过python代码执行,如下:
python freeze_graph.py --input_graph=/your/path/to/model_graph.pb --input_checkpoint=/your/path/to/vgg-19.ckpt --input_binary=true --output_graph=/your/path/to/frozen_graph.pb --output_node_names=output
注意:此处model_graph.pb为保存模型的推理图,结果为上一步生成文件。
- inference
import cv2
import numpy as np
from nets import nets_factory
from preprocessing import preprocessing_factory, vgg_preprocessing
import tensorflow as tf
file = r” /data/1.jpg”
eval_image_size = 224 #FLAGS.eval_image_size or network_fn.default_image_size
image_np = cv2.imread(file)
resize to model input image size
image_np = cv2.resize(image_np, (eval_image_size, eval_image_size))
image_np = np.expand_dims(image_np, 0)
load model
with tf.gfile.GFile(’ /log/frozen_graph.pb’) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name=”)
with tf.Session(graph=graph) as sess:
input_tensor = sess.graph.get_tensor_by_name(“input:0”) # get input tensor
output_tensor = sess.graph.get_tensor_by_name(“output:0”) # get output tensor
logits = sess.run(output_tensor, feed_dict={input_tensor:image_np})
print “Prediciton label index:”, np.argmax(logits, 1)
print “Top 3 Prediciton label index:”, np.argsort(logits[8])