Tensorflow 使用slim框架下的分类模型进行分类

Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object detection接口类似的image classification接口,可以很方便的进行fine-tuning利用自己的数据集训练自己所需的模型。

官方文档提供了比较详细的从数据准备,预训练模型的model zoo,fine-tuning,freeze model等一系列流程的步骤,但是缺少了inference的文档,不过tf所有模型的加载方式是通用的,所以调用方法和调用其他pb模型是一样的。

根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph –> write your graph –> import from written graph –> run compute etc。

以下我们使用slim提供的网络inception-resnet-v2作为例子:

  1. export inference graph
    import tensorflow as tf
    import nets.inception_resnet_v2 as net

slim = tf.contrib.slim

checkpoint path

checkpoint_path = “/your/path/to/inception_resnet_v2.ckpt” # ckpt file obtained during model training or fine-tuning

set up and load session

sess = tf.Session()
arg_scope = net.inception_resnet_v2_arg_scope()

initialize tensor suitable for model input

input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
logits, end_points = net.inception_resnet_v2(inputs=input_tensor)

set up model saver

saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with tf.gfile.GFile(‘/your/path/to/model_graph.pb’, ‘w’) as f: # save model to given pb file
f.write(sess.graph_def.SerializeToString())
f.close()
2. freeze model
这里用tf提供的tensorflow/python/tools下的freeze_graph工具:

bazelbuildtensorflow/python/tools:freezegraph bazel-bin/tensorflow/python/tools/freeze_graph \
–input_graph=/your/path/to/model_graph.pb \ # obtained above
–input_checkpoint=/your/path/to/inception_resnet_v2.ckpt \
–input_binary=true
–output_graph=/your/path/to/frozen_graph.pb \
–output_node_names=InceptionResnetV2/Logits/Predictions # output node name defined in inception resnet v2 net
(Optional) visualize frozen graph
LOG_DIR = ‘/tmp/graphdeflogdir’
model_filename = ‘/your/path/to/frozen_graph.pb’

with tf.Session() as sess:
with tf.gfile.FastGFile(model_filename, ‘rb’) as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name=”)
writer = tf.summary.FileWriter(LOG_DIR, graph_def)
writer.close()
然后用tensorborad –logdir=LOG_DIR选择graph就可以查看到frozen后的网络结构。

  1. inference
    import cv2
    import numpy as np

def preprocess_inception(image_np, central_fraction=0.875):
image_height, image_width, image_channel = image_np.shape
if central_fraction:
bbox_start_h = int(image_height * (1 - central_fraction) / 2)
bbox_end_h = int(image_height - bbox_start_h)
bbox_start_w = int(image_width * (1 - central_fraction) / 2)
bbox_end_w = int(image_width - bbox_start_w)
image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w]
# normalize
image_np = 2 * (image_np / 255.) - 1
return image_np

image_np = cv2.imread(“test.jpg”)

preprocess image as inception resnet v2 does

image_np = preprcess_inception(image_np)

resize to model input image size

image_np = cv2.resize(image_np, (299, 299))

expand dims to shape [None, 299, 299, 3]

image_np = np.expand_dims(image_np, 0)

load model

with tf.gfile.GFile(‘/your/path/to/frozen_graph.pb’)
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name=”)
with tf.Session(graph=graph) as sess:
input tensor = sess.graph.get_tensor_by_name(“input:0”) # get input tensor
output_tensor = sess.graph.get_tensor_by_name(“InceptionResnetV2/Logits/Predictions:0”) # get output tensor
logits = sess.run(output_tensor, feed_dict={input_tensor: image_np})
print “Prediciton label index:”, np.argmax(logits[0], 1)
print “Top 3 Prediciton label index:”, np.argsort(logits[0], 3)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值