keras冻结,如何将Keras模型保存为冻结图?

I am working with Tensorflow 2.0 and want to store the following Keras model as frozen graph.

import tensorflow as tf

model = tf.keras.Sequential()

model.add(tf.keras.layers.Dense(64, input_shape=[100]))

model.add(tf.keras.layers.Dense(32, activation='relu'))

model.add(tf.keras.layers.Dense(16, activation='relu'))

model.add(tf.keras.layers.Dense(2, activation='softmax'))

model.summary()

model.save('./models/')

I can't find any good examples how to do this in Tensorflow 2.0. I have found the freeze_graph.py file in the Tensorflow Github repository but find it hard to wrap my head around it.

I load the file mentioned above using:

from tensorflow.python.tools.freeze_graph import freeze_graph

But what exactly do I have to provide to the freeze_graph function itself? Here I marked the arguments where I am not sure with a questionmark.

freeze_graph(input_graph=?,

input_saver='',

input_binary=False,

input_checkpoint=?,

output_node_names=?,

restore_op_name='',

filename_tensor_name='',

output_graph='./frozen_graph.pb',

clear_devices=True,

initializer_nodes='')

Can someone provide a simple example that shows how I can store the model above as a frozen graph using the freeeze_graph function?

解决方案

Freeze_Graph is now gone in Tensorflow 2.0.

You can check it here Tensorflow 2.0 : frozen graph support.

Except for the .save method that you have in your code.

.save Method is already saving a .pb ready for inference.

As an alternative, you can also use the below code.

You can also use convert_variables_to_constants_v2

Below is the sample code.

import tensorflow as tf

import os

from tensorflow.python.tools import freeze_graph

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

model = tf.keras.Sequential()

model.add(tf.keras.layers.Dense(64, input_shape=(1,)))

model.add(tf.keras.layers.Dense(32, activation='relu'))

model.add(tf.keras.layers.Dense(16, activation='relu'))

model.add(tf.keras.layers.Dense(1, activation='softmax'))

model.compile(optimizer='adam', loss='mse')

model.summary()

# Convert Keras model to ConcreteFunction

full_model = tf.function(lambda x: model(x))

full_model = full_model.get_concrete_function(

tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype, name="yourInputName"))

# Get frozen ConcreteFunction

frozen_func = convert_variables_to_constants_v2(full_model)

frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]

print("-" * 50)

print("Frozen model layers: ")

for layer in layers:

print(layer)

print("-" * 50)

print("Frozen model inputs: ")

print(frozen_func.inputs)

print("Frozen model outputs: ")

print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive

tf.io.write_graph(graph_or_graph_def=frozen_func.graph,

logdir="./frozen_models",

name="frozen_graph.pb",

as_text=False)

### USAGE ##

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):

def _imports_graph_def():

tf.compat.v1.import_graph_def(graph_def, name="")

wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])

import_graph = wrapped_import.graph

print("-" * 50)

print("Frozen model layers: ")

layers = [op.name for op in import_graph.get_operations()]

if print_graph == True:

for layer in layers:

print(layer)

print("-" * 50)

return wrapped_import.prune(

tf.nest.map_structure(import_graph.as_graph_element, inputs),

tf.nest.map_structure(import_graph.as_graph_element, outputs))

## Example Usage ###

# Load frozen graph using TensorFlow 1.x functions

with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f:

graph_def = tf.compat.v1.GraphDef()

loaded = graph_def.ParseFromString(f.read())

# Wrap frozen graph to ConcreteFunctions

frozen_func = wrap_frozen_graph(graph_def=graph_def,

inputs=["yourInputName:0"],

outputs=["Identity:0"],

print_graph=True)

print("-" * 50)

print("Frozen model inputs: ")

print(frozen_func.inputs)

print("Frozen model outputs: ")

print(frozen_func.outputs)

# Get predictions for test images

predictions = frozen_func(yourInputName=tf.constant([[3.]]))

# Print the prediction for the first image

print("-" * 50)

print("Example prediction reference:")

print(predictions[0].numpy())

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值