keras转tensorflow

参考:https://juejin.im/post/5b7bccc6e51d453887102e0a

代码如下:

# coding: utf-8

# In[ ]:

'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc., 
and it is mostly equal to 1. If the network is a **multi-stream network** 
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the 
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''


# Parse input arguments

# In[ ]:

import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
                    dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
                    dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
                    dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
                    dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
                    dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
                    dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
                    dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
                    dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
                    dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
                    dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)

if args.theano_backend is True and args.quantize is True:
    raise ValueError("Quantize feature does not work with theano backend.")


# initialize

# In[ ]:

from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope

output_fld =  args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
    args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)

# Load keras model and rename output

# In[ ]:

K.set_learning_phase(0)
if args.theano_backend:
    K.set_image_data_format('channels_first')
else:
    K.set_image_data_format('channels_last')

# try:
# 主要修改在这里,需要加上这行,否则会报错
with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
    net_model = load_model(weight_file_path)
# except ValueError as err:
#     print('''Input file specified ({}) only holds the weights, and not the model defenition.
#     Save the model using mode.save(filename.h5) which will contain the network architecture
#     as well as its weights.
#     If the model is saved using model.save_weights(filename.h5), the model architecture is
#     expected to be saved separately in a json format and loaded prior to loading the weights.
#     Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
#           .format(weight_file_path))
#     raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
    pred_node_names[i] = args.output_node_prefix+str(i)
    pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
    f = args.output_graphdef_file
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))

转换成功:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值