该方法只适用于通过save_model保存的keras模型文件(.h5),不适用于通过save_weights保存的keras模型文件,直接上代码。
import keras
from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend as K
# 转换函数
def h5_to_pb(h5_model, output_dir, new_name, out_prefix="output_", log_tensorboard=True):
if not osp.exists(output_dir):
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util, graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
if log_tensorboard:
from tenso