tensorflow各种模型文件的生成

一,pb文件

import os
import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(ckpt, output_graph):
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "predictions"
    saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta',clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, ckpt.model_checkpoint_path)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=sess.graph_def,
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))
        # 得到当前图有几个操作节点

if __name__ == "__main__":
    ckpt = tf.train.get_checkpoint_state('../model/')
    # 输出pb模型的路径
    out_pb_path = "../model/tmp.pb"
    freeze_graph(ckpt, out_pb_path)

二.tflite文件

import os
import sys
import argparse

import tensorflow as tf

sys.path.append((os.path.normpath(
                 os.path.join(os.path.dirname(os.path.realpath(__file__)),
                              '..'))))

from model import OpenNsfwModel, InputType

"""Exports a tflite version of tensorflow-open_nsfw
Note: The standard TFLite runtime does not support all required ops when using the base64_jpeg input type.
You will have to implement the missing ones by yourself.
"""
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--target", default='../data/nsfw.tflite')

    parser.add_argument("-i", "--input_type",
                        default=InputType.TENSOR.name.lower(),
                        help="Input type. Warning: base64_jpeg does not work with the standard TFLite runtime since a lot of operations are not supported",
                        choices=[InputType.TENSOR.name.lower(),
                                 InputType.BASE64_JPEG.name.lower()])

    parser.add_argument("-m", "--model_weights", default="../data/open_nsfw-weights.npy",
                        help="Path to trained model weights file")

    args = parser.parse_args()

    model = OpenNsfwModel()

    export_path = args.target
    input_type = InputType[args.input_type.upper()]

    with tf.Session() as sess:
        model.build(weights_path=args.model_weights,input_type=input_type)

        sess.run(tf.global_variables_initializer())

        converter = tf.lite.TFLiteConverter.from_session(sess, [model.input], [model.predictions])
        tflite_model = converter.convert()

        with open(export_path, "wb") as f:
            f.write(tflite_model)

三.tf-serve线上所有模型

# coding: utf-8
import time
from resnet50 import OpenNsfwModel
import tensorflow
if tensorflow.__version__ > '2.0':
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    tf.disable_eager_execution()
else:
    import tensorflow as tf

save_path = '../class3/All-150-9529-8801'

def build_and_saved_wdl():
    model = OpenNsfwModel(trainable=False, classnum=3)  # 我自己的模型结构是在这个类中定义的,基于自己的模型进行替换

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess=session, save_path=save_path)

    savename = time.strftime("%m%d%H%M", time.localtime())
    builder = tf.saved_model.builder.SavedModelBuilder("./porn/{}".format(savename))
    inputs = {"input": tf.saved_model.utils.build_tensor_info(model.input)}
    output = {"conf": tf.saved_model.utils.build_tensor_info(model.conf)}
    prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs,
        outputs=output,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )

    builder.add_meta_graph_and_variables(
        session,
        [tf.saved_model.tag_constants.SERVING],
        {'predict_images': prediction_signature}
    )
    builder.save()

if __name__ == '__main__':
    build_and_saved_wdl()

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值