在tensorfow lite中对各op进行单元测试

Tensorflow lite源码中提供了对个op的单元测试源码,但是在官方的tflite Makefile中默认并没有编译该部分代码。本文主要是记录在tflite中对op进行单独测试的方法,平台为ARM嵌入式。

概要

在tflite的源码中单元测试的源码一般在op名后面添加有test,在目录 tensorflow/contrib/lite/kernels下可以看到很多op的单元测试源码,如convolution的实现源码为conv.cc,则对应的单元测试源码为conv_test.cc,查看源码后可以知道单元测试采用googletest来实现。另外,基本上所有op的单元测试都会继承tensorflow/contrib/lite/kernels/test_util.h里面的SingleOpModel类,而官方源码中的Makefile默认是没有编译test_util.cc的。
下面,本文以编译conv_test.cc为例说明怎么使用单元测试。

基本思想为:先修改Makefile把test_util.cc编译进libtensorflow-lite.a,然后对要测试的conv_test.cc源码单独写一个cmake去调用新的libtensorflow-lite.a

由于unit test需要用到Googletest库,所以需要提前编译准备好Googletest,另外还需要用到absl库。

安装Googletest库

git clone https://github.com/google/googletest 
cd googletest
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=/path/to/yourdir ..
make install 

absl地址:https://github.com/abseil/abseil-cpp 先下载放到制定位置,可以暂时不用编译。

修改Makefile,编译新的tflite库

准备工作做好以后就可以修改lite源码中的Makefile了,修改的地方主要是添加googletest到INCLUDES,以及添加对其他源码的编译。

添加新的INCLUDES

INCLUDES += -I/path/to/googletest/include

修改CORE_CC_EXCLUDE_SRCS变量

CORE_CC_EXCLUDE_SRCS := \
$(wildcard tensorflow/contrib/lite/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) 
##$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \   把这行注释掉

另外还需要添加

CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/core/platform/default/logging.cc) \
$(wildcard tensorflow/core/platform/env_time.cc)

到此,对Makefile的修改就完成了,运行./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh编译可生成新的libtensorflow-lite.a库。

Tips: 注意上述编译的库在被conv_test.cc调用时会出错,原因是env_time.cc中的 EnvTime类部分函数还没有实现,这个可以自己把相关函数实现下,可以参考 tensorflow/core/platform/posix/env_time.cc的实现方式。

编译conv_test.cc

接下来编译conv_test.cc,编译CMAKE的时候遇到一些坑,下面是填坑后的完整CMakeLists.txt

cmake_minimum_required(VERSION 3.0)
add_definitions(-std=c++11)  #must use c++11
set(CMAKE_SYSTEM_PROCESSOR aarch64)
set(GCC_COMPILER_VERSION "" STRING "GCC Compiler version")

SET(CMAKE_C_COMPILER   aarch64-linux-gnu-gcc) 
SET(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++) 
find_package(Threads)
SET(CMAKE_BUILD_TYPE "Release")
#set(CMAKE_EXE_LINKER_FLAGS "-lpthread -lrt -ldl") #special for tflite compile
INCLUDE_DIRECTORIES("/path/to/tflite_lib/include")
INCLUDE_DIRECTORIES("/path/to/googletest/include")
INCLUDE_DIRECTORIES("/path/to/absl")

LINK_DIRECTORIES("/path/to/tflite/lib")
LINK_DIRECTORIES("/path/to/googletest/lib")

add_executable(ConvUnitTest conv_test.cc)
target_link_libraries(ConvUnitTest libtensorflow-lite.a libgtest.a libgmock.a ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS})

如果需要对其他的op进行单元测试,则把对应的op_test.cc替换掉上面的conv_test.cc即可。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
要在React Native项目加入TensorFlow Lite,需要执行以下步骤: 1. 安装TensorFlow Lite:您可以使用以下命令在React Native项目安装TensorFlow Lite: ``` npm install @tensorflow/tfjs @tensorflow/tfjs-react-native @tensorflow/[email protected] ``` 2. 将TensorFlow Lite模型添加到项目:将TensorFlow Lite模型文件(.tflite)复制到React Native项目的assets文件夹。 3. 在React Native应用程序加载TensorFlow Lite模型:您可以使用以下代码加载TensorFlow Lite模型: ```javascript import { load } from "@tensorflow/tfjs-react-native"; async function loadModel() { const modelJson = require("./assets/model.json"); const modelWeights = require("./assets/model_weights.bin"); const model = await load({ modelUrl: modelJson, weightsUrl: modelWeights, }); return model; } ``` 此代码将加载您的TensorFlow Lite模型文件(model.json和model_weights.bin)并返回一个TensorFlow模型对象。 4. 使用TensorFlow Lite模型进行推理:您可以使用以下代码将输入数据传递给TensorFlow Lite模型进行推理: ```javascript const inputTensor = tf.tensor2d([inputData]); // inputData是您的输入数据 const outputTensor = model.predict(inputTensor); const outputData = outputTensor.dataSync(); outputTensor.dispose(); ``` 此代码将创建一个输入张量对象,将其传递给TensorFlow Lite模型进行推理,并返回一个输出张量对象。然后,您可以使用outputTensor.dataSync()方法从输出张量对象提取结果。 以上是在React Native项目加入TensorFlow Lite的基本步骤,您可以根据自己的应用场景进行调整和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值