Android studio项目加载pytorch模型文件

1.首先把你YOLO模型转为torchscript格式。

2.然后把模型文件放在你的安卓项目的资源文件下【需要添加标签文件(如下图)】

3.添加依赖库【Gradle app文件下】 

    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision_lite:1.9.0'

4.编写加载模型的代码类【不详细解释-直接放代码】

        4.1【推理】

public class PrePostProcessor {
    public static float[] NO_MEAN_RGB = new float[] {0.0f, 0.0f, 0.0f};
    public static float[] NO_STD_RGB = new float[] {1.0f, 1.0f, 1.0f};
    public static int mInputWidth = 320;
    public static int mInputHeight = 320;

    private static final int mOutputRow = 6300; 
    private static final int mOutputColumn = 9; 
    private static final float mThreshold = 0.35f;
    private static final int mNmsLimit = 5;

    public static  String[] mClasses;

    static float IOU(Rect a, Rect b){
        float areaA = (a.right - a.left) * (a.bottom - a.top);
        if (areaA <= 0.0) return 0.0f;

        float areaB = (b.right - b.left) * (b.bottom - b.top);
        if (areaB <= 0.0) return 0.0f;

        float intersectionMinX = Math.max(a.left, b.left);
        float intersectionMinY = Math.max(a.top, b.top);
        float intersectionMaxX = Math.min(a.right, b.right);
        float intersectionMaxY = Math.min(a.bottom, b.bottom);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0 )*
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    static ArrayList<ResultCAR> nonMaxSuppression(ArrayList<ResultCAR> boxes, int limit, float threshold){
        Collections.sort(boxes,
                new Comparator<ResultCAR>(){
                    @Override
                    public int compare(ResultCAR o1, ResultCAR o2){
                        return o1.score.compareTo(o2.score);
                    }
                });
        ArrayList<ResultCAR> selected = new ArrayList<>();
        boolean[] active = new boolean[boxes.size()];
        Arrays.fill(active, true);
        int numActive = active.length;

        boolean done = false;
        for (int i=0; i<boxes.size() && !done; i++){
            if (active[i]){
                ResultCAR boxA = boxes.get(i);
                selected.add(boxA);
                if (selected.size() >= limit) break;

                for(int j = i+1; j<boxes.size();j++){
                    if(active[j]){
                        ResultCAR boxB = boxes.get(j);
                        if (IOU(boxA.raw_rect, boxB.raw_rect)>threshold){
                            active[j] = false;
                            numActive -= 1;
                            if  (numActive <= 0){
                                done = true;
                                break;
                            }
                        }
                    }
                }
            }
        }
        return selected;
    }

    public static ArrayList<ResultCAR> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY,float startX, float startY){
        ArrayList<ResultCAR> results = new ArrayList<>();
        for (int i=0; i<mOutputRow; i++){
            if (outputs[i* mOutputColumn +4]>mThreshold){
                float x = outputs[i* mOutputColumn];
                float y = outputs[i* mOutputColumn +1];
                float w = outputs[i* mOutputColumn +2];
                float h = outputs[i* mOutputColumn +3];

                float left = imgScaleX * (x - w/2);
                float top = imgScaleY * (y-h/2);
                float right = imgScaleX * (x + w/2);
                float bottom = imgScaleY * (y + h/2);

                float max = outputs[i* mOutputColumn +5];
                int cls = 0;
                for (int j=0; j<mOutputColumn-5;j++){
                    if (outputs[i* mOutputColumn +5+j] > max){
                        max = outputs[i * mOutputColumn +5+j];
                        cls = j;
                    }
                }
                Rect rect = new Rect((int)(startX + ivScaleX*left),(int)(startY+top*ivScaleY),
                        (int)(startX+ivScaleX*right), (int) (startY+ivScaleY*bottom));
                ResultCAR result = new ResultCAR(cls, outputs[i * mOutputColumn+4], rect);
                results.add(result);
            }
        }
        return nonMaxSuppression(results, mNmsLimit, mThreshold);
    }

         4.2【获取推理结果】

    private static float  mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;
    public  static Bitmap resimg  = null;
    public static String runCAR(Bitmap mBitmap, Module mModuleCarTag , Context context, boolean isSaveImage) {
        Bitmap corpBitmap = null;
        String resulthld = null;
        mImgScaleX = (float) mBitmap.getWidth() / PrePostProcessor.mInputWidth;          
        mImgScaleY = (float) mBitmap.getHeight() / PrePostProcessorCar.mInputHeight;
        mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float) 1 / mBitmap.getWidth() : (float) 1 / mBitmap.getHeight());
        mIvScaleY = (mBitmap.getHeight() > mBitmap.getWidth() ? (float) 1 / mBitmap.getHeight() : (float) 1 / mBitmap.getWidth());
        mStartX = (1 - mIvScaleX * mBitmap.getWidth()) / 2;
        mStartY = (1 - mIvScaleY * mBitmap.getHeight()) / 2;
        // 缩放Bitmap
        Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessorCar.mInputWidth, PrePostProcessorCar.mInputHeight, true);
        // Bitmap -> Tensor
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessorCar.NO_MEAN_RGB, PrePostProcessorCar.NO_STD_RGB);    
        IValue[] outputTuple = mModuleCarTag.forward(IValue.from(inputTensor)).toTuple();     
        final Tensor outputTensor = outputTuple[0].toTensor();                          
        final float[] outputs = outputTensor.getDataAsFloatArray();                     
        final ArrayList<ResultCAR> results = PrePostProcessorCar.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY);   // 非极大值抑制
        Set<Integer> set = new HashSet<>();
        for (int i = 0; i < results.size(); i++) {
            set.add(results.get(i).classIndex);
            Log.e("置信度:", results.get(i).score + "");
            if(isSaveImage) {
                Rect rect = results.get(i).raw_rect;
                corpBitmap = Bitmap.createBitmap(mBitmap, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top);
                Mat mat = new Mat();
                Utils.bitmapToMat(corpBitmap, mat);
                Imgproc.putText(mat, PrePostProcessor.mClasses[results.get(i).classIndex], new org.opencv.core.Point(10, 10), 1, 1, new org.opencv.core.Scalar(0, 0, 255), 1);
                corpBitmap = ImageUtils.mat2Bitmap(mat);
                resimg = corpBitmap;
                SaveBitmap.saveImageToGallery(context, corpBitmap);
            }
        }
        List<Integer> list = new ArrayList<>(set);
        Collections.sort(list);

            for (int i = 0; i < list.size(); i++) {
                    Log.e("识别结果", PrePostProcessor.mClasses[list.get(i)]);
                if(list.size()>=1) {
                    resulthld = PrePostProcessor.mClasses[list.get(i)];
                }

            }

        return resulthld;
    }

        4.3【补充Result】

public class ResultCAR {
    public int classIndex;
    public Float score;
    public Rect rect;
    public Rect raw_rect;

    public ResultCAR(int cls, Float output, Rect rect,Rect raw_rect){
        this.classIndex = cls;
        this.score = output;
        this.rect = rect;
        this.raw_rect = raw_rect;
    }
}

5.其他UI代码就自己搞定吧,ResultCAR中rect存放的就是目标的坐标。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,以下是使用Python和YOLO模型进行水果识别的代码: ```python # 导入相关的库 import cv2 import numpy as np import argparse # 设置命令行参数 ap = argparse.ArgumentParser() ap.add_argument("-i", "--image", required=True, help="path to input image") ap.add_argument("-c", "--config", required=True, help="path to yolo config file") ap.add_argument("-w", "--weights", required=True, help="path to yolo pre-trained weights") ap.add_argument("-cl", "--classes", required=True, help="path to text file containing class names") args = vars(ap.parse_args()) # 加载类别名 classes = None with open(args["classes"], 'r') as f: classes = [line.strip() for line in f.readlines()] # 加载模型配置和权重 net = cv2.dnn.readNetFromDarknet(args["config"], args["weights"]) # 加载输入图像并进行预处理 image = cv2.imread(args["image"]) blob = cv2.dnn.blobFromImage(image, 1/255.0, (416, 416), swapRB=True, crop=False) # 设置模型的输入和输出节点 net.setInput(blob) output_layers_names = net.getUnconnectedOutLayersNames() layerOutputs = net.forward(output_layers_names) # 初始化输出结果 boxes = [] confidences = [] classIDs = [] # 循环遍历每个输出层,提取检测结果 for output in layerOutputs: for detection in output: scores = detection[5:] classID = np.argmax(scores) confidence = scores[classID] # 过滤掉置信度低的检测结果 if confidence > 0.5: box = detection[0:4] * np.array([image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) (centerX, centerY, width, height) = box.astype("int") # 计算边框的左上角坐标 x = int(centerX - (width / 2)) y = int(centerY - (height / 2)) # 更新输出结果 boxes.append([x, y, int(width), int(height)]) confidences.append(float(confidence)) classIDs.append(classID) # 对输出结果进行NMS处理,去除冗余的检测结果 idxs = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4) # 在图像上绘制检测结果 if len(idxs) > 0: for i in idxs.flatten(): (x, y) = (boxes[i][0], boxes[i][1]) (w, h) = (boxes[i][2], boxes[i][3]) color = [int(c) for c in COLORS[classIDs[i]]] cv2.rectangle(image, (x, y), (x + w, y + h), color, 2) text = "{}: {:.4f}".format(classes[classIDs[i]], confidences[i]) cv2.putText(image, text, (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) # 显示输出图像 cv2.imshow("Image", image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 使用时可以在命令行中运行以下命令: ``` python fruit_detection.py --image input_image.jpg --config yolo.cfg --weights yolo.weights --classes classes.txt ``` 其中,`input_image.jpg`是要识别的图像,`yolo.cfg`和`yolo.weights`是YOLO模型的配置文件和权重文件,`classes.txt`是包含类别名的文本文件

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值