Java-tfserving-maskrcnn

1.maven工程pom.xml

<properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <maven.compiler.source>1.8</maven.compiler.source>
    <maven.compiler.target>1.8</maven.compiler.target>
  </properties>

  <repositories>
    <repository>
      <id>javaxt.com</id>
      <url>http://www.javaxt.com/maven</url>
    </repository>
  </repositories>

  <dependencies>
    <dependency>
      <groupId>com.yesup.oss</groupId>
      <artifactId>tensorflow-client</artifactId>
      <version>1.4-2</version>
    </dependency>
    <!-- 这个库是做图像处理的 -->
    <dependency>
      <groupId>net.coobird</groupId>
      <artifactId>thumbnailator</artifactId>
      <version>0.4.8</version>
    </dependency>
    <dependency>
      <groupId>io.grpc</groupId>
      <artifactId>grpc-netty</artifactId>
      <version>1.7.0</version>
    </dependency>
    <dependency>
      <groupId>io.netty</groupId>
      <artifactId>netty-tcnative-boringssl-static</artifactId>
      <version>2.0.7.Final</version>
    </dependency>
    <dependency>
      <groupId>javaxt</groupId>
      <artifactId>javaxt-core</artifactId>
      <version>1.8.0</version>
    </dependency>
    <dependency>
      <groupId>org.bytedeco</groupId>
      <artifactId>javacv-platform</artifactId>
      <version>1.5.2</version>
    </dependency>
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.13.1</version>
    </dependency>
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>proto</artifactId>
      <version>1.13.1</version>
    </dependency>
  </dependencies>

2.code

package org.example;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import javaxt.io.Image;
import org.bytedeco.javacpp.indexer.FloatRawIndexer;
import org.bytedeco.javacpp.indexer.UByteRawIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_imgcodecs;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.*;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

public class App {

    public static void main(String[] args) throws Exception {
        String modelName = "model-group";
        String signatureName = "serving_default";
        try {
            String file = "F:\\新建文件夹\\tfm\\src\\main\\1908.jpg";
            BufferedImage image = new Image(new File(file)).getBufferedImage();
            List<Integer> intList = new ArrayList<>();
            int pixels[] = image.getRGB(0, 0, image.getWidth(), image.getHeight(), null, 0, image.getWidth());
            // RGB转BGR格式
            for (int i = 0, j = 0; i < pixels.length; ++i, j += 3) {
                intList.add(pixels[i] & 0xff);
                intList.add((pixels[i] >> 8) & 0xff);
                intList.add((pixels[i] >> 16) & 0xff);
            }
            long t = System.currentTimeMillis();
            // http://172.20.112.102:8501/v1/models/model-group:predict
            ManagedChannel channel = ManagedChannelBuilder.forAddress("172.20.112.102", 8500).usePlaintext(true).build();
            //            System.out.println(channel);
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
            //            System.out.println(stub);
            //创建请求
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName(modelName);
            modelSpecBuilder.setSignatureName(signatureName);
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_UINT8);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(image.getHeight()));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(image.getWidth()));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
            tensorProtoBuilder.addAllIntVal(intList);
            //            System.out.println(tensorShapeBuilder);
            //            predictRequestBuilder.putInputs("image_tensor", tensorProtoBuilder.build());
            predictRequestBuilder.putInputs("inputs", tensorProtoBuilder.build());
            // 结果
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
            //            System.out.println(predictResponse);
            List<Float> boxes = predictResponse.getOutputsOrThrow("detection_boxes").getFloatValList();
            List<Float> scores = predictResponse.getOutputsOrThrow("detection_scores").getFloatValList();
            List<Float> classes = predictResponse.getOutputsOrThrow("detection_classes").getFloatValList();
            List<Float> masks = predictResponse.getOutputsOrThrow("detection_masks").getFloatValList();
            //                                    System.out.println(scores);
            //                                    System.out.println(boxes);
//                                    System.out.println(masks.get(0));
            //                                    System.out.println(scores.size());
            //                                    System.out.println(boxes.size());
            //                        System.out.println(classes.size());
            //            System.out.println(masks.size());
            Mat srcMat = opencv_imgcodecs.imread(file);
            UByteRawIndexer srcMatRawIndexer = srcMat.createIndexer();
            int width = 640;
            int height = 480;
            for (int i = 0; i < scores.size(); i++) {
                if (scores.get(i) > 0.9) {
                    System.out.println("\n****************************************************************************");
                    System.out.println("index " + i + " score " + scores.get(i));
                    //box
                    System.out.println("******** box ***********");
                    int baseIndex = i * 4;
                    System.out.println("base index " + baseIndex);
                    List<Float> boxPoints = boxes.subList(baseIndex, baseIndex + 4);
                    System.out.println(boxPoints);
                    int boxImageTopLeftY = Math.round(boxPoints.get(0) * height);
                    int boxImageTopLeftX = Math.round(boxPoints.get(1) * width);
                    int boxImageBottomRightY = Math.round(boxPoints.get(2) * height);
                    int boxImageBottomRightX = Math.round(boxPoints.get(3) * width);
                    int boxWidth = boxImageBottomRightX - boxImageTopLeftX;
                    int boxHeight = boxImageBottomRightY - boxImageTopLeftY;
                    Rect rect = new Rect(new Point(boxImageTopLeftX, boxImageTopLeftY), new Point(boxImageBottomRightX, boxImageBottomRightY));
                    //mask
                    System.out.println("******** mask ***********");
                    baseIndex = i * 15 * 15;
                    System.out.println("base index " + baseIndex);
                    List<Float> maskPoints = masks.subList(baseIndex, baseIndex + 15 * 15);
                    System.out.println(maskPoints);
                    Mat maskNumMat = new Mat(15, 15, opencv_core.CV_32F);
                    FloatRawIndexer maskNumMatIndexer = maskNumMat.createIndexer();
                    for (int y = 0; y < 15; y++) {
                        for (int x = 0; x < 15; x++) {
                            maskNumMatIndexer.put(y, x, maskPoints.get(y * 15 + x));
                        }
                    }
                    Mat maskMat = new Mat(boxHeight, boxWidth, maskNumMat.type());
                    opencv_imgproc.resize(maskNumMat, maskMat, new Size(boxWidth, boxHeight));
                    FloatRawIndexer maskFloatRawIndexer = maskMat.createIndexer();
                    for (int y = 0; y < boxHeight; y++) {
                        for (int x = 0; x < boxWidth; x++) {
                            int maskImageY = boxImageTopLeftY + y;
                            int maskImageX = boxImageTopLeftX + x;
                            if (maskFloatRawIndexer.get(y, x) > 0.3) {
                                srcMatRawIndexer.put(maskImageY, maskImageX * 3, 0);
                                srcMatRawIndexer.put(maskImageY, maskImageX * 3 + 1, 0);
                                srcMatRawIndexer.put(maskImageY, maskImageX * 3 + 2, 0);
                            }
                        }
                    }
                    opencv_imgproc.rectangle(srcMat, rect, AbstractScalar.YELLOW);
                }
            }
            opencv_imgcodecs.imwrite(new File("target/maskrcnn.jpg").getAbsolutePath(), srcMat);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值