keras yolo3生成的h5模型转saved_model模型

代码:

import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss
from yolo3.utils import get_random_data
from keras.layers import Input
import time
import os

def save_model_to_serving(model, export_version, export_path='prod_models'):
    signature = tf.saved_model.signature_def_utils.predict_signature_def(                                                                        
        inputs={'inputs': model.input}, outputs={'outputs_0': model.output[0], 'outputs_1': model.output[1], 'outputs_2': model.output[2],})  #设置模型的输入和输出
    export_path = os.path.join(
        tf.compat.as_bytes(export_path),
        tf.compat.as_bytes(str(export_version)))
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    builder.add_meta_graph_and_variables(
        sess=K.get_session(),                                                                                                                    
        tags=[tf.saved_model.tag_constants.SERVING],                                                                                             
        signature_def_map={                                                                                                                      
            'serving_default': signature,   # 默认                                                                                                                  
        },
		legacy_init_op=legacy_init_op)
	builder.save()
	
if __name__ == '__main__':
	model = yolo_body(Input(shape=(416, 416, 3)), 3, 2) #网络结构
    model.summary()  #查看网络结构
    checkpoint_filepath = 'trained_weights.h5'  #模型路径
    if (isfile(checkpoint_filepath)):
        print('Checkpoint file detected. Loading weights.')
        model.load_weights(checkpoint_filepath) # 加载权重
    else:
        print('No checkpoint file detected.  Starting from scratch.')
    export_path = "output_model"
    vers = int(time.time())
    save_model_to_serving(model, str(vers), export_path)

注意事项:

其他h5模型转换时一般只需网络模型的加载和输入输出即可

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值