TFLite Objec Detection IOS 检测核心代码说明

TFLite Objec Detection IOS 检测核心代码说明

简要说明

本文档面向 IOS 移动端开发人员,主要讲解如何使用 tensorflow lite 在 IOS 设备上实现目标检测,以及在更换检测模型之后如何调整项目中的参数,使其成功运行。

Object Detection

首先需要说明一下目标检测的大致情况,目标检测与图像分类 Image classification 有很大不同。图像分类更着重于输入图片内容是否属于某个特定的类别。以关注“猫”的信息为例,图像分类仅关注图片中是否是一只猫。

目标检测是指判断输入图像中是否包含某个/些指定类别的目标,一张输入图片的检测结果可能包含一个或多个,如模型可以检测”猫“和”狗“两类,那么输入一张包含 2 只猫 2 只狗的图片,预期结果中会包含 4 条检测结果;如果有,则还要返回其类别、位置、以及对应概率信息。同样以关注“猫”的信息为例,若展示一副阳台上有花草(或其他物品)和猫咪在晒太阳的图片,应该返回猫咪在原图中的位置,即同时携带了“猫”这个类别的信息及“猫”的位置信息。

执行 object detection 所使用的模型

了解到上面目标检测的大致情况后,再说明一下模型,模型即.tflite文件。目标检测涉及到“类别”及“位置”两部分信息,模型的输出自然也包含该两部分信息(还包括一个概率信息,输出的类别位置有多大的概率是准确的)。

关于类别,若任务中涉及特定 5 类,比如手、人脸、水杯等,那么模型不会直接直接输出概念字符如“水杯”,而是用某个数值代表各类,如 0 代表手,1 代表人脸等。这个映射信息在整个训练过程中,一经定义不可随意更改,否则可能出现报错情况。在本项目中,即 labelmap.txt 文件所包含的信息,因此请勿自行对本文件进行任何修改。

关于位置,一般用 4 个数值表示,4 个数值直接或间接地携带了目标中心点在原图中的坐标及目标的宽高信息。

针对特定任务所训练得到的模型,一经定义和完成之后,其输入输出的尺寸、数据类型等均已定好,且**不能改变**,否则报错。如本项目中该检测模型接受输入图像尺寸为 300 × 300,那么就应该保证待检测图像被送入模型识别前被处理为 300 × 300,其他数据类型等情况同理。

代码分析 ModelDataHandler.swift

可以在线或者克隆 github上 tensorflow examples 项目后阅读本文件完整代码。
后文按照先代码段,后说明的顺序。

Line 26
/// Stores one formatted inference.
struct Inference {
  let confidence: Float
  let className: String
  let rect: CGRect
  let displayColor: UIColor
}

inference: 模型的训练过程被称为 Training,此处的 inference 是指模型在做推断,即模型已完成被应用的,相近词还有 predict,本项目就是在用训练好的模型做 inference。

那么上述代码段即是在说明模型在预测和推断时的数据结构信息。

confidence: 即前面说明中提及的概率,值在 0~1之间,float,越高代表模型“认为”预测得越好。

className: 即类别信息,代表某数字(或者已和 labelmap 做映射得到具有真实含义的类别名如‘cat’)。

rect: 即位置信息,具体请参考代码中所示 CGRect。

displayColor: 为方便开发人员调试,本项目提供了绘制功能,即将检测到的类别和位置信息用实时画框的方式展示在前端界面。

Line 34
/// Information about a model file or labels file.
typealias FileInfo = (name: String, extension: String)

/// Information about the MobileNet SSD model.
enum MobileNetSSD {
  static let modelInfo: FileInfo = (name: "detect", extension: "tflite")
  static let labelsInfo: FileInfo = (name: "labelmap", extension: "txt")
}

本段代码即在给定实现目标检测所使用的模型及对应的 labelmap 信息,模型文件为 detect.tflite,labelmap文件为 lablemap.txt。

Line 43
/// This class handles all data preprocessing and makes calls to run inference on a given frame
/// by invoking the `Interpreter`. It then formats the inferences obtained and returns the top N
/// results for a successful inference.

定义类,这个类即核心。通过调用 TensorFlow Lite 提供的解释器 Interpreter,实现数据预处理、调用模型实现 inference,对 inference 的结果进行格式化,返回模型”认为“识别的最好的(概率高) N 个inference 结果。

  Line 48
  // MARK: - Internal Properties
  /// The current thread count used by the TensorFlow Lite Interpreter.
  let threadCount: Int
  let threadCountLimit = 10

线程问题,略过。

  Line 53
  let threshold: Float = 0.5

threshold:阈值,用来滤除概率比较小的检测结果,0~1 之间可以更改,阈值越高,意味着对识别结果的准确率要求越高。

  Line 55
  // MARK: Model parameters
  let batchSize = 1
  let inputChannels = 3
  let inputWidth = 300
  let inputHeight = 300

本段为项目中自带的模型相关的固有参数,若针对其他任务更换了检测模型,应该再次核查此处参数。

batchSize: 是指模型一次调用能够推测出来多少frame,移动端资源有限,此处即一次调用推测一帧。

inputChannels:是指输入模型的图片通道数,灰度图为 1 通道,常见彩图 3 通道,此处即 3 通道彩图。

inputWidth 和 inputHeight:及输入模型的图片尺寸,此处即为 300 × 300。若待处理的图片尺寸与此不符,应先进行 resize 等操作,保证与此相符。

  Line 61
  // image mean and std for floating model, should be consistent with parameters used in model training
  let imageMean: Float = 127.5
  let imageStd:  Float = 127.5
  
  // MARK: Private properties
  private var labels: [String] = []

  /// TensorFlow Lite `Interpreter` object for performing inference on a given model.
  private var interpreter: Interpreter

imageMean 和 imageStd:图像送入模型前,对像素值的归一化预处理,包括均值和标准差,应该与训练过程中设置保持一致。

labels:识别结果中的类别信息。

Interpreter:解释器。

  Line 71
  private let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2)
  private let rgbPixelChannels = 3
  private let colorStrideValue = 10
  private let colors = [
    UIColor.red,
    UIColor(displayP3Red: 90.0/255.0, green: 200.0/255.0, blue: 250.0/255.0, alpha: 1.0),
    UIColor.green,
    UIColor.orange,
    UIColor.blue,
    UIColor.purple,
    UIColor.magenta,
    UIColor.yellow,
    UIColor.cyan,
    UIColor.brown
  ]

前端显示颜色相关信息设置,略过。

  Line 87
  // MARK: - Initialization

  /// A failable initializer for `ModelDataHandler`. A new instance is created if the model and
  /// labels files are successfully loaded from the app's main bundle. Default `threadCount` is 1.
  init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int = 1) {
    let modelFilename = modelFileInfo.name

    // Construct the path to the model file.
    guard let modelPath = Bundle.main.path(
      forResource: modelFilename,
      ofType: modelFileInfo.extension
    ) else {
      print("Failed to load the model file with name: \(modelFilename).")
      return nil
    }

    // Specify the options for the `Interpreter`.
    self.threadCount = threadCount
    var options = Interpreter.Options()
    options.threadCount = threadCount
    do {
      // Create the `Interpreter`.
      interpreter = try Interpreter(modelPath: modelPath, options: options)
      // Allocate memory for the model's input `Tensor`s.
      try interpreter.allocateTensors()
    } catch let error {
      print("Failed to create the interpreter with error: \(error.localizedDescription)")
      return nil
    }

    super.init()

    // Load the classes listed in the labels file.
    loadLabels(fileInfo: labelsFileInfo)
  }

本段代码在初始化类,如果可以成功调用到模型和 labelmap 信息,就会成功创建一个 interpreter 实例(Line108)。如果创建失败,请排查给定的模型及 labelmap 信息。

接着会载入 labelmap 中的信息(Line 119调用),定义在 Line 243 ,后面不再介绍。

  Line 123
  /// This class handles all data preprocessing and makes calls to run inference on a given frame
  /// through the `Interpreter`. It then formats the inferences obtained and returns the top N
  /// results for a successful inference.

本段为 Class 中核心部分,调用模型返回检测结果。

Line 126
func runModel(onFrame pixelBuffer: CVPixelBuffer) -> Result? {

请注意输入数据格式问题。

    Line 127
    let imageWidth = CVPixelBufferGetWidth(pixelBuffer)
    let imageHeight = CVPixelBufferGetHeight(pixelBuffer)

获取实际图片的尺寸的宽高信息(请注意此处还没有固定 300 × 300)。

    Line 129
    let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
    assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
             sourcePixelFormat == kCVPixelFormatType_32BGRA ||
               sourcePixelFormat == kCVPixelFormatType_32RGBA)

图片格式问题,不需要改。

    Line 135
    let imageChannels = 4
    assert(imageChannels >= inputChannels)

Line 135 ~ Line 136,图片通道问题,不需要改。

    Line 138
    // Crops the image to the biggest square in the center and scales it down to model dimensions.
    let scaledSize = CGSize(width: inputWidth, height: inputHeight)
    guard let scaledPixelBuffer = pixelBuffer.resized(to: scaledSize) else {
      return nil
    }

本段在做尺寸预处理,把待处理的图片尺寸 resize 到模型接受的尺寸大小。

此处我暂时对注释中的 crop 存疑,在目标靠近图像边界时,是否会对检测造成影响。

但是按照代码行实际所写,应该并没有 crop。

    Line 144
    let interval: TimeInterval
    let outputBoundingBox: Tensor
    let outputClasses: Tensor
    let outputScores: Tensor
    let outputCount: Tensor

变量类型问题。

interval:记录 inference 一帧所需要的时间,应该是 ms 级别。

outputBoundingBox:目标的位置信息。

outputClasses:目标的类别信息。

outputScores:目标的概率信息,越高代表这个结果越好。

outputCount: 检测到的目标数量信息。

    Line 149
    do {
      let inputTensor = try interpreter.input(at: 0)

      // Remove the alpha component from the image buffer to get the RGB data.
      guard let rgbData = rgbDataFromBuffer(
        scaledPixelBuffer,
        byteCount: batchSize * inputWidth * inputHeight * inputChannels,
        isModelQuantized: inputTensor.dataType == .uInt8
      ) else {
        print("Failed to convert the image buffer to RGB data.")
        return nil
      }

      // Copy the RGB data to the input `Tensor`.
      try interpreter.copy(rgbData, toInputAt: 0)

      // Run inference by invoking the `Interpreter`.
      let startDate = Date()
      try interpreter.invoke()
      interval = Date().timeIntervalSince(startDate) * 1000

      outputBoundingBox = try interpreter.output(at: 0)
      outputClasses = try interpreter.output(at: 1)
      outputScores = try interpreter.output(at: 2)
      outputCount = try interpreter.output(at: 3)
    } catch let error {
      print("Failed to invoke the interpreter with error: \(error.localizedDescription)")
      return nil
    }

本段代码可以分为三步:

(1)Line 152 开始,在转换图像信息,去除 alpha 通道获取 RGB 通道信息。

(2)Line 162 开始,把上一步拿到的 RGB 通道信息 rgbData 复制到 interpreter 的输入 tensor 中,准备 inference

(3)调用之前实例化的 Interpreter 对刚刚放好的图片信息进行 inference。Line 166 ~ Line 168 记录了本次 inference 的耗时;Line 167 为真实的 inference 调用;Line 170 ~ Line 173 在获取 inference 的结果,依次是位置信息、类别信息、概率信息、数量信息。

    Line 179
    // Formats the results
    let resultArray = formatResults(
      boundingBox: [Float](unsafeData: outputBoundingBox.data) ?? [],
      outputClasses: [Float](unsafeData: outputClasses.data) ?? [],
      outputScores: [Float](unsafeData: outputScores.data) ?? [],
      outputCount: Int(([Float](unsafeData: outputCount.data) ?? [0])[0]),
      width: CGFloat(imageWidth),
      height: CGFloat(imageHeight)
    )

本段对 inference 的结果做格式化。

    Line 189
    // Returns the inference time and inferences
    let result = Result(inferenceTime: interval, inferences: resultArray)
    return result

返回本次 inference 的耗时及格式化后的结果信息。

Line 194
/// Filters out all the results with confidence score < threshold and returns the top N results
  /// sorted in descending order.
  func formatResults(boundingBox: [Float], outputClasses: [Float], outputScores: [Float], outputCount: Int, width: CGFloat, height: CGFloat) -> [Inference]{
    

本段代码对 inference 的结果信息进行格式化,滤除概率小于概率阈值的结果,返回当前检测结果中概率排名前 N 个结果。N 即输入中的 outputCount 变量。

  	Line 197
    var resultsArray: [Inference] = []
    if (outputCount == 0) {
      return resultsArray
    }

定义本函数的输出,如果 N 为0,即不返回实际检测信息,直接返回空的。

  for i in 0...outputCount - 1 {
      let score = outputScores[i]
      // Filters results with confidence < threshold.
      guard score >= threshold else {
        continue
      }

      // Gets the output class names for detected classes from labels list.
      let outputClassIndex = Int(outputClasses[i])
      let outputClass = labels[outputClassIndex + 1]

      var rect: CGRect = CGRect.zero

      // Translates the detected bounding box to CGRect.
      rect.origin.y = CGFloat(boundingBox[4*i])
      rect.origin.x = CGFloat(boundingBox[4*i+1])
      rect.size.height = CGFloat(boundingBox[4*i+2]) - rect.origin.y
      rect.size.width = CGFloat(boundingBox[4*i+3]) - rect.origin.x

      // The detected corners are for model dimensions. So we scale the rect with respect to the
      // actual image dimensions.
      let newRect = rect.applying(CGAffineTransform(scaleX: width, y: height))

      // Gets the color assigned for the class
      let colorToAssign = colorForClass(withIndex: outputClassIndex + 1)
      let inference = Inference(confidence: score,
                                className: outputClass,
                                rect: newRect,
                                displayColor: colorToAssign)
      resultsArray.append(inference)
    }

for 循环,获取前 N 个结果。

Line 205:滤除概率过低的结果。

Line 210:处理检测结果中的类别信息,联动 labelmap 的内容,根据模型返回的类别下标 outputClassIndex 获取 labelmap 中的实际类别名称如”猫“。

Line 214:处理检测结果中的位置信息,按照顺序获取当前第 i 个检测结果位置信息的 y,x,height,width。

Line 222:还记得当初送进模型的图片事先经过了尺寸的处理吗,若想将 inference 的结果绘制/映射到最初的图片中,应该做一定的变换。

Line 226:从之前定义好的 color 集合里面分配一个颜色给当前这个结果。

Line 228:包好准备返回最后的 inference 结果。

    Line 235
    // Sort results in descending order of confidence.
    resultsArray.sort { (first, second) -> Bool in
      return first.confidence  > second.confidence
    }

    return resultsArray
}

函数最后,按照概率降序排序,输出。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值