Tensorflow教程笔记
-
基础
TensorFlow 基础
TensorFlow 模型建立与训练
基础示例:多层感知机(MLP)
卷积神经网络(CNN)
循环神经网络(RNN)
深度强化学习(DRL)
Keras Pipeline
自定义层、损失函数和评估指标
常用模块 tf.train.Checkpoint :变量的保存与恢复
常用模块 TensorBoard:训练过程可视化
常用模块 tf.data :数据集的构建与预处理
常用模块 TFRecord :TensorFlow 数据集存储格式
常用模块 tf.function :图执行模式
常用模块 tf.TensorArray :TensorFlow 动态数组
常用模块 tf.config:GPU 的使用与分配 -
大规模训练与加速
TensorFlow 分布式训练
使用 TPU 训练 TensorFlow 模型 -
附录
强化学习基础简介
目录
TensorFlow Lite 是 TensorFlow 在移动和 IoT 等边缘设备端的解决方案,提供了 Java、Python 和 C++ API 库,可以运行在 Android、iOS 和 Raspberry Pi 等设备上。2019 年是 5G 元年,万物互联的时代已经来临,作为 TensorFlow 在边缘设备上的基础设施,TFLite 将会是愈发重要的角色。
目前 TFLite 只提供了推理功能,在服务器端进行训练后,经过如下简单处理即可部署到边缘设备上。
- 模型转换:由于边缘设备计算等资源有限,使用 TensorFlow 训练好的模型,模型太大、运行效率比较低,不能直接在移动端部署,需要通过相应工具进行转换成适合边缘设备的格式。
- 边缘设备部署:本节以 android 为例,简单介绍如何在 android 应用中部署转化后的模型,完成 Mnist 图片的识别。
模型转换
转换工具有两种:命令行工具和 Python API
TF2.0 对模型转换工具发生了非常大的变化,推荐大家使用 Python API 进行转换,命令行工具只提供了基本的转化功能。转换后的原模型为 FlatBuffers
格式。 FlatBuffers
原来主要应用于游戏场景,是谷歌为了高性能场景创建的序列化库,相比 Protocol Buffer 有更高的性能和更小的大小等优势,更适合于边缘设备部署。
转换方式有两种:Float 格式和 Quantized 格式
为了熟悉两种方式我们都会使用,针对 Float 格式的,先使用命令行工具 tflite_convert
,其跟随 TensorFlow 一起安装。
在终端执行如下命令:
tflite_convert -h
输出结果如下,即该命令的使用方法:
usage: tflite_convert [-h] --output_file OUTPUT_FILE
(--saved_model_dir SAVED_MODEL_DIR | --keras_model_file KERAS_MODEL_FILE)
--output_file OUTPUT_FILE
Full filepath of the output file.
--saved_model_dir SAVED_MODEL_DIR
Full path of the directory containing the SavedModel.
--keras_model_file KERAS_MODEL_FILE
Full filepath of HDF5 file containing tf.Keras model.
在 TensorFlow 模型导出 中,我们知道 TF2.0 支持两种模型导出方法和格式 SavedModel 和 Keras Sequential。
SavedModel 导出模型转换:
tflite_convert --saved_model_dir=saved/1 --output_file=mnist_savedmodel.tflite
Keras Sequential 导出模型转换:
tflite_convert --keras_model_file=mnist_cnn.h5 --output_file=mnist_sequential.tflite
到此,已经得到两个 TensorFlow Lite 模型。因为两者后续操作基本一致,我们只处理 SavedModel 格式的,Keras Sequential 的转换可以按类似方法处理。
Android 部署
现在开始在 Android 环境部署,为了获取 SDK 和 gradle 编译环境等资源,需要先给 Android Studio 配置 proxy 或者使用镜像。
配置 build.gradle
将 build.gradle
中的 maven 源 google()
和 jcenter()
分别替换为阿里云镜像地址,如下:
buildscript {
repositories {
maven {
url 'https://maven.aliyun.com/nexus/content/repositories/google' }
maven {
url