import java.io.*;
import org.tensorflow.Tensor;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.types.UInt8;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import javax.imageio.ImageIO;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;
public class Main {
private static final String[] LABELS = {"label1", "label2", "label3", "label4", "label5", "label6", "label7"};
private static final String SAVED_MODEL_PATH = "/usr/repositories/resources/models/ilats-targets-4/saved_model";
private static final String FILE_PATH = "/tmp/inference_image.jpeg";
// Definite input and output Tensors for detection_graph
private static final String IMAGE_TENSOR_NAME = "image_tensor:0";
// Each box represents a part of the image where a particular object was detected.
private static final String DETECTION_BOXES_NAME = "detection_boxes:0";
// Each score represent how level of confidence for each of the objects.
// Score is shown on the result image, together with the class label.
private static final String DETECTION_SCORES_NAME = "detection_scores:0";
private static final String DETECTION_CLASSES_NAME = "detection_classes:0";
public static void main(String[] args) {
int ch;
try {
while ((ch = System.in.read()) != -1) {
performInference();
}
} catch (Exception e) {
}
}
public static void performInference() {
SavedModelBundle model = null;
Tensor<UInt8> imageTensor = null;
List<Tensor<?>> outputs = null;
try {
model = SavedModelBundle.load(SAVED_MODEL_PATH, "serve");
imageTensor = makeImageTensor(FILE_PATH);
outputs = model
.session()
.runner()
.feed(IMAGE_TENSOR_NAME, imageTensor)
.fetch(DETECTION_SCORES_NAME)
.fetch(DETECTION_CLASSES_NAME)
.fetch(DETECTION_BOXES_NAME)
.run();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
} finally {
// this closes Session and Graph that belongs to model as well
if (model != null) {
model.close();
}
if (imageTensor != null) {
imageTensor.close();
}
if (outputs != null) {
for (Tensor output : outputs) {
if (output != null) {
output.close();
}
}
}
}
}
public static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
BufferedImage img = ImageIO.read(new File(filename));
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
throw new IOException(
String.format("Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust", img.getType(), filename)
);
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb(data);
final long BATCH_SIZE = 1;
final long CHANNELS = 3;
long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
Tensor<UInt8> imageTensor = Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
img.flush();
return imageTensor;
}
public static void bgr2rgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
}
参考:https://github.com/tensorflow/tensorflow/issues/17930
java 调用tensorflow模型
最新推荐文章于 2024-05-10 15:54:26 发布