ONNX格式下图像分类模型输入标准化(java版)

大家在使用onnx格式的图像分类模型或者其他需要输入图像的模型时会遇到输入图像格式问题,这点onnx官方给出的案例模板中已经有现成的kotlin代码,但是萌新使用java进行安卓开发时不会使用kotlin,所以我写了一份java版的处理代码给大家参考

目录

一、图像处理类

二、使用方法


一、图像处理类

import android.graphics.Bitmap;

import java.nio.FloatBuffer;

public class ImagePreProcessor {

    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_PIXEL_SIZE = 3;
    private static final int IMAGE_SIZE_X = 224;
    private static final int IMAGE_SIZE_Y = 224;

    public static FloatBuffer preProcess(Bitmap bitmap) {
        FloatBuffer imgData = FloatBuffer.allocate(
                DIM_BATCH_SIZE
                        * DIM_PIXEL_SIZE
                        * IMAGE_SIZE_X
                        * IMAGE_SIZE_Y
        );
        imgData.rewind();
        int stride = IMAGE_SIZE_X * IMAGE_SIZE_Y;
        int[] bmpData = new int[stride];
        bitmap.getPixels(bmpData, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        for (int i = 0; i < IMAGE_SIZE_X; i++) {
            for (int j = 0; j < IMAGE_SIZE_Y; j++) {
                int idx = IMAGE_SIZE_Y * i + j;
                int pixelValue = bmpData[idx];
                imgData.put(idx, (((pixelValue >> 16 & 0xFF) / 255f - 0.485f) / 0.229f));
                imgData.put(idx + stride, (((pixelValue >> 8 & 0xFF) / 255f - 0.456f) / 0.224f));
                imgData.put(idx + stride * 2, (((pixelValue & 0xFF) / 255f - 0.406f) / 0.225f));
            }
        }

        imgData.rewind();
        return imgData;
    }
}

二、使用方法

我使用一个示例方法来解释使用过程,如下,我们创建onnx环境和会话,从assets目录获取模型名称并转化成InputStream,随后进行会话实例化,再根据图片uri地址将图片转化成InputStream形式,接下来就可以对数据进行处理,下面代码中的224是大小,长宽裁剪都是如此,再将图片转化成Bitmap格式、FloatBuffer格式、OnnxTensor格式等,最后转化成inputMap格式,步骤下面都有,随后使用session会话的run方法进行模型推理,最后可以得到outputData结果。

    private float[] Run(Uri uri, Bitmap mybitmap) throws IOException {
        float[][] outputData = new float[0][];
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        AssetManager assetManager = getAssets();
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        try {
            // 读取模型
            InputStream stream = assetManager.open(modelname);
            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
            byte[] buffer = new byte[4096];
            int bytesRead;
            while ((bytesRead = stream.read(buffer)) != -1) {
                byteStream.write(buffer, 0, bytesRead);
            }
            byteStream.flush();
            byte[] bytes = byteStream.toByteArray();
            OrtSession session = environment.createSession(bytes, options);

            //数据处理
            Bitmap bitmap = null;
            if (mybitmap==null&&uri!=null){
                InputStream inputStream =getContentResolver().openInputStream(uri);
                bitmap = BitmapFactory.decodeStream(inputStream);
            }else {
                bitmap=mybitmap;
            }

//            bitmap = BitmapFactory.decodeStream(getActivity().getAssets().open("ricebacterial_leaf_blight.jpg"));
            Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
            FloatBuffer inputData= ImagePreProcessor.preProcess(scaledBitmap);
            OnnxTensor inputTensor = OnnxTensor.createTensor(environment, inputData, new long[]{1, 3, 224, 224});
            Map<String, OnnxTensor> inputMap = new HashMap<>();
            inputMap.put("input", inputTensor);
            System.out.println(inputMap);
            OrtSession.Result output = session.run(inputMap);

            //推理执行
            OnnxTensor outputTensor = (OnnxTensor) output.get(0);
            outputData = (float[][]) outputTensor.getValue();

            for (int i = 0; i < outputData.length; i++) {
                for (int j = 0; j < outputData[0].length; j++) {
                    System.out.print(outputData[i][j] + " ");
                }
            }
        } catch (IOException | OrtException e) {
            e.printStackTrace();
        }
        float[] result=outputData[0];
        return result;
    }

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

焚詩作薪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值