java 调用tensorflow模型

import java.io.*;
import org.tensorflow.Tensor;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.types.UInt8;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import javax.imageio.ImageIO;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;

public class Main {
  private static final String[] LABELS = {"label1", "label2", "label3", "label4", "label5", "label6", "label7"};
  private static final String SAVED_MODEL_PATH = "/usr/repositories/resources/models/ilats-targets-4/saved_model";
  private static final String FILE_PATH = "/tmp/inference_image.jpeg";

  // Definite input and output Tensors for detection_graph
  private static final String IMAGE_TENSOR_NAME = "image_tensor:0";

  // Each box represents a part of the image where a particular object was detected.
  private static final String DETECTION_BOXES_NAME = "detection_boxes:0";

  // Each score represent how level of confidence for each of the objects.
  // Score is shown on the result image, together with the class label.
  private static final String DETECTION_SCORES_NAME = "detection_scores:0";

  private static final String DETECTION_CLASSES_NAME = "detection_classes:0";

  public static void main(String[] args) {
    int ch;
    try {

      while ((ch = System.in.read()) != -1) {
        performInference();
      }

    } catch (Exception e) {
    }
  }

  public static void performInference() {
    SavedModelBundle model = null;
    Tensor<UInt8> imageTensor = null;
    List<Tensor<?>> outputs = null;

    try {
      model = SavedModelBundle.load(SAVED_MODEL_PATH, "serve");
      imageTensor = makeImageTensor(FILE_PATH);
      
      outputs = model
        .session()
        .runner()
        .feed(IMAGE_TENSOR_NAME, imageTensor)
        .fetch(DETECTION_SCORES_NAME)
        .fetch(DETECTION_CLASSES_NAME)
        .fetch(DETECTION_BOXES_NAME)
        .run();

    } catch (Exception e) {
      throw new RuntimeException(e.getMessage(), e);
    } finally {
      // this closes Session and Graph that belongs to model as well
      if (model != null) {
        model.close();
      }

      if (imageTensor != null) {
        imageTensor.close();
      }

      if (outputs != null) {
        for (Tensor output : outputs) {
          if (output != null) {
            output.close();
          }
        }
      }
    }
    

  }

  public static Tensor<UInt8> makeImageTensor(String filename) throws IOException {
    BufferedImage img = ImageIO.read(new File(filename));
    if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
      throw new IOException(
        String.format("Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust", img.getType(), filename)
      );
    }

    byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
    // ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
    bgr2rgb(data);
    final long BATCH_SIZE = 1;
    final long CHANNELS = 3;
    long[] shape = new long[] {BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
    Tensor<UInt8> imageTensor = Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
    img.flush();
    return imageTensor;
  }

  public static void bgr2rgb(byte[] data) {
    for (int i = 0; i < data.length; i += 3) {
      byte tmp = data[i];
      data[i] = data[i + 2];
      data[i + 2] = tmp;
    }
  }

}
参考:https://github.com/tensorflow/tensorflow/issues/17930
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值