在Keras框架下调用TensorFlow代码时的一些技巧
-
InputLayer作为输入时,部分tf框架的模型会卡死(不报错也不继续运行)具体原因不明,可以在调用前保存该Tensor对象的_keras_history,然后进删除_keras_history字段,再作为tf模型的输入进行使用,待该得到该tf模型的输出,再把之前保存的_keras_history重新赋值给InputLayer的Tensor对象,并且补齐_keras_history[0]的_inbound_nodes[0]所需的数据。
例如:
char_ids = Input(batch_shape=(None, None), dtype='int32', name='input_ids') # Input输入 elmo_model = ElmoEmbeddingLayer(self.data_config) # 包装了tf模型的layer p = char_ids._keras_history # 保存输入Tensor的_keras_history del(char_ids._keras_history) # 删除_keras_history elmo_embeddings = elmo_model(char_ids) char_ids._keras_history = p # 重新赋值 # 补全_keras_history[0]的_inbound_nodes的内容(这里只有一个inbound_nodes所以只修改了0位置) elmo_embeddings._keras_history[0]._inbound_nodes[0].inbound_layers[0] = char_ids._keras_history[0] elmo_embeddings._keras_history[0]._inbound_nodes[0].node_indices[0] = 0 elmo_embeddings._keras_history[0]._inbound_nodes[0].tensor_indices[0] = 0
-
包装tf模型到keras layer的时候,一定要注意compute_output_shape方法和compute_mask方法,尤其是内部会修改维度的操作,这辆个方法一定要体现出维度的修改