成果图
点击 BUTTION_SELECT_IMAGE
选择完照片后,显示图片和预测结果
注意:
-
本案例是基于tensorflowlite的案例 image classification 进行删减和修改,使其代码更易看懂。
-
所使用的.tflite文件和labels.txt也是官方案例提供的文件,自己得先下载下来。(我不知道怎么附文件,如果需要,自己从官方案例下载,或者留言评论,我发给你)
-
下载的 .tflite文件 和 labels.txt 放在asset目录下(没有的话,自己创建)
-
应用运行在安卓 9.0 的手机上
代码
build.gradle
dependencies {
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true }
implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true }
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'com.google.android.material:material:1.1.0'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
}
java文件:
Clasifier.java
package com.example.anothertfdemo;
import android.app.Activity;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.SystemClock;
import android.os.Trace;
import android.util.Log;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.gpu.GpuDelegate;
import org.tensorflow.lite.nnapi.NnApiDelegate;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import static java.lang.Math.min;
public abstract class Classifier {
/** Number of results to show in the UI. */
private static final int MAX_RESULTS = 3;
public static final String TAG = "ClassifierWithSupport";
/** Image size along the x axis. */
private final int imageSizeX;
/** Image size along the y axis. */
private final int imageSizeY;
/** Optional GPU delegate for accleration. */
private GpuDelegate gpuDelegate = null;
/** Optional NNAPI delegate for accleration. */
private NnApiDelegate nnApiDelegate = null;
/** An instance of the driver class to run model inference with Tensorflow Lite. */
protected Interpreter tflite;
/** Options for configuring the Interpreter. */
private final Interpreter.Options tfliteOptions = new Interpreter.Options();
/** Labels corresponding to the output of the vision model. */
private final List<String> labels;
/** Input image TensorBuffer. */
private TensorImage inputImageBuffer;
/** Output probability TensorBuffer. */
private final TensorBuffer outputProbabilityBuffer;
/** Processer to apply post processing of the output probability. */
private final TensorProcessor probabilityProcessor;
public static Classifier create(Activity activity) throws IOException {
return new ClassifierQuantizedMobileNet(activity);
}
protected Classifier(Activity activity) throws IOException {
MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath());
tfliteOptions.setUseXNNPACK(true);
tfliteOptions.setNumThreads(1);
tflite = new Interpreter(tfliteModel, tfliteOptions);
// Loads labels out from the label file.
labels = FileUtil.loadLabels(activity, getLabelPath());
// Reads type and shape of input and output tensors, respectively.
int imageTensorIndex = 0;
int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3}
imageSizeY = imageShape[1];
imageSizeX = imageShape[2];
System.out.println("imageSizeX: " + imageSizeX);
System.out.println("imageSizeY: " + imageSizeY);
DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType();
int probabilityTensorIndex = 0;
int[] probabilityShape =
tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES}
DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType();
// Creates the input tensor.
inputImageBuffer = new TensorImage(imageDataType);
// Creates the output tensor and its processor.
outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType);
// Creates the post processor for the output probability.
probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build();
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
/** Gets the name of the model file stored in Assets. */
protected abstract String getModelPath();
/** Gets the name of the label file stored in Assets. */
protected abstract String getLabelPath();
/** Gets the TensorOperator to nomalize the input image in preprocessing. */
protected abstract TensorOperator getPreprocessNormalizeOp();
/**
* Gets the TensorOperator to dequantize the output probability in post processing.
*
* <p>For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all
* essentially linear transformation). For float model, de-quantize is not required. But to
* uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and
* 1.0f, respectively.
*/
protected abstract TensorOperator getPostprocessNormalizeOp();
/** An immutable result returned by a Classifier describing what was recognized. */
/** Loads input image, and applies preprocessing. */
private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) {
// Loads bitmap into a TensorImage.
inputImageBuffer.load(bitmap);
// Creates processor for the TensorImage.
int cropSize = min(bitmap.getWidth(), bitmap.getHeight());
int numRotation = sensorOrientation / 90;
// TODO(b/143564309): Fuse ops inside ImageProcessor.
ImageProcessor imageProcessor =
new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
// TODO(b/169379396): investigate the impact of the resize algorithm on accuracy.
// To get the same inference results as lib_task_api, which is built on top of the Task
// Library, use ResizeMethod.BILINEAR.
.add(new ResizeOp(imageSizeX, imageSizeY, ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new Rot90Op(numRotation))
.add(getPreprocessNormalizeOp())
.build();
return imageProcessor.process(inputImageBuffer);
}
public static class Recognition{
/**
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
*/
private final String id;
/** Display name for the recognition. */
private final String title;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private final Float confidence;
public Recognition(
final String id, final String title,
final Float confidence) {
this.id = id;
this.title = title;
this.confidence = confidence;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
return resultString.trim();
}
}
/** Gets the top-k results. */
private static List<Recognition> getTopKProbability(Map<String, Float> labelProb) {
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<>(
MAX_RESULTS,
new Comparator<Recognition>() {
@Override
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (Map.Entry<String, Float> entry : labelProb.entrySet()) {
pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue()));
}
final ArrayList<Recognition> recognitions = new ArrayList<>();
int recognitionsSize = min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
return recognitions;
}
public List<Recognition> recognizeImage(final Bitmap bitmap, int sensorOrientation) {
// Logs this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("loadImage");
long startTimeForLoadImage = SystemClock.uptimeMillis();
inputImageBuffer = loadImage(bitmap, sensorOrientation);
long endTimeForLoadImage = SystemClock.uptimeMillis();
Trace.endSection();
Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage));
// Runs the inference call.
Trace.beginSection("runInference");
long startTimeForReference = SystemClock.uptimeMillis();
tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
long endTimeForReference = SystemClock.uptimeMillis();
Trace.endSection();
Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference));
// Gets the map of label and probability.
Map<String, Float> labeledProbability =
new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
.getMapWithFloatValue();
Trace.endSection();
// Gets top-k results.
return getTopKProbability(labeledProbability);
}
}
ClassifierQuantizedMobileNet.java
package com.example.anothertfdemo;
import android.app.Activity;
import org.tensorflow.lite.support.common.TensorOperator;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import java.io.IOException;
public class ClassifierQuantizedMobileNet extends Classifier{
/**
* The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to
* bypass the normalization.
*/
private static final float IMAGE_MEAN = 0.0f;
private static final float IMAGE_STD = 1.0f;
/** Quantized MobileNet requires additional dequantization to the output probability. */
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 255.0f;
/**
* Initializes a {@code ClassifierQuantizedMobileNet}.
*
* @param activity
*/
public ClassifierQuantizedMobileNet(Activity activity)
throws IOException {
super(activity);
}
@Override
protected String getModelPath() {
// you can download this file from
// see build.gradle for where to obtain this file. It should be auto
// downloaded into assets.
return "mobilenet_v1_1.0_224_quant.tflite";
}
@Override
protected String getLabelPath() {
return "labels.txt";
}
@Override
protected TensorOperator getPreprocessNormalizeOp() {
return new NormalizeOp(IMAGE_MEAN, IMAGE_STD);
}
@Override
protected TensorOperator getPostprocessNormalizeOp() {
return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD);
}
}
MainActivity.java
package com.example.anothertfdemo;
import androidx.appcompat.app.AppCompatActivity;
import android.content.Context;
import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable;
import android.net.Uri;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
public class MainActivity extends AppCompatActivity {
ImageView selectedImageView;
TextView answer;
Classifier classifier;
final int SELECT_PICTURE = 1;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
try {
classifier = Classifier.create(this);
} catch (IOException e) {
e.printStackTrace();
}
setContentView(R.layout.activity_main);
selectedImageView = findViewById(R.id.selectedImageView);
answer = findViewById(R.id.answer);
}
// 在布局中定义的按钮将触发打开图像选择器的函数
public void onSelectImageButtonClicked(View view) {
Intent intent = new Intent(Intent.ACTION_GET_CONTENT);
intent.setType("image/*");
Intent chooser = Intent.createChooser(intent, "Choose a Picture");
startActivityForResult(chooser, SELECT_PICTURE);
}
// 当用户激活此函数时,系统的图像选择器将打开。用户选择图像后,控制权返回给应用程序。要检索选择,活动必须实现方法onActivityResult()。所选图像的URI在传递给此方法的数据对象内。
public void onActivityResult(int reqCode, int resultCode, Intent data) {
super.onActivityResult(reqCode, resultCode, data);
if (resultCode == RESULT_OK) {
if (reqCode == SELECT_PICTURE) {
Uri selectedUri = data.getData();
String fileString = selectedUri.getPath();
selectedImageView.setImageURI(selectedUri);
Bitmap bitmap = ((BitmapDrawable)selectedImageView.getDrawable()).getBitmap();
List<Classifier.Recognition> result = classifier.recognizeImage(bitmap, 0);
answer.setText(result.get(0).getTitle());
}
}
}
}
布局文件:
activity_main.xml
(我自己放了一张初始图片)
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/selectedImageView"
android:layout_width="match_parent"
android:layout_height="300dp"
android:src="@drawable/dog"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toTopOf="@+id/selectImageButton"
/>
<Button
android:id="@+id/selectImageButton"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:onClick="onSelectImageButtonClicked"
android:text="button_select_image"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/selectedImageView"
app:layout_constraintBottom_toTopOf="@+id/answer"
/>
<TextView
android:id="@+id/answer"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="What's this?"
android:textAlignment="center"
android:textSize="30dp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/selectImageButton" />
</androidx.constraintlayout.widget.ConstraintLayout>