Android端侧模型部署核心框架管理类

import android.content.Context
import android.graphics.Bitmap
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import java.nio.ByteBuffer
import java.nio.ByteOrder

/**
 * Android端侧模型部署框架核心管理类
 * 支持TFLite模型的自动加速、内存优化和业务友好调用
 * @param context 应用上下文
 * @param modelPath assets中的模型路径(如"models/classifier.tflite")
 * @param config 模型配置参数
 */
class EdgeModelManager(
    private val context: Context,
    private val modelPath: String,
    private val config: ModelConfig
) {
    private var interpreter: Interpreter? = null
    private var nnApiDelegate: NnApiDelegate? = null
    private var gpuDelegate: GpuDelegate? = null

    init {
        initializeInterpreter()
    }

    /**
     * 初始化模型解释器(含自动加速配置)
     */
    private fun initializeInterpreter() {
        val options = Interpreter.Options().apply {
            // 自动选择最优加速方案
            when {
                config.enableNNAPI && isNnApiSupported() -> {
                    nnApiDelegate = NnApiDelegate(NnApiDelegate.Options().apply {
                        setUseNnApiCpu(true)  // 允许CPU作为备选
                        setExecutionPriority(NnApiDelegate.Options.PRIORITY_HIGH)
                    })
                    addDelegate(nnApiDelegate)
                }
                config.enableGPU && isGpuSupported() -> {
                    gpuDelegate = GpuDelegate(CompatibilityList().bestOptionsForThisDevice)
                    addDelegate(gpuDelegate)
                }
                else -> {
                    setNumThreads(config.cpuThreads)  // 默认CPU多线程
                }
            }
            setAllowFp16PrecisionForFp32(config.allowHalfPrecision)  // 混合精度支持
            setUseXNNPACK(config.enableXNNPACK)  // 启用XNNPACK加速
        }

        // 从assets加载模型
        val modelBuffer = context.assets.open(modelPath).use { inputStream ->
            val buffer = ByteBuffer.allocateDirect(inputStream.available())
                .order(ByteOrder.nativeOrder())
            inputStream.read(buffer.array())
            buffer
        }

        interpreter = Interpreter(modelBuffer, options)
    }

    /**
     * 执行推理(通用输入输出)
     * @param inputs 输入数据列表(支持Bitmap/ByteBuffer/基本类型数组)
     * @return 推理结果列表(根据模型输出类型自动转换)
     */
    fun runInference(inputs: List<Any>): List<Any> {
        check(interpreter != null) { "模型未初始化" }
        
        // 输入预处理(自动转换为模型需要的ByteBuffer)
        val inputBuffers = inputs.map { input ->
            when (input) {
                is Bitmap -> bitmapToByteBuffer(input, config.inputShape)
                is ByteBuffer -> input
                is FloatArray -> floatArrayToByteBuffer(input)
                else -> throw IllegalArgumentException("不支持的输入类型: ${input::class.java}")
            }
        }

        // 初始化输出缓冲区
        val outputShapes = config.outputShapes
        val outputs = outputShapes.map { shape ->
            when (config.outputTypes[outputShapes.indexOf(shape)]) {
                DataType.FLOAT32 -> Array(shape[0]) { FloatArray(shape[1]) }
                DataType.INT8 -> Array(shape[0]) { ByteArray(shape[1]) }
                else -> throw IllegalArgumentException("不支持的输出类型")
            }
        }

        // 执行推理
        interpreter?.runForMultipleInputsOutputs(inputBuffers.toTypedArray(), outputs.toTypedArray())

        return outputs
    }

    /**
     * 释放资源(建议在Activity onDestroy时调用)
     */
    fun release() {
        interpreter?.close()
        nnApiDelegate?.close()
        gpuDelegate?.close()
    }

    // ------------ 私有工具方法 ------------
    private fun bitmapToByteBuffer(bitmap: Bitmap, inputShape: IntArray): ByteBuffer {
        val buffer = ByteBuffer.allocateDirect(4 * inputShape[1] * inputShape[2] * inputShape[3])
            .order(ByteOrder.nativeOrder())
        buffer.rewind()
        
        val pixels = IntArray(bitmap.width * bitmap.height)
        bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        
        for (pixel in pixels) {
            // 归一化处理(假设模型输入范围0-1)
            buffer.putFloat((pixel shr 16 and 0xFF) / 255.0f)  // R
            buffer.putFloat((pixel shr 8 and 0xFF) / 255.0f)   // G
            buffer.putFloat((pixel and 0xFF) / 255.0f)        // B
        }
        return buffer
    }

    private fun floatArrayToByteBuffer(array: FloatArray): ByteBuffer {
        return ByteBuffer.allocateDirect(4 * array.size)
            .order(ByteOrder.nativeOrder())
            .putFloatArray(array)
    }

    private fun isNnApiSupported() = android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.P
    private fun isGpuSupported() = CompatibilityList().isDelegateSupportedOnThisDevice

    // ------------ 数据类定义 ------------
    data class ModelConfig(
        val inputShape: IntArray,  // [batch, height, width, channel]
        val outputShapes: List<IntArray>,  // 支持多输出
        val outputTypes: List<DataType>,  // 输出数据类型列表
        val enableNNAPI: Boolean = true,  // 启用NNAPI加速
        val enableGPU: Boolean = true,    // 启用GPU加速
        val cpuThreads: Int = 4,          // CPU线程数
        val allowHalfPrecision: Boolean = true,  // 允许FP16精度
        val enableXNNPACK: Boolean = true  // 启用XNNPACK
    )

    enum class DataType { FLOAT32, INT8, UINT8 }
}
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值