如何将keras训练的模型转换成tensorflow lite模型

背景

keras是一个比较适合初学者上手的高级神经网络API,它能够以TensorFlow, CNTK, 或者 Theano作为后端运行。而keras训练完的模型是.h5文件,如果想要在移动端运行模型需要tflite模型文件

实现

附上从github上找到的一段转换代码,但是要稍作修改

# coding: utf-8

# In[ ]:

'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc., 
and it is mostly equal to 1. If the network is a **multi-stream network** 
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the 
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''


# Parse input arguments

# In[ ]:

import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
                    dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
                    dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
                    dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
                    dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
                    dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
                    dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
                    dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
                    dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
                    dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
                    dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)

if args.theano_backend is True and args.quantize is True:
    raise ValueError("Quantize feature does not work with theano backend.")


# initialize

# In[ ]:

from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope

output_fld =  args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
    args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)

# Load keras model and rename output

# In[ ]:

K.set_learning_phase(0)
if args.theano_backend:
    K.set_image_data_format('channels_first')
else:
    K.set_image_data_format('channels_last')

# try:
# 主要修改在这里,需要加上这行,否则会报错
with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
    net_model = load_model(weight_file_path)
# except ValueError as err:
#     print('''Input file specified ({}) only holds the weights, and not the model defenition.
#     Save the model using mode.save(filename.h5) which will contain the network architecture
#     as well as its weights.
#     If the model is saved using model.save_weights(filename.h5), the model architecture is
#     expected to be saved separately in a json format and loaded prior to loading the weights.
#     Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
#           .format(weight_file_path))
#     raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
    pred_node_names[i] = args.output_node_prefix+str(i)
    pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
    f = args.output_graphdef_file
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
复制代码

keras转tensorflow完成后,接下来我们就要将.pb文件转化为.tflite文件。这里查阅了很多资料,记录一下坑的地方

  1. 如果你的tensorflow是1.8的话,先要将tensorflow升级到1.9或者降级到1.7,因为1.8的toco命令不好使。升级方法就是pip3 install -U tensorflow 或者pip3 install --upgrade tensorflow
  2. 升级完成后就可以使用toco命令了,注意:如果之前你是用virtualenv安装的整个环境,那么先source ./bin/activate激活环境,在环境下才能使用
toco --graph_def_file mobilenet_v1_1.0_224_frozen.pb \
  --output_format=TFLITE \
  --output_file=mobilenet_v1_1.0_224_test.tflite \
  --inference_type=FLOAT \
  --input_arrays=input \
  --output_arrays=MobilenetV1/Predictions/Reshape_1 \
  --input_shapes=1,224,224,3
复制代码

这里注意,千万不要按照教程里的命令进行,因为这里有几个坑点:

    1. 1.9的toco命令已经用参数--graph_def_file代替了--input_file
    1. 1.9的toco命令已经将参数--input_type取消掉

所以最后可以运行成功的命令,如上

  1. 上面的命令运行成功后,就可以将自己的pb文件转化成tflite文件了,只要替换graph_def_file后面的pb文件名字和output_file后面的输出文件名字,然后重点是知道你训练的模型的input层的name和output层的name,至于怎么找到这两个层的name,最好用tensorflow中的load_graph函数load一下你的pb模型,遍历graph找到对应层的name既可。分别用input层name和output层name替换input_arrays和output_arrays参数后面的值
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值