Kotlin与机器学习实战:Android端集成TensorFlow Lite全指南

本文将手把手教你如何在Android应用中集成TensorFlow Lite模型,实现端侧机器学习推理能力。我们以图像分类场景为例,提供可直接运行的完整代码示例。


环境准备

1. 开发环境要求

  • Android Studio Arctic Fox以上版本
  • AGP 7.0+
  • Kotlin 1.6+
  • Minimum SDK 21

2. 添加Gradle依赖

// build.gradle.kts
android {
    aaptOptions {
        noCompress "tflite" // 防止模型文件被压缩
    }
}

dependencies {
    // TFLite核心库
    implementation("org.tensorflow:tensorflow-lite:2.12.0")
    implementation("org.tensorflow:tensorflow-lite-gpu:2.12.0") // GPU支持
    implementation("org.tensorflow:tensorflow-lite-support:0.4.4")
    
    // 相机扩展库(可选)
    implementation("androidx.camera:camera-core:1.3.0")
    implementation("androidx.camera:camera-lifecycle:1.3.0")
    implementation("androidx.camera:camera-view:1.3.0")
    
    // 协程支持
    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
}

完整实现流程

步骤1:模型文件处理

将训练好的.tflite模型文件放入app/src/main/assets目录,建议同时包含labels.txt标签文件

app/src/main/assets/
├── mobilenet_v1_1.0_224_quant.tflite
└── labels.txt

步骤2:核心分类器实现

import android.content.Context
import android.graphics.Bitmap
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.task.vision.classifier.ImageClassifier

class TFLiteImageClassifier(
    context: Context,
    modelPath: String = "mobilenet_v1_1.0_224_quant.tflite",
    labelPath: String = "labels.txt",
    private val threadNum: Int = 4
) {
    private var classifier: ImageClassifier? = null
    private val labels: List<String>

    init {
        // 加载标签文件
        labels = context.assets.open(labelPath).bufferedReader().useLines { it.toList() }

        // 配置分类器选项
        val options = ImageClassifier.ImageClassifierOptions.builder()
            .setMaxResults(3)
            .setNumThreads(threadNum)
            .setDelegate(Delegate.GPU) // 优先尝试GPU加速
            .build()

        try {
            classifier = ImageClassifier.createFromFileAndOptions(
                context, 
                modelPath,
                options
            )
        } catch (e: IllegalStateException) {
            // GPU失败时回退CPU
            options.setDelegate(Delegate.CPU)
            classifier = ImageClassifier.createFromFileAndOptions(
                context,
                modelPath,
                options
            )
        }
    }

    fun classify(bitmap: Bitmap): List<Pair<String, Float>> {
        val imageProcessor = ImageProcessor.Builder()
            .add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
            .add(NormalizeOp(127.5f, 127.5f)) // 根据模型类型调整
            .build()

        val tensorImage = imageProcessor.process(
            TensorImage.fromBitmap(bitmap)
        )

        val results = classifier?.classify(tensorImage) ?: return emptyList()

        return results[0].categories.map {
            val label = labels.getOrNull(it.index) ?: "Unknown"
            label to it.score
        }
    }

    fun close() {
        classifier?.close()
    }
}

步骤3:UI界面实现

activity_main.xml
<androidx.constraintlayout.widget.ConstraintLayout
    xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    android:layout_width="match_parent"
    android:layout_height="match_parent">

    <androidx.camera.view.PreviewView
        android:id="@+id/cameraPreview"
        android:layout_width="300dp"
        android:layout_height="300dp"
        app:layout_constraintTop_toTopOf="parent"
        app:layout_constraintStart_toStartOf="parent"/>

    <ImageView
        android:id="@+id/ivPreview"
        android:layout_width="300dp"
        android:layout_height="300dp"
        app:layout_constraintTop_toTopOf="parent"
        app:layout_constraintEnd_toEndOf="parent"/>

    <Button
        android:id="@+id/btnCapture"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="拍照识别"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintStart_toStartOf="parent"/>

    <Button
        android:id="@+id/btnSelect"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="图库选择"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"/>

    <TextView
        android:id="@+id/tvResult"
        android:layout_width="0dp"
        android:layout_height="wrap_content"
        android:padding="16dp"
        android:textSize="18sp"
        app:layout_constraintTop_toBottomOf="@id/cameraPreview"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintEnd_toEndOf="parent"/>

</androidx.constraintlayout.widget.ConstraintLayout>

步骤4:主Activity实现(CameraX集成版)

@RequiresApi(Build.VERSION_CODES.M)
class MainActivity : AppCompatActivity() {
    private lateinit var classifier: TFLiteImageClassifier
    private lateinit var cameraExecutor: ExecutorService
    private var imageCapture: ImageCapture? = null

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        cameraExecutor = Executors.newSingleThreadExecutor()
        classifier = TFLiteImageClassifier(this)

        // 请求相机权限
        if (allPermissionsGranted()) {
            startCamera()
        } else {
            ActivityCompat.requestPermissions(
                this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS
            )
        }

        // 拍照按钮点击
        btnCapture.setOnClickListener {
            takePhoto()
        }

        // 图库选择
        btnSelect.setOnClickListener {
            val intent = Intent(Intent.ACTION_GET_CONTENT).apply {
                type = "image/*"
            }
            startActivityForResult(intent, REQUEST_IMAGE_PICK)
        }
    }

    private fun takePhoto() {
        val imageCapture = imageCapture ?: return

        val outputOptions = ImageCapture.OutputFileOptions
            .Builder(File.createTempFile("ML_TEMP", ".jpg", cacheDir))
            .build()

        imageCapture.takePicture(
            outputOptions,
            ContextCompat.getMainExecutor(this),
            object : ImageCapture.OnImageSavedCallback {
                override fun onImageSaved(output: ImageCapture.OutputFileResults) {
                    val uri = output.savedUri ?: return
                    processImage(uri)
                }

                override fun onError(exc: ImageCaptureException) {
                    Log.e(TAG, "拍照失败: ${exc.message}", exc)
                }
            }
        )
    }

    private fun processImage(uri: Uri) {
        lifecycleScope.launch(Dispatchers.IO) {
            try {
                val bitmap = contentResolver.loadThumbnail(
                    uri, Size(224, 224), null
                )
                
                val results = classifier.classify(bitmap)
                
                withContext(Dispatchers.Main) {
                    ivPreview.setImageBitmap(bitmap)
                    showResults(results)
                }
            } catch (e: Exception) {
                Log.e(TAG, "图片处理失败", e)
            }
        }
    }

    private fun showResults(results: List<Pair<String, Float>>) {
        val output = buildString {
            append("识别结果:\n")
            results.forEach { (label, confidence) ->
                append("${label}: ${"%.2f".format(confidence * 100)}%\n")
            }
        }
        tvResult.text = output
    }

    // CameraX初始化
    private fun startCamera() {
        val cameraProviderFuture = ProcessCameraProvider.getInstance(this)

        cameraProviderFuture.addListener({
            val cameraProvider = cameraProviderFuture.get()
            val preview = Preview.Builder()
                .build()
                .also { it.setSurfaceProvider(cameraPreview.surfaceProvider) }

            imageCapture = ImageCapture.Builder()
                .setCaptureMode(ImageCapture.CAPTURE_MODE_MINIMIZE_LATENCY)
                .build()

            try {
                cameraProvider.unbindAll()
                cameraProvider.bindToLifecycle(
                    this, CameraSelector.DEFAULT_BACK_CAMERA, preview, imageCapture)
            } catch (exc: Exception) {
                Log.e(TAG, "相机初始化失败", exc)
            }
        }, ContextCompat.getMainExecutor(this))
    }

    // 权限处理
    override fun onRequestPermissionsResult(
        requestCode: Int, 
        permissions: Array<String>, 
        grantResults: IntArray
    ) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults)
        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera()
            } else {
                Toast.makeText(this, "需要相机权限", Toast.LENGTH_SHORT).show()
                finish()
            }
        }
    }

    companion object {
        private const val TAG = "MLDemo"
        private const val REQUEST_CODE_PERMISSIONS = 10
        private const val REQUEST_IMAGE_PICK = 101
        private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA)
        
        private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all {
            ContextCompat.checkSelfPermission(context, it) == PackageManager.PERMISSION_GRANTED
        }
    }
}

高级优化技巧

1. 性能监控

class BenchmarkHelper {
    fun measureInference(bitmap: Bitmap) {
        val warmupRuns = 10
        val benchmarkRuns = 100
        
        // 预热
        repeat(warmupRuns) {
            classifier.classify(bitmap)
        }
        
        // 正式测试
        val start = SystemClock.elapsedRealtime()
        repeat(benchmarkRuns) {
            classifier.classify(bitmap)
        }
        val avgTime = (SystemClock.elapsedRealtime() - start) / benchmarkRuns.toFloat()
        
        Log.d("Benchmark", "平均推理时间: ${avgTime}ms")
    }
}

2. 模型动态更新

private fun downloadAndUpdateModel(modelUrl: String) {
    lifecycleScope.launch(Dispatchers.IO) {
        try {
            val tempFile = File.createTempFile("model", ".tflite")
            
            Retrofit.Builder()
                .baseUrl("https://your-model-server/")
                .build()
                .create(ModelService::class.java)
                .downloadModel(modelUrl)
                .enqueue(object : Callback<ResponseBody> {
                    override fun onResponse(call: Call<ResponseBody>, response: Response<ResponseBody>) {
                        response.body()?.byteStream()?.use { input ->
                            tempFile.outputStream().use { output ->
                                input.copyTo(output)
                            }
                        }
                        classifier.updateModel(tempFile)
                    }
                    
                    override fun onFailure(call: Call<ResponseBody>, t: Throwable) {
                        Log.e("ModelUpdate", "下载失败", t)
                    }
                })
        } catch (e: Exception) {
            Log.e("ModelUpdate", "更新失败", e)
        }
    }
}

常见问题解决方案

问题1:输入尺寸不匹配

解决方案

val inputTensor = classifier.getInputTensor(0)
val inputShape = inputTensor.shape() // 获取实际输入尺寸
val dataType = inputTensor.dataType()

// 动态调整预处理
val resizeOp = when (dataType) {
    DataType.UINT8 -> ResizeWithCropOrPadOp(inputShape[1], inputShape[2])
    DataType.FLOAT32 -> ResizeOp(inputShape[1], inputShape[2], ResizeMethod.BILINEAR)
    else -> throw IllegalArgumentException("不支持的输入类型")
}

问题2:内存泄漏

预防措施

override fun onDestroy() {
    super.onDestroy()
    classifier.close()
    cameraExecutor.shutdown()
}

扩展应用方向

实时视频流处理

class VideoAnalyzer(private val classifier: TFLiteImageClassifier) : ImageAnalysis.Analyzer {
    private val frameCounter = AtomicInteger(0)
    private val skipFrame = 3 // 控制处理频率

    override fun analyze(imageProxy: ImageProxy) {
        if (frameCounter.getAndIncrement() % skipFrame != 0) {
            imageProxy.close()
            return
        }

        val bitmap = imageProxy.toBitmap() // 实现ImageProxy转Bitmap
        lifecycleScope.launch(Dispatchers.Default) {
            val results = classifier.classify(bitmap)
            updateUI(results)
            imageProxy.close()
        }
    }
}

最佳实践建议

  1. 模型优化

  2. 性能平衡

    • 根据设备性能动态选择推理后端(CPU/GPU/NNAPI)
    • 针对低端设备启用XNNPACK优化:
    ImageClassifierOptions.builder()
        .setComputeSettings(
            ComputeSettings.builder()
                .setDelegate(Delegate.XNNPACK)
        .build()
    
  3. 安全防护

    // 模型完整性校验
    fun verifyModel(file: File): Boolean {
        val expectedHash = "a1b2c3d4..." // 预计算SHA256
        return FileUtils.calculateSHA256(file) == expectedHash
    }
    

建议结合具体业务需求选择合适的模型,并通过性能分析工具持续优化推理流程。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值