如何使用tensorflow加载keras训练好的模型

感谢

How to convert your Keras models to Tensorflow

前言

最近实验室碰到一个奇怪的需求,大家分别构建不同的NLP模型,最后需要进行整合,可是由于有的同学使用的是keras,有的同学喜欢使用TensorFlow,这样导致在构建接口时无法统一不同模型load的方式,每一个模型单独使用一种load的方式的话导致了很多重复开发,效率不高的同时也对项目的可扩展性造成了巨大的破坏。于是需要一种能够统一TensorFlow和keras模型的load过程的方法。

正文

1.构建keras模型
首先假设我们build了一个非常简单的keras模型,如下所示:

x = np.vstack((np.random.rand(1000,10),-np.random.rand(1000,10)))
y = np.vstack((np.ones((1000,1)),np.zeros((1000,1))))
print(x.shape)
print(y.shape)

model = Sequential()
model.add(Dense(units = 32, input_shape=(10,), activation ='relu'))
model.add(Dense(units = 16, activation ='relu'))
model.add(Dense(units = 1, activation ='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
model.fit(x = x, y=y, epochs = 2, validation_split=0.2) 

2.将keras模型保存为Protocol Buffers的格式
由于TensorFlow是支持将模型保存为Protocol Buffers(.pb)格式的,如果我们有一种方法能将keras模型保存为(.pb)格式的话,那我们的问题就解决了。可是天不遂人愿,keras没有直接提供这样一个将模型保存为(.pb)格式的方法,所以我们必须自己实现这样一个方法,如果你看过keras的源码的话,你会发现keras backend提供了一个get_session()的函数(只有基于TensorFlow的backend有),该函数会返回一个TensorFlow Session,这样一来我们就另辟蹊径,使用这个Session来保存keras模型,而不使用keras已经提供的保存模型的函数,方法如下:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    将输入的Session保存为静态的计算图结构.
    创建一个新的计算图,其中的节点以及权重和输入的Session相同. 新的计算图会将输入Session中不参与计算的部分删除。
    @param session 需要被保存的Session.
    @param keep_var_names 一个记录了需要被保存的变量名的list,若为None则默认保存所有的变量.
    @param output_names 计算图相关输出的name list.
    @param clear_devices 若为True的话会删除不参与计算的部分,这样更利于移植,否则可能移植失败
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

我们通过如下方法调用上述函数来保存模型:

from keras import backend as K
frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, wkdir, pb_filename, as_text=False)

3.在TensorFlow中载入保存的模型
载入保存模型的例子如下:

from tensorflow.python.platform import gfile
with tf.Session() as sess:
    # load model from pb file
    with gfile.FastGFile(wkdir+'/'+pb_filename,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        g_in = tf.import_graph_def(graph_def)
    # write to tensorboard (check tensorboard for each op names)
    writer = tf.summary.FileWriter(wkdir+'/log/')
    writer.add_graph(sess.graph)
    writer.flush()
    writer.close()
    # print all operation names 
    print('\n===== ouptut operation names =====\n')
    for op in sess.graph.get_operations():
      print(op)
    # inference by the model (op name must comes with :0 to specify the index of its output)
    tensor_output = sess.graph.get_tensor_by_name('import/dense_3/Sigmoid:0')
    tensor_input = sess.graph.get_tensor_by_name('import/dense_1_input:0')
    predictions = sess.run(tensor_output, {tensor_input: x})
    print('\n===== output predicted results =====\n')
    print(predictions)
  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值