代码:
import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss
from yolo3.utils import get_random_data
from keras.layers import Input
import time
import os
def save_model_to_serving(model, export_version, export_path='prod_models'):
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'inputs': model.input}, outputs={'outputs_0': model.output[0], 'outputs_1': model.output[1], 'outputs_2': model.output[2],}) #设置模型的输入和输出
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'serving_default': signature, # 默认
},
legacy_init_op=legacy_init_op)
builder.save()
if __name__ == '__main__':
model = yolo_body(Input(shape=(416, 416, 3)), 3, 2) #网络结构
model.summary() #查看网络结构
checkpoint_filepath = 'trained_weights.h5' #模型路径
if (isfile(checkpoint_filepath)):
print('Checkpoint file detected. Loading weights.')
model.load_weights(checkpoint_filepath) # 加载权重
else:
print('No checkpoint file detected. Starting from scratch.')
export_path = "output_model"
vers = int(time.time())
save_model_to_serving(model, str(vers), export_path)
注意事项:
其他h5模型转换时一般只需网络模型的加载和输入输出即可