网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。
PC端系统:Ubuntu14
tensorflow版本:tensroflow1.14
安卓版本:9.0
PC端训练过程
数据集:自定义生成
训练框架:tensorflow slim 关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。
数据生成代码:生成50000张28*28大小三通道的验证码图片,共分10类,0-9,生成的数据保存在datasets/images/里面
#-*- coding: utf-8 -*-
importcv2importnumpy as npfrom captcha.image importImageCaptchadef generate_captcha(text='1'):"""Generate a digit image."""capt= ImageCaptcha(width=28, height=28, font_sizes=[24])
image=capt.generate_image(text)
image= np.array(image, dtype=np.uint8)returnimageif __name__ == '__main__':
output_dir= './datasets/images/'
for i in range(50000):
label= np.random.randint(0, 10)
image=generate_captcha(str(label))
image_name= 'image{}_{}.jpg'.format(i+1, label)
output_path= output_dir +image_name
cv2.imwrite(output_path, image)
训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作
1、指明了模型的输入和输出节点的名字,PC端部署测试模型的时候要用到,也便于快速确定模型的输出数据到底是什么格式,移动端代码要与其保持一致
inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')
.......
.......
prob_= tf.identity(prob, name='prob')
2、训练结束的时候直接把模型保存成PB格式
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式
with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb
f.write(constant_graph.SerializeToString())
训练代码如下
#-*- coding: utf-8 -*-
"""Train a CNN model to classifying 10 digits.
Example Usage:
---------------
python3 train.py \
--images_path: Path to the training images (directory).
--model_output_path: Path to model.ckpt."""
importcv2importglobimportnumpy as npimportosimporttensorflow as tfimportmodelfrom tensorflow.python.framework importgraph_util
flags=tf.app.flags
flags.DEFINE_string('images_path', None, 'Path to training images.')
flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')
FLAGS=flags.FLAGSdefget_train_data(images_path):"""Get the training images from images_path.
Args:
images_path: Path to trianing images.
Returns:
images: A list of images.
lables: A list of integers representing the classes of images.
Raises:
ValueError: If images_path is not exist."""
if notos.path.exists(images_path):raise ValueError('images_path is not exist.')
images=[]
labels=[]