一,pb文件
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(ckpt, output_graph):
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "predictions"
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta',clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=sess.graph_def,
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
# 得到当前图有几个操作节点
if __name__ == "__main__":
ckpt = tf.train.get_checkpoint_state('../model/')
# 输出pb模型的路径
out_pb_path = "../model/tmp.pb"
freeze_graph(ckpt, out_pb_path)
二.tflite文件
import os
import sys
import argparse
import tensorflow as tf
sys.path.append((os.path.normpath(
os.path.join(os.path.dirname(os.path.realpath(__file__)),
'..'))))
from model import OpenNsfwModel, InputType
"""Exports a tflite version of tensorflow-open_nsfw
Note: The standard TFLite runtime does not support all required ops when using the base64_jpeg input type.
You will have to implement the missing ones by yourself.
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--target", default='../data/nsfw.tflite')
parser.add_argument("-i", "--input_type",
default=InputType.TENSOR.name.lower(),
help="Input type. Warning: base64_jpeg does not work with the standard TFLite runtime since a lot of operations are not supported",
choices=[InputType.TENSOR.name.lower(),
InputType.BASE64_JPEG.name.lower()])
parser.add_argument("-m", "--model_weights", default="../data/open_nsfw-weights.npy",
help="Path to trained model weights file")
args = parser.parse_args()
model = OpenNsfwModel()
export_path = args.target
input_type = InputType[args.input_type.upper()]
with tf.Session() as sess:
model.build(weights_path=args.model_weights,input_type=input_type)
sess.run(tf.global_variables_initializer())
converter = tf.lite.TFLiteConverter.from_session(sess, [model.input], [model.predictions])
tflite_model = converter.convert()
with open(export_path, "wb") as f:
f.write(tflite_model)
三.tf-serve线上所有模型
# coding: utf-8
import time
from resnet50 import OpenNsfwModel
import tensorflow
if tensorflow.__version__ > '2.0':
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.disable_eager_execution()
else:
import tensorflow as tf
save_path = '../class3/All-150-9529-8801'
def build_and_saved_wdl():
model = OpenNsfwModel(trainable=False, classnum=3) # 我自己的模型结构是在这个类中定义的,基于自己的模型进行替换
session = tf.Session()
session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=session, save_path=save_path)
savename = time.strftime("%m%d%H%M", time.localtime())
builder = tf.saved_model.builder.SavedModelBuilder("./porn/{}".format(savename))
inputs = {"input": tf.saved_model.utils.build_tensor_info(model.input)}
output = {"conf": tf.saved_model.utils.build_tensor_info(model.conf)}
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=output,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
builder.add_meta_graph_and_variables(
session,
[tf.saved_model.tag_constants.SERVING],
{'predict_images': prediction_signature}
)
builder.save()
if __name__ == '__main__':
build_and_saved_wdl()