1.首先把你YOLO模型转为torchscript格式。
2.然后把模型文件放在你的安卓项目的资源文件下【需要添加标签文件(如下图)】
3.添加依赖库【Gradle app文件下】
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.9.0'
4.编写加载模型的代码类【不详细解释-直接放代码】
4.1【推理】
public class PrePostProcessor {
public static float[] NO_MEAN_RGB = new float[] {0.0f, 0.0f, 0.0f};
public static float[] NO_STD_RGB = new float[] {1.0f, 1.0f, 1.0f};
public static int mInputWidth = 320;
public static int mInputHeight = 320;
private static final int mOutputRow = 6300;
private static final int mOutputColumn = 9;
private static final float mThreshold = 0.35f;
private static final int mNmsLimit = 5;
public static String[] mClasses;
static float IOU(Rect a, Rect b){
float areaA = (a.right - a.left) * (a.bottom - a.top);
if (areaA <= 0.0) return 0.0f;
float areaB = (b.right - b.left) * (b.bottom - b.top);
if (areaB <= 0.0) return 0.0f;
float intersectionMinX = Math.max(a.left, b.left);
float intersectionMinY = Math.max(a.top, b.top);
float intersectionMaxX = Math.min(a.right, b.right);
float intersectionMaxY = Math.min(a.bottom, b.bottom);
float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0 )*
Math.max(intersectionMaxX - intersectionMinX, 0);
return intersectionArea / (areaA + areaB - intersectionArea);
}
static ArrayList<ResultCAR> nonMaxSuppression(ArrayList<ResultCAR> boxes, int limit, float threshold){
Collections.sort(boxes,
new Comparator<ResultCAR>(){
@Override
public int compare(ResultCAR o1, ResultCAR o2){
return o1.score.compareTo(o2.score);
}
});
ArrayList<ResultCAR> selected = new ArrayList<>();
boolean[] active = new boolean[boxes.size()];
Arrays.fill(active, true);
int numActive = active.length;
boolean done = false;
for (int i=0; i<boxes.size() && !done; i++){
if (active[i]){
ResultCAR boxA = boxes.get(i);
selected.add(boxA);
if (selected.size() >= limit) break;
for(int j = i+1; j<boxes.size();j++){
if(active[j]){
ResultCAR boxB = boxes.get(j);
if (IOU(boxA.raw_rect, boxB.raw_rect)>threshold){
active[j] = false;
numActive -= 1;
if (numActive <= 0){
done = true;
break;
}
}
}
}
}
}
return selected;
}
public static ArrayList<ResultCAR> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY,float startX, float startY){
ArrayList<ResultCAR> results = new ArrayList<>();
for (int i=0; i<mOutputRow; i++){
if (outputs[i* mOutputColumn +4]>mThreshold){
float x = outputs[i* mOutputColumn];
float y = outputs[i* mOutputColumn +1];
float w = outputs[i* mOutputColumn +2];
float h = outputs[i* mOutputColumn +3];
float left = imgScaleX * (x - w/2);
float top = imgScaleY * (y-h/2);
float right = imgScaleX * (x + w/2);
float bottom = imgScaleY * (y + h/2);
float max = outputs[i* mOutputColumn +5];
int cls = 0;
for (int j=0; j<mOutputColumn-5;j++){
if (outputs[i* mOutputColumn +5+j] > max){
max = outputs[i * mOutputColumn +5+j];
cls = j;
}
}
Rect rect = new Rect((int)(startX + ivScaleX*left),(int)(startY+top*ivScaleY),
(int)(startX+ivScaleX*right), (int) (startY+ivScaleY*bottom));
ResultCAR result = new ResultCAR(cls, outputs[i * mOutputColumn+4], rect);
results.add(result);
}
}
return nonMaxSuppression(results, mNmsLimit, mThreshold);
}
4.2【获取推理结果】
private static float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;
public static Bitmap resimg = null;
public static String runCAR(Bitmap mBitmap, Module mModuleCarTag , Context context, boolean isSaveImage) {
Bitmap corpBitmap = null;
String resulthld = null;
mImgScaleX = (float) mBitmap.getWidth() / PrePostProcessor.mInputWidth;
mImgScaleY = (float) mBitmap.getHeight() / PrePostProcessorCar.mInputHeight;
mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float) 1 / mBitmap.getWidth() : (float) 1 / mBitmap.getHeight());
mIvScaleY = (mBitmap.getHeight() > mBitmap.getWidth() ? (float) 1 / mBitmap.getHeight() : (float) 1 / mBitmap.getWidth());
mStartX = (1 - mIvScaleX * mBitmap.getWidth()) / 2;
mStartY = (1 - mIvScaleY * mBitmap.getHeight()) / 2;
// 缩放Bitmap
Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessorCar.mInputWidth, PrePostProcessorCar.mInputHeight, true);
// Bitmap -> Tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessorCar.NO_MEAN_RGB, PrePostProcessorCar.NO_STD_RGB);
IValue[] outputTuple = mModuleCarTag.forward(IValue.from(inputTensor)).toTuple();
final Tensor outputTensor = outputTuple[0].toTensor();
final float[] outputs = outputTensor.getDataAsFloatArray();
final ArrayList<ResultCAR> results = PrePostProcessorCar.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY); // 非极大值抑制
Set<Integer> set = new HashSet<>();
for (int i = 0; i < results.size(); i++) {
set.add(results.get(i).classIndex);
Log.e("置信度:", results.get(i).score + "");
if(isSaveImage) {
Rect rect = results.get(i).raw_rect;
corpBitmap = Bitmap.createBitmap(mBitmap, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top);
Mat mat = new Mat();
Utils.bitmapToMat(corpBitmap, mat);
Imgproc.putText(mat, PrePostProcessor.mClasses[results.get(i).classIndex], new org.opencv.core.Point(10, 10), 1, 1, new org.opencv.core.Scalar(0, 0, 255), 1);
corpBitmap = ImageUtils.mat2Bitmap(mat);
resimg = corpBitmap;
SaveBitmap.saveImageToGallery(context, corpBitmap);
}
}
List<Integer> list = new ArrayList<>(set);
Collections.sort(list);
for (int i = 0; i < list.size(); i++) {
Log.e("识别结果", PrePostProcessor.mClasses[list.get(i)]);
if(list.size()>=1) {
resulthld = PrePostProcessor.mClasses[list.get(i)];
}
}
return resulthld;
}
4.3【补充Result】
public class ResultCAR {
public int classIndex;
public Float score;
public Rect rect;
public Rect raw_rect;
public ResultCAR(int cls, Float output, Rect rect,Rect raw_rect){
this.classIndex = cls;
this.score = output;
this.rect = rect;
this.raw_rect = raw_rect;
}
}
5.其他UI代码就自己搞定吧,ResultCAR中rect存放的就是目标的坐标。