14.6.4 运行OCR模型
编写文件app\src\main\java\org\tensorflow\lite\examples\ocr\OCRModelExecutor.kt,功能是运行OCR模型,分别实现文本检测和文本识别功能。文件OCRModelExecutor.kt的具体实现流程如下所示。
(1)设置需要的常量属性,对应代码如下所示:
class OCRModelExecutor(context: Context, private var useGPU: Boolean = false) : AutoCloseable {
private var gpuDelegate: GpuDelegate? = null
private val recognitionResult: ByteBuffer
private val detectionInterpreter: Interpreter
private val recognitionInterpreter: Interpreter
private var ratioHeight = 0.toFloat()
private var ratioWidth = 0.toFloat()
private var indicesMat: MatOfInt
private var boundingBoxesMat: MatOfRotatedRect
private var ocrResults: HashMap<String, Int>
(2)初始化处理,验证是否支持OpenCV,对应代码如下所示:
init {
try {
if (!OpenCVLoader.initDebug()) throw Exception("Unable to load OpenCV")
else Log.d(TAG, "OpenCV loaded")
} catch (e: Exception) {
val exceptionLog = "something went wrong: ${e.message}"
Log.d(TAG, exceptionLog)
}
(3)创建检测解释器,检测指定范围内的图像信息,对应代码如下所示:
detectionInterpreter = getInterpreter(context, textDetectionModel, useGPU)
//识别模型需要Flex,因此无论用户选择如何,我们都会禁用GPU代理
recognitionInterpreter = getInterpreter(context, textRecognitionModel, false)
recognitionResult = ByteBuffer.allocateDirect(recognitionModelOutputSize * 8)
recognitionResult.order(ByteOrder.nativeOrder())
indicesMat = MatOfInt()
boundingBoxesMat = MatOfRotatedRect()
ocrResults = HashMap<String, Int>()
}
(4)编写方法execute(data: Bitmap),功能是处理参数data指定的图像,对应代码如下所示:
fun execute(data: Bitmap): ModelExecutionResult {
try {
ratioHeight = data.height.toFloat() / detectionImageHeight
ratioWidth = data.width.toFloat() / detectionImageWidth
ocrResults.clear()
detectTexts(data)
val bitmapWithBoundingBoxes = recognizeTexts(data, boundingBoxesMat, indicesMat)
return ModelExecutionResult(bitmapWithBoundingBoxes, "OCR result", ocrResults)
} catch (e: Exception) {
val exceptionLog = "something went wrong: ${e.message}"
Log.d(TAG, exceptionLog)
val emptyBitmap = ImageUtils.createEmptyBitmap(displayImageSize, displayImageSize)
return ModelExecutionResult(emptyBitmap, exceptionLog, HashMap<String, Int>())
}
}
(5)编写方法detectTexts(),功能是指定参数data图像中的文字,对应代码如下所示:
private fun detectTexts(data: Bitmap) {
val detectionTensorImage =
ImageUtils.bitmapToTensorImageForDetection(
data,
detectionImageWidth,
detectionImageHeight,
detectionImageMeans,
detectionImageStds
)
val detectionInputs = arrayOf(detectionTensorImage.buffer.rewind())
val detectionOutputs: HashMap<Int, Any> = HashMap<Int, Any>()
val detectionScores =
Array(1) { Array(detectionOutputNumRows) { Array(detectionOutputNumCols) { FloatArray(1) } } }
val detectionGeometries =
Array(1) { Array(detectionOutputNumRows) { Array(detectionOutputNumCols) { FloatArray(5) } } }
detectionOutputs.put(0, detectionScores)
detectionOutputs.put(1, detectionGeometries)
detectionInterpreter.runForMultipleInputsOutputs(detectionInputs, detectionOutputs)
val transposeddetectionScores =
Array(1) { Array(1) { Array(detectionOutputNumRows) { FloatArray(detectionOutputNumCols) } } }
val transposedDetectionGeometries =
Array(1) { Array(5) { Array(detectionOutputNumRows) { FloatArray(detectionOutputNumCols) } } }
//转换检测输出张量
for (i in 0 until transposeddetectionScores[0][0].size) {
for (j in 0 until transposeddetectionScores[0][0][0].size) {
for (k in 0 until 1) {
transposeddetectionScores[0][k][i][j] = detectionScores[0][i][j][k]
}
for (k in 0 until 5) {
transposedDetectionGeometries[0][k][i][j] = detectionGeometries[0][i][j][k]
}
}
}
val detectedRotatedRects = ArrayList<RotatedRect>()
val detectedConfidences = ArrayList<Float>()
for (y in 0 until transposeddetectionScores[0][0].size) {
val detectionScoreData = transposeddetectionScores[0][0][y]
val detectionGeometryX0Data = transposedDetectionGeometries[0][0][y]
val detectionGeometryX1Data = transposedDetectionGeometries[0][1][y]
val detectionGeometryX2Data = transposedDetectionGeometries[0][2][y]
val detectionGeometryX3Data = transposedDetectionGeometries[0][3][y]
val detectionRotationAngleData = transposedDetectionGeometries[0][4][y]
for (x in 0 until transposeddetectionScores[0][0][0].size) {
if (detectionScoreData[x] < 0.5) {
continue
}
//计算旋转的边界框和约束(主要基于OpenCV示例):
// https://github.com/opencv/opencv/blob/master/samples/dnn/text_detection.py
val offsetX = x * 4.0
val offsetY = y * 4.0
val h = detectionGeometryX0Data[x] + detectionGeometryX2Data[x]
val w = detectionGeometryX1Data[x] + detectionGeometryX3Data[x]
val angle = detectionRotationAngleData[x]
val cos = Math.cos(angle.toDouble())
val sin = Math.sin(angle.toDouble())
val offset =
Point(
offsetX + cos * detectionGeometryX1Data[x] + sin * detectionGeometryX2Data[x],
offsetY - sin * detectionGeometryX1Data[x] + cos * detectionGeometryX2Data[x]
)
val p1 = Point(-sin * h + offset.x, -cos * h + offset.y)
val p3 = Point(-cos * w + offset.x, sin * w + offset.y)
val center = Point(0.5 * (p1.x + p3.x), 0.5 * (p1.y + p3.y))
val textDetection =
RotatedRect(
center,
Size(w.toDouble(), h.toDouble()),
(-1 * angle * 180.0 / Math.PI)
)
detectedRotatedRects.add(textDetection)
detectedConfidences.add(detectionScoreData[x])
}
}
val detectedConfidencesMat = MatOfFloat(vector_float_to_Mat(detectedConfidences))
boundingBoxesMat = MatOfRotatedRect(vector_RotatedRect_to_Mat(detectedRotatedRects))
NMSBoxesRotated(
boundingBoxesMat,
detectedConfidencesMat,
detectionConfidenceThreshold.toFloat(),
detectionNMSThreshold.toFloat(),
indicesMat
)
}
(6)编写方法recognizeTexts(),功能是调用模型实现文字识别功能,通过copy()方法返回新的Bitmap对象,他的像素格式是ARGB_8888。在Android中界面显示图片时,需要的内存空间不是按图片的实际大小来计算的,而是按像素点的多少乘以每个像素点占用的空间大小来计算的。比如一个400*800的图片以ARGB_8888形式显示则占用:(400*800*4)/1024=1500kb的内存。在图像中检测文字的时候,会使用for循环遍历指定区域内的每一个点,然后使用drawLine()绘制方块,将有文字的区域标记出来。最后将有文字的位图转换为张量图像,从而实现文字识别功能。方法recognizeTexts()的实现代码如下所示:
private fun recognizeTexts(
data: Bitmap,
boundingBoxesMat: MatOfRotatedRect,
indicesMat: MatOfInt
): Bitmap {
val bitmapWithBoundingBoxes = data.copy(Bitmap.Config.ARGB_8888, true)
val canvas = Canvas(bitmapWithBoundingBoxes)
val paint = Paint()
paint.style = Paint.Style.STROKE
paint.strokeWidth = 10.toFloat()
paint.setColor(Color.GREEN)
for (i in indicesMat.toArray()) {
val boundingBox = boundingBoxesMat.toArray()[i]
val targetVertices = ArrayList<Point>()
targetVertices.add(Point(0.toDouble(), (recognitionImageHeight - 1).toDouble()))
targetVertices.add(Point(0.toDouble(), 0.toDouble()))
targetVertices.add(Point((recognitionImageWidth - 1).toDouble(), 0.toDouble()))
targetVertices.add(
Point((recognitionImageWidth - 1).toDouble(), (recognitionImageHeight - 1).toDouble())
)
val srcVertices = ArrayList<Point>()
val boundingBoxPointsMat = Mat()
boxPoints(boundingBox, boundingBoxPointsMat)
for (j in 0 until 4) {
srcVertices.add(
Point(
boundingBoxPointsMat.get(j, 0)[0] * ratioWidth,
boundingBoxPointsMat.get(j, 1)[0] * ratioHeight
)
)
if (j != 0) {
canvas.drawLine(
(boundingBoxPointsMat.get(j, 0)[0] * ratioWidth).toFloat(),
(boundingBoxPointsMat.get(j, 1)[0] * ratioHeight).toFloat(),
(boundingBoxPointsMat.get(j - 1, 0)[0] * ratioWidth).toFloat(),
(boundingBoxPointsMat.get(j - 1, 1)[0] * ratioHeight).toFloat(),
paint
)
}
}
canvas.drawLine(
(boundingBoxPointsMat.get(0, 0)[0] * ratioWidth).toFloat(),
(boundingBoxPointsMat.get(0, 1)[0] * ratioHeight).toFloat(),
(boundingBoxPointsMat.get(3, 0)[0] * ratioWidth).toFloat(),
(boundingBoxPointsMat.get(3, 1)[0] * ratioHeight).toFloat(),
paint
)
val srcVerticesMat =
MatOfPoint2f(srcVertices[0], srcVertices[1], srcVertices[2], srcVertices[3])
val targetVerticesMat =
MatOfPoint2f(targetVertices[0], targetVertices[1], targetVertices[2], targetVertices[3])
val rotationMatrix = getPerspectiveTransform(srcVerticesMat, targetVerticesMat)
val recognitionBitmapMat = Mat()
val srcBitmapMat = Mat()
bitmapToMat(data, srcBitmapMat)
warpPerspective(
srcBitmapMat,
recognitionBitmapMat,
rotationMatrix,
Size(recognitionImageWidth.toDouble(), recognitionImageHeight.toDouble())
)
val recognitionBitmap =
ImageUtils.createEmptyBitmap(
recognitionImageWidth,
recognitionImageHeight,
0,
Bitmap.Config.ARGB_8888
)
matToBitmap(recognitionBitmapMat, recognitionBitmap)
val recognitionTensorImage =
ImageUtils.bitmapToTensorImageForRecognition(
recognitionBitmap,
recognitionImageWidth,
recognitionImageHeight,
recognitionImageMean,
recognitionImageStd
)
recognitionResult.rewind()
recognitionInterpreter.run(recognitionTensorImage.buffer, recognitionResult)
var recognizedText = ""
for (k in 0 until recognitionModelOutputSize) {
var alphabetIndex = recognitionResult.getInt(k * 8)
if (alphabetIndex in 0..alphabets.length - 1)
recognizedText = recognizedText + alphabets[alphabetIndex]
}
Log.d("Recognition result:", recognizedText)
if (recognizedText != "") {
ocrResults.put(recognizedText, getRandomColor())
}
}
return bitmapWithBoundingBoxes
}
(7)编写方法loadModelFile(),功能是加载指定的模型文件,对应代码如下所示:
private fun loadModelFile(context: Context, modelFile: String): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(modelFile)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
val retFile = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
fileDescriptor.close()
return retFile
}
@Throws(IOException::class)
private fun getInterpreter(
context: Context,
modelName: String,
useGpu: Boolean = false
): Interpreter {
val tfliteOptions = Interpreter.Options()
tfliteOptions.setNumThreads(numberThreads)
gpuDelegate = null
if (useGpu) {
gpuDelegate = GpuDelegate()
tfliteOptions.addDelegate(gpuDelegate)
}
return Interpreter(loadModelFile(context, modelName), tfliteOptions)
}
(8)编写方法close(),功能是关闭识别功能,对应代码如下所示:
override fun close() {
detectionInterpreter.close()
recognitionInterpreter.close()
if (gpuDelegate != null) {
gpuDelegate!!.close()
}
}
(9)编写方法getRandomColor(),功能是获取Android的随机颜色,然后用应用于识别结果的文字,这样达到了突出识别结果的效果。方法getRandomColor()的实现代码如下所示:
fun getRandomColor(): Int {
val random = Random()
return Color.argb(
(128),
(255 * random.nextFloat()).toInt(),
(255 * random.nextFloat()).toInt(),
(255 * random.nextFloat()).toInt()
)
}
companion object {
public const val TAG = "TfLiteOCRDemo"
private const val textDetectionModel = "text_detection.tflite"
private const val textRecognitionModel = "text_recognition.tflite"
private const val numberThreads = 4
private const val alphabets = "0123456789abcdefghijklmnopqrstuvwxyz"
private const val displayImageSize = 257
private const val detectionImageHeight = 320
private const val detectionImageWidth = 320
private val detectionImageMeans =
floatArrayOf(103.94.toFloat(), 116.78.toFloat(), 123.68.toFloat())
private val detectionImageStds = floatArrayOf(1.toFloat(), 1.toFloat(), 1.toFloat())
private val detectionOutputNumRows = 80
private val detectionOutputNumCols = 80
private val detectionConfidenceThreshold = 0.5
private val detectionNMSThreshold = 0.4
private const val recognitionImageHeight = 31
private const val recognitionImageWidth = 200
private const val recognitionImageMean = 0.toFloat()
private const val recognitionImageStd = 255.toFloat()
private const val recognitionModelOutputSize = 48
}
14.7 调试运行
单击Android Studio顶部的运行按钮运行本项目,在Android设备中将会显示执行效果。在屏幕上方会显示要识别的图片,在下方悬浮界面识别结果。如果图像中没有文字,则执行效果如图14-3所示。如果在图像中有文字,则显示识别结果,执行效果如图14-4所示。
本项目已完结:
(14-1)OCR文本检测识别系统(OpenCV+TensorFlow Lite+TensorFlow+Android):背景介绍+背景介绍-CSDN博客
(14-2)OCR文本检测识别系统(OpenCV+TensorFlow Lite+TensorFlow+Android):系统介绍+准备模型+创建工程_tensorflow ocr识别-CSDN博客
(14-3-01)OCR文本检测识别系统(OpenCV+TensorFlow Lite+TensorFlow+Android):具体实现-CSDN博客
(14-3-02)OCR文本检测识别系统(OpenCV+TensorFlow Lite+TensorFlow+Android):图像操作-CSDN博客