3.Android TensorFlow lite 3分种实现百度网盘图片分类 (人工智能)

13 篇文章 0 订阅
8 篇文章 0 订阅
我们都知道,谷歌有一个开源库叫做TensorFlow,可被用在安卓系统中实现机器学习。换言之,TensorFlow是谷歌为机器智能提供的一个开源软件库。
 
 
TensorFlow:
1.模型计算,训练模型
2.推理
 
 
主要作用:
当我们有一个已经训练好的TF模型的时候,我们怎么去调用这个模型并且让他顺利在Android平台上运行起来呢?
大概包括这几个方面: 
1、 保存训练完毕的TF模型 
2、 在Android项目中导入TF模型、导入Android平台调用TF模型需要的jar包和so文件 (它们负责TF模型的解析和运算) 
3、定义变量、存储数据,通过jar包提供的接口进行模型的调用
————————————————
 
 
 

使用定制化的图片分类器

必须要有一个预训练模型文件和一个用于分类的标签文件。

<span style="color:#cccccc"><code class="language-css">mobilenet_v1_1.0_224.ckpt.index mobilenet_v1_1.0_224_frozen.pb
mobilenet_v1_1.0_224.ckpt.meta mobilenet_v1_1.0_224_info.txt</code></span>
 
源码地址:
 
总结:
1.自己编译so库。或者自己用源码
2.模型和标签存放
3.jni调用
 

官方教程:

 

1. 编译运行 example 项目(通过bazel指令)

找到下载的 tensorflow 目录,其中 tensorflow/tensorflow/examples/android 路径下,是 Google 提供的 example 项目 for Android,我们可以通过编译运行该项目了解在 Android 平台上实现的几个典型的应用场景实力。

通过 Android Studio 进行编译

使用 Android Studio 找到对应的项目目录打开,修改 ndk 的环境配置后,点击编译运行即可。

通过命令行进行

通过命令行进入 tensorflow 目录,也就在 WORKSPACE 所在的项目根目录。

1. 编译项目

<span style="color:#cccccc"><code class="language-source-shell">bazel build -c opt //tensorflow/examples/android:tensorflow_demo</code></span>

 
2. 安装 Bazel

Bazel 是 Google 开源的构建工具,Bazel 据说是数倍于 maven 的性能,快速的增量构建是它的特色,目前支持 java、cpp、Go 等语言。TensorFlow 项目的构建依赖于它来构建,更多关于 Bazel 的相关内容可以了解Bazel 官网

  1. 通过 Homebrew 安装 Bazel
  2. brew install bazel
  1. 安装完成后验证是否能正常运行
  2. bazel version
  1. 如果要检查版本更新
  2. brew upgrade bazel

 

安装构建工具bazel()

官方文档地址:https://bazel.build/versions/master/docs/install.html 。

安装JDK 8,并添加APT库,然后安装bazel

 

2. 编译运行 example 项目(使用Android Studio)

可以在Android Studio中直接打开tensorflow/examples/android,但是需要配置好你的gradle、sdk、ndk

  1. gradle必须要在3.3以上版本
  2. build_tool_version需要指定在25以上
  3. 配置好你的ndk
  4. 很重要:在android:build.gradle 中配置好你的bazel路径,默认的是如下的:
def bazelLocation = '/usr/local/bin/bazel'

但是在该路径下并没有bazel,因此会报如下错误:

Error:Execution failed for task ':buildNativeBazel'.
A problem occurred starting process 'command '/usr/local/bin/bazel''

此时只需要把路径改成你的bazel路径就好,比如我的是:

def bazelLocation = '/usr/bin/bazel'

这样配置完毕后,就能直接在AS中使用run去运行了。

 

 

注意:NDK 版本要设置为 14b ,如果选择 NDK 16 版本编译会出现未知问题!

 

Install Bazel and Android Prerequisites

Bazel is the primary build system for TensorFlow. To build with Bazel, it and the Android NDK and SDK must be installed on your system.

  1. Install the latest version of Bazel as per the instructions on the Bazel website.
  2. The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found here.
  3. The Android SDK and build tools may be obtained here, or alternatively as part of Android Studio. Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices).
 
 
cmake:
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

project(TENSORFLOW_DEMO)
cmake_minimum_required(VERSION 3.4.1)

set(CMAKE_VERBOSE_MAKEFILE on)

get_filename_component(TF_SRC_ROOT ${CMAKE_SOURCE_DIR}/../../../..  ABSOLUTE)
get_filename_component(SAMPLE_SRC_DIR  ${CMAKE_SOURCE_DIR}/..  ABSOLUTE)

if (ANDROID_ABI MATCHES "^armeabi-v7a$")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfloat-abi=softfp -mfpu=neon")
elseif(ANDROID_ABI MATCHES "^arm64-v8a")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -ftree-vectorize")
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSTANDALONE_DEMO_LIB \
                    -std=c++11 -fno-exceptions -fno-rtti -O2 -Wno-narrowing \
                    -fPIE")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
                              -Wl,--allow-multiple-definition \
                              -Wl,--whole-archive -fPIE -v")

file(GLOB_RECURSE tensorflow_demo_sources ${SAMPLE_SRC_DIR}/jni/*.*)
add_library(tensorflow_demo SHARED
            ${tensorflow_demo_sources})
target_include_directories(tensorflow_demo PRIVATE
                           ${TF_SRC_ROOT}
                           ${CMAKE_SOURCE_DIR})

target_link_libraries(tensorflow_demo
                      android
                      log
                      jnigraphics
                      m
                      atomic
                      z)
 
直接使用的库:
dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'androidx.appcompat:appcompat:1.1.0'
    implementation 'androidx.coordinatorlayout:coordinatorlayout:1.1.0'
    implementation 'com.google.android.material:material:1.0.0'
 
    implementation 'org.tensorflow:tensorflow-lite:2.0.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.0.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
    implementation group: 'org.tensorflow', name: 'tensorflow-lite', version: '2.0.0'
}
 
 

第一步:在build.gradle里添加依赖compile 'org.tensorflow:tensorflow-android:+'即可

第二步:调用TensorFlow接口并进行使用;

 
————————————————
自己编译的demo:
 
so库:
 
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package com.example.android.tflitecamerademo;

import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;

/** Classifies images with Tensorflow Lite. */
public class ImageClassifier {

  /** Tag for the {@link Log}. */
  private static final String TAG = "TfLiteCameraDemo";

  /** Name of the model file stored in Assets. */
  private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";

  /** Name of the label file stored in Assets. */
  private static final String LABEL_PATH = "labels.txt";

  /** Number of results to show in the UI. */
  private static final int RESULTS_TO_SHOW = 3;

  /** Dimensions of inputs. */
  private static final int DIM_BATCH_SIZE = 1;

  private static final int DIM_PIXEL_SIZE = 3;

  static final int DIM_IMG_SIZE_X = 224;
  static final int DIM_IMG_SIZE_Y = 224;

  /* Preallocated buffers for storing image data in. */
  private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];

  /** An instance of the driver class to run model inference with Tensorflow Lite. */
  private Interpreter tflite;

  /** Labels corresponding to the output of the vision model. */
  private List<String> labelList;

  /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
  private ByteBuffer imgData = null;

  /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
  private byte[][] labelProbArray = null;

  private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
      new PriorityQueue<>(
          RESULTS_TO_SHOW,
          new Comparator<Map.Entry<String, Float>>() {
            @Override
            public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
              return (o1.getValue()).compareTo(o2.getValue());
            }
          });

  /** Initializes an {@code ImageClassifier}. */
  ImageClassifier(Activity activity) throws IOException {
    tflite = new Interpreter(loadModelFile(activity));//加载模型
    labelList = loadLabelList(activity);//加载类别
    imgData =
        ByteBuffer.allocateDirect(
            DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);//1*224*224*3
    imgData.order(ByteOrder.nativeOrder());
    labelProbArray = new byte[1][labelList.size()];
    Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
  }

  /** Classifies a frame from the preview stream. */
  String classifyFrame(Bitmap bitmap) {
    if (tflite == null) {
      Log.e(TAG, "Image classifier has not been initialized; Skipped.");
      return "Uninitialized Classifier.";
    }
    convertBitmapToByteBuffer(bitmap);
    // Here's where the magic happens!!!
    long startTime = SystemClock.uptimeMillis();
    tflite.run(imgData, labelProbArray);
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
    String textToShow = printTopKLabels();
    textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
    return textToShow;
  }

  /** Closes tflite to release resources. */
  public void close() {
    tflite.close();
    tflite = null;
  }

  /** Reads label list from Assets. */
  private List<String> loadLabelList(Activity activity) throws IOException {
    List<String> labelList = new ArrayList<String>();
    BufferedReader reader =
        new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
    String line;
    while ((line = reader.readLine()) != null) {
      labelList.add(line);
    }
    reader.close();
    return labelList;
  }

  /** Memory-map the model file in Assets. */
  private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }

  /** Writes Image data into a {@code ByteBuffer}. */
  private void convertBitmapToByteBuffer(Bitmap bitmap) {
    if (imgData == null) {
      return;
    }
    imgData.rewind();
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    // Convert the image to floating point.
    int pixel = 0;
    long startTime = SystemClock.uptimeMillis();
    for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
      for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
        final int val = intValues[pixel++];
        imgData.put((byte) ((val >> 16) & 0xFF));
        imgData.put((byte) ((val >> 8) & 0xFF));
        imgData.put((byte) (val & 0xFF));
      }
    }
    long endTime = SystemClock.uptimeMillis();
    Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
  }

  /** Prints top-K labels, to be shown in UI as the results. */
  private String printTopKLabels() {
    for (int i = 0; i < labelList.size(); ++i) {
      sortedLabels.add(
          new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
      if (sortedLabels.size() > RESULTS_TO_SHOW) {
        sortedLabels.poll();
      }
    }
    String textToShow = "";
    final int size = sortedLabels.size();
    for (int i = 0; i < size; ++i) {
      Map.Entry<String, Float> label = sortedLabels.poll();
      textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow;
    }
    return textToShow;
  }
}
 
参考博客:
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
使用TensorFlow Lite可以快速构建移动端声音分类应用。TensorFlow Lite是一个针对移动设备和嵌入式设备优化的轻量级机器学习框架,它具有高效的推理性能和较小的模型尺寸。 首先,我们需要准备声音分类的训练数据集。可以选择一个合适的公开数据集,例如UrbanSound8K,其中包含了来自不同环境的各种声音样本。然后,我们使用TensorFlow构建和训练一个声音分类模型,可以选择常见的模型架构,如卷积神经网络。 接下来,我们使用TensorFlow提供的工具将训练好的模型转换为TensorFlow Lite的模型格式。这可以通过使用TensorFlow的Converter API实现,其中应用了优化技术来减小模型的尺寸和优化推理性能。 一旦我们获得了TensorFlow Lite模型,我们可以将其集成到移动端应用程序中。可以使用Java或Kotlin编写Android应用,或使用Swift编写iOS应用。在应用程序中,我们使用TensorFlow Lite解析器来加载模型并进行声音分类。该解析器提供了简单的API来输入音频数据并获得分类结果。 最后,为了提高移动端应用的性能,可以考虑对模型进行量化(quantization),将浮点数模型转换为整数模型,以减小内存占用和加速推理速度。TensorFlow Lite还提供了一些优化技术,如模型矩阵压缩和模型分割,进一步提升了性能。 总之,使用TensorFlow Lite可以快速构建高性能、低资源占用的移动端声音分类应用。通过合理选择数据集、训练模型,并应用TensorFlow Lite的优化技术,我们可以在移动设备上实现实时声音分类

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值