【Android机器学习实战】2、用 TensorFlowLite 做目标检测

TensorFlow Lite 是一个跨平台机器学习库,针对在边缘设备(包括 Android 和 iOS 移动设备)上运行机器学习模型进行了优化。其包含如下2个组件:

  • Model Maker 是一个 Python 库,只需几行代码即可让您轻松地使用自己的数据训练 TensorFlow Lite 模型,而无需具备机器学习专业知识。
  • Task 库是一个跨平台库,通过在您的移动应用中只需几行代码即可轻松部署 TensorFlow Lite 模型。

一、添加 TFLite 库

创建 TFLiteDetectionTest 项目,项目github地址详见

在 activity_main.xml 中添加如下布局:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context="org.tensorflow.codelabs.objectdetection.MainActivity">

    <FrameLayout
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:layout_above="@+id/tvDescription">

        <TextView
            android:id="@+id/tvPlaceholder"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:layout_gravity="center"
            android:text="@string/tv_placeholder"
            android:textAlignment="center"
            android:textSize="40sp" />

        <ImageView
            android:id="@+id/imageView"
            android:layout_width="match_parent"
            android:layout_height="match_parent"
            android:adjustViewBounds="true"
            android:contentDescription="@null"
            android:scaleType="fitCenter" />
    </FrameLayout>

    <TextView
        android:id="@+id/tvDescription"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_above="@id/llSampleImage"
        android:layout_marginStart="16dp"
        android:layout_marginTop="10dp"
        android:layout_marginBottom="10dp"
        android:text="@string/tv_description" />

    <LinearLayout
        android:id="@+id/llSampleImage"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_above="@+id/captureImageFab"
        android:gravity="center_horizontal"
        android:orientation="horizontal">

        <ImageView
            android:id="@+id/imgSampleOne"
            android:layout_width="80dp"
            android:layout_height="80dp"
            android:adjustViewBounds="true"
            android:contentDescription="@null"
            android:scaleType="centerCrop"
            android:src="@drawable/img_meal_one" />

        <ImageView
            android:id="@+id/imgSampleTwo"
            android:layout_width="80dp"
            android:layout_height="80dp"
            android:layout_marginStart="16dp"
            android:layout_marginEnd="16dp"
            android:adjustViewBounds="true"
            android:contentDescription="@null"
            android:scaleType="centerCrop"
            android:src="@drawable/img_meal_two" />

        <ImageView
            android:id="@+id/imgSampleThree"
            android:layout_width="80dp"
            android:layout_height="80dp"
            android:adjustViewBounds="true"
            android:contentDescription="@null"
            android:scaleType="centerCrop"
            android:src="@drawable/img_meal_three" />
    </LinearLayout>

    <Button
        android:id="@+id/captureImageFab"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:layout_centerHorizontal="true"
        android:layout_gravity="center|bottom"
        android:layout_margin="16dp"
        android:drawableLeft="@drawable/ic_camera"
        android:text="@string/tv_take_photo"
        android:textAllCaps="false" />

</RelativeLayout>

activity_main.xml 的布局有上部1个图片展示控件、下部3个图片预览控件,底部1个拍照按钮,布局效果如下:

在这里插入图片描述

在 build.gradle(app) 中添加 tensorflow 的依赖,依赖如下:

implementation 'org.tensorflow:tensorflow-lite-task-vision:0.3.1'

目标检测的模型文件可在此处下载,该模型是用 COCO2017数据集训练的,我们将模型文件命名为 model.tflite 并放在 assets 目录中,效果如下:

在这里插入图片描述

接下来,在 MainActivity 中实现基本的图片布局和按钮点击时间,代码如下:

/**
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software v
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tensorflow.codelabs.objectdetection

import android.app.Activity
import android.content.ActivityNotFoundException
import android.content.Intent
import android.graphics.*
import android.net.Uri
import android.os.Bundle
import android.os.Environment
import android.provider.MediaStore
import android.util.Log
import android.view.View
import android.widget.Button
import android.widget.ImageView
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.FileProvider
import androidx.exifinterface.media.ExifInterface
import androidx.lifecycle.lifecycleScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
import java.util.*
import kotlin.math.max
import kotlin.math.min

class MainActivity : AppCompatActivity(), View.OnClickListener {
    companion object {
        const val TAG = "TFLite - ODT"
        const val REQUEST_IMAGE_CAPTURE: Int = 1
        private const val MAX_FONT_SIZE = 96F
    }

    private lateinit var captureImageFab: Button
    private lateinit var inputImageView: ImageView
    private lateinit var imgSampleOne: ImageView
    private lateinit var imgSampleTwo: ImageView
    private lateinit var imgSampleThree: ImageView
    private lateinit var tvPlaceholder: TextView
    private lateinit var currentPhotoPath: String

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        captureImageFab = findViewById(R.id.captureImageFab)
        inputImageView = findViewById(R.id.imageView)
        imgSampleOne = findViewById(R.id.imgSampleOne)
        imgSampleTwo = findViewById(R.id.imgSampleTwo)
        imgSampleThree = findViewById(R.id.imgSampleThree)
        tvPlaceholder = findViewById(R.id.tvPlaceholder)

        captureImageFab.setOnClickListener(this)
        imgSampleOne.setOnClickListener(this)
        imgSampleTwo.setOnClickListener(this)
        imgSampleThree.setOnClickListener(this)
    }

    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        super.onActivityResult(requestCode, resultCode, data)
        if (requestCode == REQUEST_IMAGE_CAPTURE &&
            resultCode == Activity.RESULT_OK
        ) {
            setViewAndDetect(getCapturedImage())
        }
    }

    /**
     * onClick(v: View?)
     *      Detect touches on the UI components
     */
    override fun onClick(v: View?) {
        when (v?.id) {
            R.id.captureImageFab -> {
                try {
                    dispatchTakePictureIntent()
                } catch (e: ActivityNotFoundException) {
                    Log.e(TAG, e.message.toString())
                }
            }
            R.id.imgSampleOne -> {
                setViewAndDetect(getSampleImage(R.drawable.img_meal_one))
            }
            R.id.imgSampleTwo -> {
                setViewAndDetect(getSampleImage(R.drawable.img_meal_two))
            }
            R.id.imgSampleThree -> {
                setViewAndDetect(getSampleImage(R.drawable.img_meal_three))
            }
        }
    }

    /**
     * runObjectDetection(bitmap: Bitmap)
     *      TFLite Object Detection function
     */
    private fun runObjectDetection(bitmap: Bitmap) {
        //TODO: Add object detection code here
    }

    /**
     * setViewAndDetect(bitmap: Bitmap)
     *      Set image to view and call object detection
     */
    private fun setViewAndDetect(bitmap: Bitmap) {
        // Display capture image
        inputImageView.setImageBitmap(bitmap)
        tvPlaceholder.visibility = View.INVISIBLE

        // Run ODT and display result
        // Note that we run this in the background thread to avoid blocking the app UI because
        // TFLite object detection is a synchronised process.
        lifecycleScope.launch(Dispatchers.Default) { runObjectDetection(bitmap) }
    }

    /**
     * getCapturedImage():
     *      Decodes and crops the captured image from camera.
     */
    private fun getCapturedImage(): Bitmap {
        // Get the dimensions of the View
        val targetW: Int = inputImageView.width
        val targetH: Int = inputImageView.height

        val bmOptions = BitmapFactory.Options().apply {
            // Get the dimensions of the bitmap
            inJustDecodeBounds = true

            BitmapFactory.decodeFile(currentPhotoPath, this)

            val photoW: Int = outWidth
            val photoH: Int = outHeight

            // Determine how much to scale down the image
            val scaleFactor: Int = max(1, min(photoW / targetW, photoH / targetH))

            // Decode the image file into a Bitmap sized to fill the View
            inJustDecodeBounds = false
            inSampleSize = scaleFactor
            inMutable = true
        }
        val exifInterface = ExifInterface(currentPhotoPath)
        val orientation = exifInterface.getAttributeInt(
            ExifInterface.TAG_ORIENTATION,
            ExifInterface.ORIENTATION_UNDEFINED
        )

        val bitmap = BitmapFactory.decodeFile(currentPhotoPath, bmOptions)
        return when (orientation) {
            ExifInterface.ORIENTATION_ROTATE_90 -> {
                rotateImage(bitmap, 90f)
            }
            ExifInterface.ORIENTATION_ROTATE_180 -> {
                rotateImage(bitmap, 180f)
            }
            ExifInterface.ORIENTATION_ROTATE_270 -> {
                rotateImage(bitmap, 270f)
            }
            else -> {
                bitmap
            }
        }
    }

    /**
     * getSampleImage():
     *      Get image form drawable and convert to bitmap.
     */
    private fun getSampleImage(drawable: Int): Bitmap {
        return BitmapFactory.decodeResource(resources, drawable, BitmapFactory.Options().apply {
            inMutable = true
        })
    }

    /**
     * rotateImage():
     *     Decodes and crops the captured image from camera.
     */
    private fun rotateImage(source: Bitmap, angle: Float): Bitmap {
        val matrix = Matrix()
        matrix.postRotate(angle)
        return Bitmap.createBitmap(
            source, 0, 0, source.width, source.height,
            matrix, true
        )
    }

    /**
     * createImageFile():
     *     Generates a temporary image file for the Camera app to write to.
     */
    @Throws(IOException::class)
    private fun createImageFile(): File {
        // Create an image file name
        val timeStamp: String = SimpleDateFormat("yyyyMMdd_HHmmss").format(Date())
        val storageDir: File? = getExternalFilesDir(Environment.DIRECTORY_PICTURES)
        return File.createTempFile(
            "JPEG_${timeStamp}_", /* prefix */
            ".jpg", /* suffix */
            storageDir /* directory */
        ).apply {
            // Save a file: path for use with ACTION_VIEW intents
            currentPhotoPath = absolutePath
        }
    }

    /**
     * dispatchTakePictureIntent():
     *     Start the Camera app to take a photo.
     */
    private fun dispatchTakePictureIntent() {
        Intent(MediaStore.ACTION_IMAGE_CAPTURE).also { takePictureIntent ->
            // Ensure that there's a camera activity to handle the intent
            takePictureIntent.resolveActivity(packageManager)?.also {
                // Create the File where the photo should go
                val photoFile: File? = try {
                    createImageFile()
                } catch (e: IOException) {
                    Log.e(TAG, e.message.toString())
                    null
                }
                // Continue only if the File was successfully created
                photoFile?.also {
                    val photoURI: Uri = FileProvider.getUriForFile(
                        this,
                        "org.tensorflow.codelabs.objectdetection.fileprovider",
                        it
                    )
                    takePictureIntent.putExtra(MediaStore.EXTRA_OUTPUT, photoURI)
                    startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE)
                }
            }
        }
    }

    /**
     * drawDetectionResult(bitmap: Bitmap, detectionResults: List<DetectionResult>
     *      Draw a box around each objects and show the object's name.
     */
    private fun drawDetectionResult(bitmap: Bitmap, detectionResults: List<DetectionResult>): Bitmap {
        val outputBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
        val canvas = Canvas(outputBitmap)
        val pen = Paint()
        pen.textAlign = Paint.Align.LEFT

        detectionResults.forEach {
            // draw bounding box
            pen.color = Color.RED
            pen.strokeWidth = 8F
            pen.style = Paint.Style.STROKE
            val box = it.boundingBox
            canvas.drawRect(box, pen)


            val tagSize = Rect(0, 0, 0, 0)

            // calculate the right font size
            pen.style = Paint.Style.FILL_AND_STROKE
            pen.color = Color.YELLOW
            pen.strokeWidth = 2F

            pen.textSize = MAX_FONT_SIZE
            pen.getTextBounds(it.text, 0, it.text.length, tagSize)
            val fontSize: Float = pen.textSize * box.width() / tagSize.width()

            // adjust the font size so texts are inside the bounding box
            if (fontSize < pen.textSize) pen.textSize = fontSize

            var margin = (box.width() - tagSize.width()) / 2.0F
            if (margin < 0F) margin = 0F
            canvas.drawText(
                it.text, box.left + margin,
                box.top + tagSize.height().times(1F), pen
            )
        }
        return outputBitmap
    }
}

/**
 * DetectionResult
 *      A class to store the visualization info of a detected object.
 */
data class DetectionResult(val boundingBox: RectF, val text: String)

运行后,可以展示、切换图片,并拍照更换图片,效果如下:

在这里插入图片描述

二、用 TFLite 对图片做目标检测

在 MainActivity 中初始化 TFLite 的 detector,并向其输入 image,得到输出的 List 结果,代码如下:

    /**
     * runObjectDetection(bitmap: Bitmap)
     *      TFLite Object Detection function
     */
    private fun runObjectDetection(bitmap: Bitmap) {
        val image = TensorImage.fromBitmap(bitmap)
        // Initialize the detector object
        val options = ObjectDetector.ObjectDetectorOptions.builder()
            .setMaxResults(5)
            .setScoreThreshold(0.5f)
            .build()
        val detector = ObjectDetector.createFromFileAndOptions(
            this, // the application context
            "model.tflite", // must be same as the filename in assets folder
            options
        )
        val results = detector.detect(image)
        debugPrint(results)
    }

    private fun debugPrint(results : List<Detection>) {
        for ((i, obj) in results.withIndex()) {
            val box = obj.boundingBox

            Log.d(TAG, "Detected object: $i ")
            Log.d(TAG, "  boundingBox: (${box.left}, ${box.top}) - (${box.right},${box.bottom})")

            for ((j, category) in obj.categories.withIndex()) {
                Log.d(TAG, "    Label $j: ${category.label}")
                val confidence: Int = category.score.times(100).toInt()
                Log.d(TAG, "    Confidence: ${confidence}%")
            }
        }
    }

运行后,可在Logcat 看到打印的日志如下:

在这里插入图片描述

其输出结果中:BoundingBox 是检测框的左上角和右下角,Label 是检测的类别,Confidence 是置信度,例如下例意为对象位于 (0, 15) – (2223, 1645) 的矩形内,标签为餐桌, 此模型确信第 1 个是餐桌 (77%):

D/TFLite-ODT: Detected object: 0
D/TFLite-ODT:   boundingBox: (0.0, 15.0) - (2223.0,1645.0)
D/TFLite-ODT:     Label 0: dining table
D/TFLite-ODT:     Confidence: 77%

接下来,在 MainActivity 中将识别结果的检测框和类别,渲染在界面上即可,效果如下:

     /**
     * runObjectDetection(bitmap: Bitmap)
     *      TFLite Object Detection function
     */
    private fun runObjectDetection(bitmap: Bitmap) {
        val image = TensorImage.fromBitmap(bitmap)
        // Initialize the detector object
        val options = ObjectDetector.ObjectDetectorOptions.builder()
            .setMaxResults(5)
            .setScoreThreshold(0.5f)
            .build()
        val detector = ObjectDetector.createFromFileAndOptions(
            this, // the application context
            "model.tflite", // must be same as the filename in assets folder
            options
        )
        val results = detector.detect(image)
        debugPrint(results)
        val resultToDisplay = results.map {
            // Get the top-1 category and craft the display text
            val category = it.categories.first()
            val text = "${category.label}, ${category.score.times(100).toInt()}%"

            // Create a data object to display the detection result
            DetectionResult(it.boundingBox, text)
        }
        // Draw the detection result on the bitmap and show it.
        val imgWithResult = drawDetectionResult(bitmap, resultToDisplay)
        runOnUiThread { inputImageView.setImageBitmap(imgWithResult) }
    }

运行后,每张图片均可检测出检测框和类别,效果如下:

在这里插入图片描述

为了优化效果,可在 模型链接地址 下载新模型,放置在 assets 目录中,优化算法效果,用新模型运行后,可更详细地推断图片中食材的类别,效果如下:

在这里插入图片描述

在ESP32上运行TensorFlow Lite模型通常涉及到几个步骤: 1. **下载模型**:首先从TensorFlow官方网站或其他来源获取经过优化的小型版本的TensorFlow Lite模型。这个模型应该适合嵌入式设备的资源限制。 2. **准备模型文件**:使用`tf.lite.Interpreter`类加载模型文件。这通常涉及创建一个`Interpreter`实例,并提供模型路径给`load_delegate()`函数(如果使用特定的加速器,如NEON或GPU Delegate)。 ```python import tensorflow as tf interpreter = tf.lite.Interpreter(model_path="model.tflite") ``` 3. **解析模型输入/输出**:查看模型的输入和输出详情,以便在实际应用中提供正确的数据格式。`get_input_details()`和`get_output_details()`方法可以得到这些信息。 4. **预处理输入数据**:根据模型需求准备好输入数据,将其转换为模型期望的格式并传递给`set_tensor()`方法。 5. **运行推理**:调用`invoke()`方法执行模型预测。这会根据模型结构和输入数据计算出结果。 6. **解读输出**:从模型输出中提取有用的信息,这可能需要依据模型的架构来解析。 7. **部署到ESP32**:如果你是在编写用于ESP32的固件或程序,确保使用支持的编程语言(如C++、MicroPython等),并将上述步骤融入到相应的循环或实时处理流程中。 请注意,由于ESP32的资源有限,可能会遇到内存管理的问题,因此需要针对模型大小和性能进行优化。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆的猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值