本文将手把手教你如何在Android应用中集成TensorFlow Lite模型,实现端侧机器学习推理能力。我们以图像分类场景为例,提供可直接运行的完整代码示例。
环境准备
1. 开发环境要求
- Android Studio Arctic Fox以上版本
- AGP 7.0+
- Kotlin 1.6+
- Minimum SDK 21
2. 添加Gradle依赖
// build.gradle.kts
android {
aaptOptions {
noCompress "tflite" // 防止模型文件被压缩
}
}
dependencies {
// TFLite核心库
implementation("org.tensorflow:tensorflow-lite:2.12.0")
implementation("org.tensorflow:tensorflow-lite-gpu:2.12.0") // GPU支持
implementation("org.tensorflow:tensorflow-lite-support:0.4.4")
// 相机扩展库(可选)
implementation("androidx.camera:camera-core:1.3.0")
implementation("androidx.camera:camera-lifecycle:1.3.0")
implementation("androidx.camera:camera-view:1.3.0")
// 协程支持
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
}
完整实现流程
步骤1:模型文件处理
将训练好的.tflite
模型文件放入app/src/main/assets
目录,建议同时包含labels.txt
标签文件
app/src/main/assets/
├── mobilenet_v1_1.0_224_quant.tflite
└── labels.txt
步骤2:核心分类器实现
import android.content.Context
import android.graphics.Bitmap
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.task.vision.classifier.ImageClassifier
class TFLiteImageClassifier(
context: Context,
modelPath: String = "mobilenet_v1_1.0_224_quant.tflite",
labelPath: String = "labels.txt",
private val threadNum: Int = 4
) {
private var classifier: ImageClassifier? = null
private val labels: List<String>
init {
// 加载标签文件
labels = context.assets.open(labelPath).bufferedReader().useLines { it.toList() }
// 配置分类器选项
val options = ImageClassifier.ImageClassifierOptions.builder()
.setMaxResults(3)
.setNumThreads(threadNum)
.setDelegate(Delegate.GPU) // 优先尝试GPU加速
.build()
try {
classifier = ImageClassifier.createFromFileAndOptions(
context,
modelPath,
options
)
} catch (e: IllegalStateException) {
// GPU失败时回退CPU
options.setDelegate(Delegate.CPU)
classifier = ImageClassifier.createFromFileAndOptions(
context,
modelPath,
options
)
}
}
fun classify(bitmap: Bitmap): List<Pair<String, Float>> {
val imageProcessor = ImageProcessor.Builder()
.add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5f, 127.5f)) // 根据模型类型调整
.build()
val tensorImage = imageProcessor.process(
TensorImage.fromBitmap(bitmap)
)
val results = classifier?.classify(tensorImage) ?: return emptyList()
return results[0].categories.map {
val label = labels.getOrNull(it.index) ?: "Unknown"
label to it.score
}
}
fun close() {
classifier?.close()
}
}
步骤3:UI界面实现
activity_main.xml
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent">
<androidx.camera.view.PreviewView
android:id="@+id/cameraPreview"
android:layout_width="300dp"
android:layout_height="300dp"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintStart_toStartOf="parent"/>
<ImageView
android:id="@+id/ivPreview"
android:layout_width="300dp"
android:layout_height="300dp"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintEnd_toEndOf="parent"/>
<Button
android:id="@+id/btnCapture"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="拍照识别"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintStart_toStartOf="parent"/>
<Button
android:id="@+id/btnSelect"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="图库选择"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"/>
<TextView
android:id="@+id/tvResult"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:padding="16dp"
android:textSize="18sp"
app:layout_constraintTop_toBottomOf="@id/cameraPreview"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"/>
</androidx.constraintlayout.widget.ConstraintLayout>
步骤4:主Activity实现(CameraX集成版)
@RequiresApi(Build.VERSION_CODES.M)
class MainActivity : AppCompatActivity() {
private lateinit var classifier: TFLiteImageClassifier
private lateinit var cameraExecutor: ExecutorService
private var imageCapture: ImageCapture? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
cameraExecutor = Executors.newSingleThreadExecutor()
classifier = TFLiteImageClassifier(this)
// 请求相机权限
if (allPermissionsGranted()) {
startCamera()
} else {
ActivityCompat.requestPermissions(
this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS
)
}
// 拍照按钮点击
btnCapture.setOnClickListener {
takePhoto()
}
// 图库选择
btnSelect.setOnClickListener {
val intent = Intent(Intent.ACTION_GET_CONTENT).apply {
type = "image/*"
}
startActivityForResult(intent, REQUEST_IMAGE_PICK)
}
}
private fun takePhoto() {
val imageCapture = imageCapture ?: return
val outputOptions = ImageCapture.OutputFileOptions
.Builder(File.createTempFile("ML_TEMP", ".jpg", cacheDir))
.build()
imageCapture.takePicture(
outputOptions,
ContextCompat.getMainExecutor(this),
object : ImageCapture.OnImageSavedCallback {
override fun onImageSaved(output: ImageCapture.OutputFileResults) {
val uri = output.savedUri ?: return
processImage(uri)
}
override fun onError(exc: ImageCaptureException) {
Log.e(TAG, "拍照失败: ${exc.message}", exc)
}
}
)
}
private fun processImage(uri: Uri) {
lifecycleScope.launch(Dispatchers.IO) {
try {
val bitmap = contentResolver.loadThumbnail(
uri, Size(224, 224), null
)
val results = classifier.classify(bitmap)
withContext(Dispatchers.Main) {
ivPreview.setImageBitmap(bitmap)
showResults(results)
}
} catch (e: Exception) {
Log.e(TAG, "图片处理失败", e)
}
}
}
private fun showResults(results: List<Pair<String, Float>>) {
val output = buildString {
append("识别结果:\n")
results.forEach { (label, confidence) ->
append("${label}: ${"%.2f".format(confidence * 100)}%\n")
}
}
tvResult.text = output
}
// CameraX初始化
private fun startCamera() {
val cameraProviderFuture = ProcessCameraProvider.getInstance(this)
cameraProviderFuture.addListener({
val cameraProvider = cameraProviderFuture.get()
val preview = Preview.Builder()
.build()
.also { it.setSurfaceProvider(cameraPreview.surfaceProvider) }
imageCapture = ImageCapture.Builder()
.setCaptureMode(ImageCapture.CAPTURE_MODE_MINIMIZE_LATENCY)
.build()
try {
cameraProvider.unbindAll()
cameraProvider.bindToLifecycle(
this, CameraSelector.DEFAULT_BACK_CAMERA, preview, imageCapture)
} catch (exc: Exception) {
Log.e(TAG, "相机初始化失败", exc)
}
}, ContextCompat.getMainExecutor(this))
}
// 权限处理
override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<String>,
grantResults: IntArray
) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
if (requestCode == REQUEST_CODE_PERMISSIONS) {
if (allPermissionsGranted()) {
startCamera()
} else {
Toast.makeText(this, "需要相机权限", Toast.LENGTH_SHORT).show()
finish()
}
}
}
companion object {
private const val TAG = "MLDemo"
private const val REQUEST_CODE_PERMISSIONS = 10
private const val REQUEST_IMAGE_PICK = 101
private val REQUIRED_PERMISSIONS = arrayOf(Manifest.permission.CAMERA)
private fun allPermissionsGranted() = REQUIRED_PERMISSIONS.all {
ContextCompat.checkSelfPermission(context, it) == PackageManager.PERMISSION_GRANTED
}
}
}
高级优化技巧
1. 性能监控
class BenchmarkHelper {
fun measureInference(bitmap: Bitmap) {
val warmupRuns = 10
val benchmarkRuns = 100
// 预热
repeat(warmupRuns) {
classifier.classify(bitmap)
}
// 正式测试
val start = SystemClock.elapsedRealtime()
repeat(benchmarkRuns) {
classifier.classify(bitmap)
}
val avgTime = (SystemClock.elapsedRealtime() - start) / benchmarkRuns.toFloat()
Log.d("Benchmark", "平均推理时间: ${avgTime}ms")
}
}
2. 模型动态更新
private fun downloadAndUpdateModel(modelUrl: String) {
lifecycleScope.launch(Dispatchers.IO) {
try {
val tempFile = File.createTempFile("model", ".tflite")
Retrofit.Builder()
.baseUrl("https://your-model-server/")
.build()
.create(ModelService::class.java)
.downloadModel(modelUrl)
.enqueue(object : Callback<ResponseBody> {
override fun onResponse(call: Call<ResponseBody>, response: Response<ResponseBody>) {
response.body()?.byteStream()?.use { input ->
tempFile.outputStream().use { output ->
input.copyTo(output)
}
}
classifier.updateModel(tempFile)
}
override fun onFailure(call: Call<ResponseBody>, t: Throwable) {
Log.e("ModelUpdate", "下载失败", t)
}
})
} catch (e: Exception) {
Log.e("ModelUpdate", "更新失败", e)
}
}
}
常见问题解决方案
问题1:输入尺寸不匹配
解决方案:
val inputTensor = classifier.getInputTensor(0)
val inputShape = inputTensor.shape() // 获取实际输入尺寸
val dataType = inputTensor.dataType()
// 动态调整预处理
val resizeOp = when (dataType) {
DataType.UINT8 -> ResizeWithCropOrPadOp(inputShape[1], inputShape[2])
DataType.FLOAT32 -> ResizeOp(inputShape[1], inputShape[2], ResizeMethod.BILINEAR)
else -> throw IllegalArgumentException("不支持的输入类型")
}
问题2:内存泄漏
预防措施:
override fun onDestroy() {
super.onDestroy()
classifier.close()
cameraExecutor.shutdown()
}
扩展应用方向
实时视频流处理
class VideoAnalyzer(private val classifier: TFLiteImageClassifier) : ImageAnalysis.Analyzer {
private val frameCounter = AtomicInteger(0)
private val skipFrame = 3 // 控制处理频率
override fun analyze(imageProxy: ImageProxy) {
if (frameCounter.getAndIncrement() % skipFrame != 0) {
imageProxy.close()
return
}
val bitmap = imageProxy.toBitmap() // 实现ImageProxy转Bitmap
lifecycleScope.launch(Dispatchers.Default) {
val results = classifier.classify(bitmap)
updateUI(results)
imageProxy.close()
}
}
}
最佳实践建议
-
模型优化:
- 使用TFLite Model Optimization Toolkit进行量化
- 使用ML Metadata添加模型描述
-
性能平衡:
- 根据设备性能动态选择推理后端(CPU/GPU/NNAPI)
- 针对低端设备启用XNNPACK优化:
ImageClassifierOptions.builder() .setComputeSettings( ComputeSettings.builder() .setDelegate(Delegate.XNNPACK) .build()
-
安全防护:
// 模型完整性校验 fun verifyModel(file: File): Boolean { val expectedHash = "a1b2c3d4..." // 预计算SHA256 return FileUtils.calculateSHA256(file) == expectedHash }
建议结合具体业务需求选择合适的模型,并通过性能分析工具持续优化推理流程。