Android 集成 tensorflow 训练结果 记录

引言

随着人工智能应用越来越普及,目前app 项目正在集成 以 tensorflow 为基础 训练出的结果,主要为图像的场景识别相关。
由于项目关键代码不方便公布,并且是参照google 提供集成android demo 来进行书写代码的,所以同时可以参考代码:
tensorflow android demo

环境

Android 9.0 平台
AndroidStudio 开发工具

大体思路

  1. 集成tensorflow jar包
  2. 添加tensorflow 训练的pb文件以及label.txt 文件
  3. 获取图像
  4. 初始化tensorflow接口
  5. 进行图像识别,返回结果
  6. 对结果进行处理

流程详解

1. 下载tensorflow jar 依赖包

     //build.gradle 文件中添加,并同步
     compile 'org.tensorflow:tensorflow-android:1.9.0'

2. assets

需要.pb 以及.txt文件, 一般将2文件存放到 assets 目录中
.pb: 训练后的结果
.txt: 输出的标签

3. 传入一个bitmap

  1. camera实时预览,可以从onPreviewFrame(camera api 1)中获取data 并将其从YUV 转成 RGB -> bitmap
  2. 不是实时,可以拍照或者直接图片,将其转换成bitmap

4. 获取图像数据

注册previewcallback 用来获取每一帧数据,此数据是 yuv 需要转换 RGB
1. 注册camera 实时预览 监听
2. 初始化tensorflow 使用的类
3. 将回调data 转换成RGB,并进一步转成Bitmap
4. 传入tensorflow bitmap
5. 获得识别结果

(1) 添加preview callback 回调监听

  private void xxxx() {
      if (previewBuffer == null) {
          int bufferSize = (int) (previewHeight * previewWidth * 1.5);
          previewBuffer = new byte[bufferSize];
      }
      mCamera.setPreviewCallbackWithBuffer(mOnPreviewCallback);
      mCamera.addCallbackBuffer(previewBuffer);
  }

注意:需要在camera startpreview之前完成此操作

(2) 数据处理

private Camera.PreviewCallback mOnPreviewCallback=new Camera.PreviewCallback() {
    @Override
    public void onPreviewFrame(final byte[] data, final Camera camera) {
    //获得当前preview size 的宽、高
        if (size == null) {
            size = camera.getParameters().getPreviewSize();
            previewWidth = size.width;
            previewHeight = size.height;
        }
        try {
            if (rgbBytes == null) {
            //初始化RGB 数据大小
                rgbBytes = new int[previewWidth * previewHeight];
            //初始化 bitmap 数据
                initBitmap();
            }

           //创建TensorFlow 模式,传入变量
            if (detector==null) {
                detector = TensorFlowXXimpl.create(
                                mActivity.getAssets(), 
                                "file:///android_asset/xxxx.pb",
                                "file:///android_asset/xxxx.txt", 
                                300);
            }

            //通过onOrientationChanged 获取 手机方向
            if (degree==0||degree==180) {
                sensorOrientation = 90 - getScreenOrientation();
            }else {
                sensorOrientation=getScreenOrientation();
            }
        } catch (final Exception e) {
            e.printStackTrace();
            return;
        }
        imageConverter =
                new Runnable() {
                    @Override
                    public void run() {
                    //将data 转换成 rgbBytes
                        ImageUtils.convertYUV420SPToARGB8888(data, 
                               previewWidth, previewHeight, rgbBytes);
                    }
                };
         //flag 用来判断上一次是否处理完,否则不进行下一次的运算
         if (flag) {
         //将照相机获取的原始图片,转换为300*300的图片,用来作为模型预测的输入
            frameToCropTransform =
                    ImageUtils.getTransformationMatrix(
                            previewWidth, previewHeight,
                            cropSize, cropSize,
                            sensorOrientation, false);
        //矩阵倒置
            frameToCropTransform.invert(cropToFrameTransform);
            singleThreadExecutor.execute(workRunnable);
        }

        //必须添加,否则onPreviewFrame 没有下一次的回调
        if (mCameraDevice != null) {
            mCameraDevice.getCamera().addCallbackBuffer(data);
        }
    }
};

注意
1. 结束后,必须添加addCallbackBuffer,否则没有下一次回调
2. 需要对图像进行旋转处理,否则不好识别
3. ImageUtils 的convertYUV420SPToARGB8888 方法可以见demo

(3) 初始化Bitmap

private void  initBitmap(){
    //真实宽高的bitmap
    rgbBitmap = Bitmap.createBitmap(previewWidth, previewHeight,
                                            Bitmap.Config.ARGB_8888);

    //输入 宽高 为300的 bitmap
    inputBitmap = Bitmap.createBitmap(cropSize, cropSize, 
                                             Bitmap.Config.ARGB_8888);

    //用来矩阵转换
    cropToFrameTransform = new Matrix();
 }

(4) 获得屏幕的方向

 protected int getScreenOrientation() {
    switch 
    (mActivity.getWindowManager().
                            getDefaultDisplay().getRotation()) {
        case ROTATION_270:
            return 270;
        case ROTATION_180:
            return 180;
        case ROTATION_90:
            return 90;
        default:
            return 0;
        }
    }

监听方向

@Override
public void onOrientationChanged(OrientationManager orientationM,
                OrientationManager.DeviceOrientation orientation) {
        super.onOrientationChanged(orientationManager, orientation);
        degree=orientation.getDegrees();
}

通过实现 OrientationManager.OnOrientationChangeListener

(5) 创建线程池

private ExecutorService singleThreadExecutor = 
                         Executors.newSingleThreadExecutor();

//在线程中进行图像的识别操作
Runnable workRunnable=new Runnable() {
    @Override
    public void run() {
        flag = false;
        processImage();
        flag = true;
    }
};

(6) 开始识别

private  void processImage() {
    //获得RGB 数据,设置到bitmap 中
     rgbBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);

    // 通过canvas 将rgbBitmap 通过frameToCropTransform  绘制到inputBitmap 中,类似裁剪成300 * 300 
     final Canvas canvas = new Canvas(croppedBitmap);
     canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);

    // 传入需要识别的 bitmap 进行识别,并返回结果
     if (detector!=null)
         results = detector.recognizeImage(croppedBitmap);

    //获得label 以及 概率
     if (results != null && results.size() > 0) {
         type = results.get(0).getTitle();
         score = results.get(0).getConfidence();

         //此处获取的识别后概率最大的场景
         //dosomething
     }
 }

tensorflow 流程

1. create 初始化

public static Classifier create(
          final AssetManager assetManager,
          final String modelFilename,
          final String labelFilename,
          final int inputSize) throws IOException {
    //创建自己对象,用来存放变量值
    final TensorFlowxxx d = new TensorFlowxxx ();

    //1.读取 camera_classify_map.txt 标签,将文本着行读取存到labels集合中
    InputStream labelsInput = null;
    String actualFilename =
                    labelFilename.split("file:///android_asset/")[1];
    labelsInput = assetManager.open(actualFilename);
    BufferedReader br = null;
    br = new BufferedReader(new InputStreamReader(labelsInput));
    String line;
    while ((line = br.readLine()) != null) {
      d.labels.add(line);
    }
    br.close();

    //2. 创建TensorFlow 接口
    //传入AssetManager 以及 frozen_inference_graph 训练结果文件
    d.inferenceInterface = new
            TensorFlowInferenceInterface(assetManager, modelFilename);

    //获取模型的图
    final Graph g = d.inferenceInterface.graph();

    //输入节点名字
    d.inputName = "inputName";

    d.inputSize = inputSize;

    //输出节点名
    d.outputNames = new String[] {"outputName1","outputName2"};

    //传入bitmap的像素
    d.intValues = new int[d.inputSize * d.inputSize];

    //输入数据,像素转换成RGB
    d.byteValues = new byte[d.inputSize * d.inputSize * 3];

    //初始化4个输出节点结果
    d.outputNameResult1 = new float[MAX_RESULTS];
    d.outputNameResult2 = new float[MAX_RESULTS];
    return d;
  }

create 流程
1. onCreate 一般只调用一次
传入参数分别为:
 AssetManager: 用来读取assets目录下文件
 modelFilename: .pb 文件的目录路径 “file:///android_asset/xxxx.pb”
 labelFilename: .txt 文件的目录路径 “file:///android_asset/xxxx.txt”;
 inputSize: 输入大小
2. 读取label 中的值
3. 创建TensorFlowInferenceInterface 对象
4. 初始化inputname 以及 outputname

注意:inputname outputname 都需要和训练中的名称一致

2. 识别

  1. 预处理bitmap
  2. feed 传入数据
  3. run 处理数据
  4. fetch 取得结果
  5. 概率排序
  6. 返回结果
  @Override
  public List<Recognition> recognizeImage(final Bitmap bitmap) {
    // Log this method so that it can be analyzed with systrace.
    Trace.beginSection("recognizeImage");

    Trace.beginSection("preprocessBitmap");
    // Preprocess the image data from 0-255 int to normalized float based
    // on the provided parameters.
      //预处理 bitmap 的像素
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

    for (int i = 0; i < intValues.length; ++i) {
      byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);    //取低两位         B
      byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);  //取中两位    G
      byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);  // 取高两位  R
    }
    Trace.endSection(); // preprocessBitmap

    //输入数据
    // Copy the input data into TensorFlow.
    Trace.beginSection("feed");
    //节点名字, 节点数据
    //后面参数为 inputName 节点的 shape, [batch_size, height, width, in_channel]
    inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
    Trace.endSection();

    //进行模型的推理
    //传入输出字符串名称
    // Run the inference call.
    Trace.beginSection("run");
    inferenceInterface.run(outputNames, logStats);
    Trace.endSection();

    //获取输出数据
    // Copy the output Tensor back into the output array.
    Trace.beginSection("fetch");
    //outputLocations = new float[MAX_RESULTS * 4];
    outputNameResult1 = new float[MAX_RESULTS];
    outputNameResult2 = new float[MAX_RESULTS];

    //返回概率 detection_scores  以及 对应 label  detection_classes  一一对应的
    inferenceInterface.fetch(outputNames[1], outputNameResult1 );
    inferenceInterface.fetch(outputNames[2], outputNameResult2 );
    Trace.endSection();

    //将结果放入队列中进行比较排列
    // Find the best detections.
    final PriorityQueue<Recognition> pq =
            new PriorityQueue<Recognition>(
                    1,
                    new Comparator<Recognition>() {
                      @Override
                      public int compare(final Recognition lhs, final Recognition rhs) {
                        // Intentionally reversed to put high confidence at the head of the queue.
                        //使用传入的Recognition 的 confidence值 排序
                        return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                      }
                    });


    // Scale them back to the input size.
    for (int i = 0; i < outputNameResult1.length; ++i) {
        //将数据存入Recognition 类中,并加入队列排序
        pq.add( new Recognition("" + i, labels.get((int) outputNameResult2[i]), outputNameResult1[i], null));
    }

    //获得顶部值存入集合,取第一个数索引为0 为概率最大;
    //由于用了poll方法,此时相当于加入了一半的信息
    final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
    for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
        recognitions.add(pq.poll());
    }
    Trace.endSection(); // "recognizeImage"
    return recognitions;
  }

ps:此博客写的是集成的是tensorflow结果,训练过程 以及 android 使用摄像头的使用不在其中
看到新问题会继续补充~~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值