keras模型是依赖tensorflow框架的,在恢复模型之前还需要再定义一遍网络结构,这对于部署到生产环境来说非常不方便。而转换为pb文件,可以独立运行,任何语言都可以解析它,同时方便部署到tf serving上。本文提供以下两种转换方法。
方法1(推荐):
适用于tf2.0之后的版本,但是1.0版本生成的hdf5文件也可以用此方法转换,前提必须在tf2.0环境下运行
model_path = 'weights.hdf5' # 模型文件
model = tf.keras.models.load_model(model_path)
model.save('models/1/', save_format='tf')
方法2:
适用于tf2.0之前的版本
import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
import os
def build_model():
model = Sequential()
# 你的模型结构
# ...
return model
def save_model_to_serving(model, export_version, export_path='prod_models'):
print(model.input, model.output)
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'voice': model.input}, outputs={'scores': model.output})
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'voice_classification': signature,
},
legacy_init_op=legacy_init_op)
builder.save()
if __name__ == '__main__':
model = build_model()
model.compile(loss='categorical_crossentropy',
optimizer='xxx', # xxx替换
metrics=['xxx'])
model.summary()
filepath = 'weights.hdf5' # 模型文件
if (isfile(filepath)):
model.load_weights(filepath)
else:
print('No file detected.')
export_path = "test_model"
save_model_to_serving(model, "1", export_path)