import android.content.Context
import android.graphics.Bitmap
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.nnapi.NnApiDelegate
import java.nio.ByteBuffer
import java.nio.ByteOrder
class EdgeModelManager(
private val context: Context,
private val modelPath: String,
private val config: ModelConfig
) {
private var interpreter: Interpreter? = null
private var nnApiDelegate: NnApiDelegate? = null
private var gpuDelegate: GpuDelegate? = null
init {
initializeInterpreter()
}
private fun initializeInterpreter() {
val options = Interpreter.Options().apply {
when {
config.enableNNAPI && isNnApiSupported() -> {
nnApiDelegate = NnApiDelegate(NnApiDelegate.Options().apply {
setUseNnApiCpu(true)
setExecutionPriority(NnApiDelegate.Options.PRIORITY_HIGH)
})
addDelegate(nnApiDelegate)
}
config.enableGPU && isGpuSupported() -> {
gpuDelegate = GpuDelegate(CompatibilityList().bestOptionsForThisDevice)
addDelegate(gpuDelegate)
}
else -> {
setNumThreads(config.cpuThreads)
}
}
setAllowFp16PrecisionForFp32(config.allowHalfPrecision)
setUseXNNPACK(config.enableXNNPACK)
}
val modelBuffer = context.assets.open(modelPath).use { inputStream ->
val buffer = ByteBuffer.allocateDirect(inputStream.available())
.order(ByteOrder.nativeOrder())
inputStream.read(buffer.array())
buffer
}
interpreter = Interpreter(modelBuffer, options)
}
fun runInference(inputs: List<Any>): List<Any> {
check(interpreter != null) { "模型未初始化" }
val inputBuffers = inputs.map { input ->
when (input) {
is Bitmap -> bitmapToByteBuffer(input, config.inputShape)
is ByteBuffer -> input
is FloatArray -> floatArrayToByteBuffer(input)
else -> throw IllegalArgumentException("不支持的输入类型: ${input::class.java}")
}
}
val outputShapes = config.outputShapes
val outputs = outputShapes.map { shape ->
when (config.outputTypes[outputShapes.indexOf(shape)]) {
DataType.FLOAT32 -> Array(shape[0]) { FloatArray(shape[1]) }
DataType.INT8 -> Array(shape[0]) { ByteArray(shape[1]) }
else -> throw IllegalArgumentException("不支持的输出类型")
}
}
interpreter?.runForMultipleInputsOutputs(inputBuffers.toTypedArray(), outputs.toTypedArray())
return outputs
}
fun release() {
interpreter?.close()
nnApiDelegate?.close()
gpuDelegate?.close()
}
private fun bitmapToByteBuffer(bitmap: Bitmap, inputShape: IntArray): ByteBuffer {
val buffer = ByteBuffer.allocateDirect(4 * inputShape[1] * inputShape[2] * inputShape[3])
.order(ByteOrder.nativeOrder())
buffer.rewind()
val pixels = IntArray(bitmap.width * bitmap.height)
bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
for (pixel in pixels) {
buffer.putFloat((pixel shr 16 and 0xFF) / 255.0f)
buffer.putFloat((pixel shr 8 and 0xFF) / 255.0f)
buffer.putFloat((pixel and 0xFF) / 255.0f)
}
return buffer
}
private fun floatArrayToByteBuffer(array: FloatArray): ByteBuffer {
return ByteBuffer.allocateDirect(4 * array.size)
.order(ByteOrder.nativeOrder())
.putFloatArray(array)
}
private fun isNnApiSupported() = android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.P
private fun isGpuSupported() = CompatibilityList().isDelegateSupportedOnThisDevice
data class ModelConfig(
val inputShape: IntArray,
val outputShapes: List<IntArray>,
val outputTypes: List<DataType>,
val enableNNAPI: Boolean = true,
val enableGPU: Boolean = true,
val cpuThreads: Int = 4,
val allowHalfPrecision: Boolean = true,
val enableXNNPACK: Boolean = true
)
enum class DataType { FLOAT32, INT8, UINT8 }
}