怎么将tflite部署在安卓上_tensorflow从训练自定义CNN网络模型到Android端部署tflite...

本文详细介绍了如何从头开始训练一个七层卷积网络模型,使用TensorFlow Slim,处理自定义验证码数据集,达到96%~99%的测试准确率。内容涵盖数据生成、模型训练、模型保存为PB格式,以及如何将模型转换为tflite格式。同时,强调了训练时预处理的重要性,移动端部署时需保持一致。最后,简要提到了Android端部署的关键代码调整。
摘要由CSDN通过智能技术生成

网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型在部署到安卓端的时候出现各种问题。因此,本文会记录从PC端训练、导出到安卓端部署的各种细节。欢迎大家讨论、指教。

PC端系统:Ubuntu14

tensorflow版本:tensroflow1.14

安卓版本:9.0

PC端训练过程

数据集:自定义生成

训练框架:tensorflow slim  关于tensorflow slim如何安装,这里不再赘述,大家自行百度解决。

数据生成代码:生成50000张28*28大小三通道的验证码图片,共分10类,0-9,生成的数据保存在datasets/images/里面

#-*- coding: utf-8 -*-

importcv2importnumpy as npfrom captcha.image importImageCaptchadef generate_captcha(text='1'):"""Generate a digit image."""capt= ImageCaptcha(width=28, height=28, font_sizes=[24])

image=capt.generate_image(text)

image= np.array(image, dtype=np.uint8)returnimageif __name__ == '__main__':

output_dir= './datasets/images/'

for i in range(50000):

label= np.random.randint(0, 10)

image=generate_captcha(str(label))

image_name= 'image{}_{}.jpg'.format(i+1, label)

output_path= output_dir +image_name

cv2.imwrite(output_path, image)

训练:本次训练我用tensorflow slim 搭建了一个七层卷积的网络,最后测试准确率在96%~99%左右,模型1.2M,适合在移动端部署。训练的时候我做了两点工作

1、指明了模型的输入和输出节点的名字,PC端部署测试模型的时候要用到,也便于快速确定模型的输出数据到底是什么格式,移动端代码要与其保持一致

inputs = tf.placeholder(tf.float32, shape=[None, 28, 28, 3], name='inputs')

.......

.......

prob_= tf.identity(prob, name='prob')

2、训练结束的时候直接把模型保存成PB格式

constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['inputs','prob']) #训练完毕直接把模型保存为PB格式

with tf.gfile.FastGFile('model3.pb', mode='wb') as f: #模型的名字是model.pb

f.write(constant_graph.SerializeToString())

训练代码如下

#-*- coding: utf-8 -*-

"""Train a CNN model to classifying 10 digits.

Example Usage:

---------------

python3 train.py \

--images_path: Path to the training images (directory).

--model_output_path: Path to model.ckpt."""

importcv2importglobimportnumpy as npimportosimporttensorflow as tfimportmodelfrom tensorflow.python.framework importgraph_util

flags=tf.app.flags

flags.DEFINE_string('images_path', None, 'Path to training images.')

flags.DEFINE_string('model_output_path', None, 'Path to model checkpoint.')

FLAGS=flags.FLAGSdefget_train_data(images_path):"""Get the training images from images_path.

Args:

images_path: Path to trianing images.

Returns:

images: A list of images.

lables: A list of integers representing the classes of images.

Raises:

ValueError: If images_path is not exist."""

if notos.path.exists(images_path):raise ValueError('images_path is not exist.')

images=[]

labels=[]

对于使用TensorFlow 2.3训练数字识别模型并将其量化为TFLite,然后部署到OpenMV上,你可以按照以下步骤进行操作: 1. 数据集准备:收集和准备用于数字识别的图像数据集。确保数据集具有适当的标签和类别。 2. 模型训练:使用TensorFlow 2.3构建和训练适合数字识别的模型,例如卷积神经网络CNN)。确保在训练过程中使用适当的评估指标和优化算法。 3. 模型量化:在训练完成后,将训练好的模型量化为TFLite格式。TFLite是一种针对移动和嵌入式设备的轻量级模型表示形式。 ```python import tensorflow as tf # 加载训练好的模型 model = tf.keras.models.load_model('trained_model.h5') # 量化模型 converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() # 保存量化后的模型 with open('quantized_model.tflite', 'wb') as f: f.write(tflite_model) ``` 4. OpenMV准备:确保你已经按照OpenMV官方文档的指导,设置并准备好OpenMV开发环境。 5. 部署到OpenMV:将量化后的TFLite模型部署到OpenMV上进行推理。可以使用OpenMV的MicroPython编程语言进行开发。 - 将`quantized_model.tflite`文件复制到OpenMV设备上,例如SD卡。 - 在OpenMV上编写MicroPython代码,加载模型并进行推理。 ```python import sensor import image import lcd # 初始化OpenMV模块 sensor.reset() sensor.set_pixformat(sensor.RGB565) sensor.set_framesize(sensor.QVGA) sensor.run(1) # 加载TFLite模型 import tf model = tf.load('quantized_model.tflite') # 进行推理 while True: img = sensor.snapshot() # 对图像进行预处理 # ... # 进行推理 output = model.forward(img) # 处理推理结果 # ... ``` 这个过程中,你需要根据你的具体需求和OpenMV设备的要求进行适当的调整和修改。上述步骤仅供参考,你可以根据实际情况进行调整。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值