大家在使用onnx格式的图像分类模型或者其他需要输入图像的模型时会遇到输入图像格式问题,这点onnx官方给出的案例模板中已经有现成的kotlin代码,但是萌新使用java进行安卓开发时不会使用kotlin,所以我写了一份java版的处理代码给大家参考
目录
一、图像处理类
import android.graphics.Bitmap;
import java.nio.FloatBuffer;
public class ImagePreProcessor {
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 3;
private static final int IMAGE_SIZE_X = 224;
private static final int IMAGE_SIZE_Y = 224;
public static FloatBuffer preProcess(Bitmap bitmap) {
FloatBuffer imgData = FloatBuffer.allocate(
DIM_BATCH_SIZE
* DIM_PIXEL_SIZE
* IMAGE_SIZE_X
* IMAGE_SIZE_Y
);
imgData.rewind();
int stride = IMAGE_SIZE_X * IMAGE_SIZE_Y;
int[] bmpData = new int[stride];
bitmap.getPixels(bmpData, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < IMAGE_SIZE_X; i++) {
for (int j = 0; j < IMAGE_SIZE_Y; j++) {
int idx = IMAGE_SIZE_Y * i + j;
int pixelValue = bmpData[idx];
imgData.put(idx, (((pixelValue >> 16 & 0xFF) / 255f - 0.485f) / 0.229f));
imgData.put(idx + stride, (((pixelValue >> 8 & 0xFF) / 255f - 0.456f) / 0.224f));
imgData.put(idx + stride * 2, (((pixelValue & 0xFF) / 255f - 0.406f) / 0.225f));
}
}
imgData.rewind();
return imgData;
}
}
二、使用方法
我使用一个示例方法来解释使用过程,如下,我们创建onnx环境和会话,从assets目录获取模型名称并转化成InputStream,随后进行会话实例化,再根据图片uri地址将图片转化成InputStream形式,接下来就可以对数据进行处理,下面代码中的224是大小,长宽裁剪都是如此,再将图片转化成Bitmap格式、FloatBuffer格式、OnnxTensor格式等,最后转化成inputMap格式,步骤下面都有,随后使用session会话的run方法进行模型推理,最后可以得到outputData结果。
private float[] Run(Uri uri, Bitmap mybitmap) throws IOException {
float[][] outputData = new float[0][];
OrtEnvironment environment = OrtEnvironment.getEnvironment();
AssetManager assetManager = getAssets();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
try {
// 读取模型
InputStream stream = assetManager.open(modelname);
ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = stream.read(buffer)) != -1) {
byteStream.write(buffer, 0, bytesRead);
}
byteStream.flush();
byte[] bytes = byteStream.toByteArray();
OrtSession session = environment.createSession(bytes, options);
//数据处理
Bitmap bitmap = null;
if (mybitmap==null&&uri!=null){
InputStream inputStream =getContentResolver().openInputStream(uri);
bitmap = BitmapFactory.decodeStream(inputStream);
}else {
bitmap=mybitmap;
}
// bitmap = BitmapFactory.decodeStream(getActivity().getAssets().open("ricebacterial_leaf_blight.jpg"));
Bitmap scaledBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
FloatBuffer inputData= ImagePreProcessor.preProcess(scaledBitmap);
OnnxTensor inputTensor = OnnxTensor.createTensor(environment, inputData, new long[]{1, 3, 224, 224});
Map<String, OnnxTensor> inputMap = new HashMap<>();
inputMap.put("input", inputTensor);
System.out.println(inputMap);
OrtSession.Result output = session.run(inputMap);
//推理执行
OnnxTensor outputTensor = (OnnxTensor) output.get(0);
outputData = (float[][]) outputTensor.getValue();
for (int i = 0; i < outputData.length; i++) {
for (int j = 0; j < outputData[0].length; j++) {
System.out.print(outputData[i][j] + " ");
}
}
} catch (IOException | OrtException e) {
e.printStackTrace();
}
float[] result=outputData[0];
return result;
}