转自:
https://blog.csdn.net/u014475479/article/details/84709301
前言
最近实验室碰到一个奇怪的需求,大家分别构建不同的NLP模型,最后需要进行整合,可是由于有的同学使用的是keras,有的同学喜欢使用TensorFlow,这样导致在构建接口时无法统一不同模型load的方式,每一个模型单独使用一种load的方式的话导致了很多重复开发,效率不高的同时也对项目的可扩展性造成了巨大的破坏。于是需要一种能够统一TensorFlow和keras模型的load过程的方法。
正文
1.构建keras模型
首先假设我们build了一个非常简单的keras模型,如下所示:
x = np.vstack((np.random.rand(1000,10),-np.random.rand(1000,10)))
y = np.vstack((np.ones((1000,1)),np.zeros((1000,1))))
print(x.shape)
print(y.shape)
model = Sequential()
model.add(Dense(units = 32, input_shape=(10,), activation ='relu'))
model.add(Dense(units = 16, activation ='relu'))
model.add(Dense(units = 1, activation ='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
model.fit(x = x, y=y, epochs = 2, validation_split=0.2)
2.将keras模型保存为Protocol Buffers的格式
由于TensorFlow是支持将模型保存为Protocol Buffers(.pb)格式的,如果我们有一种方法能将keras模型保存为(.pb)格式的话,那我们的问题就解决了。可是天不遂人愿,keras没有直接提供这样一个将模型保存为(.pb)格式的方法,所以我们必须自己实现这样一个方法,如果你看过keras的源码的话,你会发现keras backend提供了一个get_session()的函数(只有基于TensorFlow的backend有),该函数会返回一个TensorFlow Session,这样一来我们就另辟蹊径,使用这个Session来保存keras模型,而不使用keras已经提供的保存模型的函数,方法如下:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
将输入的Session保存为静态的计算图结构.
创建一个新的计算图,其中的节点以及权重和输入的Session相同. 新的计算图会将输入Session中不参与计算的部分删除。
@param session 需要被保存的Session.
@param keep_var_names 一个记录了需要被保存的变量名的list,若为None则默认保存所有的变量.
@param output_names 计算图相关输出的name list.
@param clear_devices 若为True的话会删除不参与计算的部分,这样更利于移植,否则可能移植失败
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
我们通过如下方法调用上述函数来保存模型:
from keras import backend as K
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, wkdir, pb_filename, as_text=False)
3.在TensorFlow中载入保存的模型
载入保存模型的例子如下:
from tensorflow.python.platform import gfile
with tf.Session() as sess:
# load model from pb file
with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
g_in = tf.import_graph_def(graph_def)
# write to tensorboard (check tensorboard for each op names)
writer = tf.summary.FileWriter(wkdir+'/log/')
writer.add_graph(sess.graph)
writer.flush()
writer.close()
# print all operation names
print('\n===== ouptut operation names =====\n')
for op in sess.graph.get_operations():
print(op)
# inference by the model (op name must comes with :0 to specify the index of its output)
tensor_output = sess.graph.get_tensor_by_name('import/dense_3/Sigmoid:0')
tensor_input = sess.graph.get_tensor_by_name('import/dense_1_input:0')
predictions = sess.run(tensor_output, {tensor_input: x})
print('\n===== output predicted results =====\n')
print(predictions)