引言
随着人工智能应用越来越普及,目前app 项目正在集成 以 tensorflow 为基础 训练出的结果,主要为图像的场景识别相关。
由于项目关键代码不方便公布,并且是参照google 提供集成android demo 来进行书写代码的,所以同时可以参考代码:
tensorflow android demo
环境
Android 9.0 平台
AndroidStudio 开发工具
大体思路
- 集成tensorflow jar包
- 添加tensorflow 训练的pb文件以及label.txt 文件
- 获取图像
- 初始化tensorflow接口
- 进行图像识别,返回结果
- 对结果进行处理
流程详解
1. 下载tensorflow jar 依赖包
//build.gradle 文件中添加,并同步
compile 'org.tensorflow:tensorflow-android:1.9.0'
2. assets
需要.pb 以及.txt文件, 一般将2文件存放到 assets 目录中
.pb: 训练后的结果
.txt: 输出的标签
3. 传入一个bitmap
- camera实时预览,可以从onPreviewFrame(camera api 1)中获取data 并将其从YUV 转成 RGB -> bitmap
- 不是实时,可以拍照或者直接图片,将其转换成bitmap
4. 获取图像数据
注册previewcallback 用来获取每一帧数据,此数据是 yuv 需要转换 RGB
1. 注册camera 实时预览 监听
2. 初始化tensorflow 使用的类
3. 将回调data 转换成RGB,并进一步转成Bitmap
4. 传入tensorflow bitmap
5. 获得识别结果
(1) 添加preview callback 回调监听
private void xxxx() {
if (previewBuffer == null) {
int bufferSize = (int) (previewHeight * previewWidth * 1.5);
previewBuffer = new byte[bufferSize];
}
mCamera.setPreviewCallbackWithBuffer(mOnPreviewCallback);
mCamera.addCallbackBuffer(previewBuffer);
}
注意:需要在camera startpreview之前完成此操作
(2) 数据处理
private Camera.PreviewCallback mOnPreviewCallback=new Camera.PreviewCallback() {
@Override
public void onPreviewFrame(final byte[] data, final Camera camera) {
//获得当前preview size 的宽、高
if (size == null) {
size = camera.getParameters().getPreviewSize();
previewWidth = size.width;
previewHeight = size.height;
}
try {
if (rgbBytes == null) {
//初始化RGB 数据大小
rgbBytes = new int[previewWidth * previewHeight];
//初始化 bitmap 数据
initBitmap();
}
//创建TensorFlow 模式,传入变量
if (detector==null) {
detector = TensorFlowXXimpl.create(
mActivity.getAssets(),
"file:///android_asset/xxxx.pb",
"file:///android_asset/xxxx.txt",
300);
}
//通过onOrientationChanged 获取 手机方向
if (degree==0||degree==180) {
sensorOrientation = 90 - getScreenOrientation();
}else {
sensorOrientation=getScreenOrientation();
}
} catch (final Exception e) {
e.printStackTrace();
return;
}
imageConverter =
new Runnable() {
@Override
public void run() {
//将data 转换成 rgbBytes
ImageUtils.convertYUV420SPToARGB8888(data,
previewWidth, previewHeight, rgbBytes);
}
};
//flag 用来判断上一次是否处理完,否则不进行下一次的运算
if (flag) {
//将照相机获取的原始图片,转换为300*300的图片,用来作为模型预测的输入
frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
cropSize, cropSize,
sensorOrientation, false);
//矩阵倒置
frameToCropTransform.invert(cropToFrameTransform);
singleThreadExecutor.execute(workRunnable);
}
//必须添加,否则onPreviewFrame 没有下一次的回调
if (mCameraDevice != null) {
mCameraDevice.getCamera().addCallbackBuffer(data);
}
}
};
注意
1. 结束后,必须添加addCallbackBuffer,否则没有下一次回调
2. 需要对图像进行旋转处理,否则不好识别
3. ImageUtils 的convertYUV420SPToARGB8888 方法可以见demo
(3) 初始化Bitmap
private void initBitmap(){
//真实宽高的bitmap
rgbBitmap = Bitmap.createBitmap(previewWidth, previewHeight,
Bitmap.Config.ARGB_8888);
//输入 宽高 为300的 bitmap
inputBitmap = Bitmap.createBitmap(cropSize, cropSize,
Bitmap.Config.ARGB_8888);
//用来矩阵转换
cropToFrameTransform = new Matrix();
}
(4) 获得屏幕的方向
protected int getScreenOrientation() {
switch
(mActivity.getWindowManager().
getDefaultDisplay().getRotation()) {
case ROTATION_270:
return 270;
case ROTATION_180:
return 180;
case ROTATION_90:
return 90;
default:
return 0;
}
}
监听方向
@Override
public void onOrientationChanged(OrientationManager orientationM,
OrientationManager.DeviceOrientation orientation) {
super.onOrientationChanged(orientationManager, orientation);
degree=orientation.getDegrees();
}
通过实现 OrientationManager.OnOrientationChangeListener
(5) 创建线程池
private ExecutorService singleThreadExecutor =
Executors.newSingleThreadExecutor();
//在线程中进行图像的识别操作
Runnable workRunnable=new Runnable() {
@Override
public void run() {
flag = false;
processImage();
flag = true;
}
};
(6) 开始识别
private void processImage() {
//获得RGB 数据,设置到bitmap 中
rgbBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
// 通过canvas 将rgbBitmap 通过frameToCropTransform 绘制到inputBitmap 中,类似裁剪成300 * 300
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// 传入需要识别的 bitmap 进行识别,并返回结果
if (detector!=null)
results = detector.recognizeImage(croppedBitmap);
//获得label 以及 概率
if (results != null && results.size() > 0) {
type = results.get(0).getTitle();
score = results.get(0).getConfidence();
//此处获取的识别后概率最大的场景
//dosomething
}
}
tensorflow 流程
1. create 初始化
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
final int inputSize) throws IOException {
//创建自己对象,用来存放变量值
final TensorFlowxxx d = new TensorFlowxxx ();
//1.读取 camera_classify_map.txt 标签,将文本着行读取存到labels集合中
InputStream labelsInput = null;
String actualFilename =
labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
d.labels.add(line);
}
br.close();
//2. 创建TensorFlow 接口
//传入AssetManager 以及 frozen_inference_graph 训练结果文件
d.inferenceInterface = new
TensorFlowInferenceInterface(assetManager, modelFilename);
//获取模型的图
final Graph g = d.inferenceInterface.graph();
//输入节点名字
d.inputName = "inputName";
d.inputSize = inputSize;
//输出节点名
d.outputNames = new String[] {"outputName1","outputName2"};
//传入bitmap的像素
d.intValues = new int[d.inputSize * d.inputSize];
//输入数据,像素转换成RGB
d.byteValues = new byte[d.inputSize * d.inputSize * 3];
//初始化4个输出节点结果
d.outputNameResult1 = new float[MAX_RESULTS];
d.outputNameResult2 = new float[MAX_RESULTS];
return d;
}
create 流程
1. onCreate 一般只调用一次
传入参数分别为:
AssetManager: 用来读取assets目录下文件
modelFilename: .pb 文件的目录路径 “file:///android_asset/xxxx.pb”
labelFilename: .txt 文件的目录路径 “file:///android_asset/xxxx.txt”;
inputSize: 输入大小
2. 读取label 中的值
3. 创建TensorFlowInferenceInterface 对象
4. 初始化inputname 以及 outputname
注意:inputname outputname 都需要和训练中的名称一致
2. 识别
- 预处理bitmap
- feed 传入数据
- run 处理数据
- fetch 取得结果
- 概率排序
- 返回结果
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
//预处理 bitmap 的像素
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); //取低两位 B
byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); //取中两位 G
byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); // 取高两位 R
}
Trace.endSection(); // preprocessBitmap
//输入数据
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
//节点名字, 节点数据
//后面参数为 inputName 节点的 shape, [batch_size, height, width, in_channel]
inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
Trace.endSection();
//进行模型的推理
//传入输出字符串名称
// Run the inference call.
Trace.beginSection("run");
inferenceInterface.run(outputNames, logStats);
Trace.endSection();
//获取输出数据
// Copy the output Tensor back into the output array.
Trace.beginSection("fetch");
//outputLocations = new float[MAX_RESULTS * 4];
outputNameResult1 = new float[MAX_RESULTS];
outputNameResult2 = new float[MAX_RESULTS];
//返回概率 detection_scores 以及 对应 label detection_classes 一一对应的
inferenceInterface.fetch(outputNames[1], outputNameResult1 );
inferenceInterface.fetch(outputNames[2], outputNameResult2 );
Trace.endSection();
//将结果放入队列中进行比较排列
// Find the best detections.
final PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
1,
new Comparator<Recognition>() {
@Override
public int compare(final Recognition lhs, final Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
//使用传入的Recognition 的 confidence值 排序
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
// Scale them back to the input size.
for (int i = 0; i < outputNameResult1.length; ++i) {
//将数据存入Recognition 类中,并加入队列排序
pq.add( new Recognition("" + i, labels.get((int) outputNameResult2[i]), outputNameResult1[i], null));
}
//获得顶部值存入集合,取第一个数索引为0 为概率最大;
//由于用了poll方法,此时相当于加入了一半的信息
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
ps:此博客写的是集成的是tensorflow结果,训练过程 以及 android 使用摄像头的使用不在其中
看到新问题会继续补充~~