java-yolov3.weights-tfserving

Darknet训练的yolov3.weights,转换yolov3.pb,tfserving部署。

1.转换pb模型github代码:

https://github.com/mystic123/tensorflow-yolo-v3

2.pb模型转换save_model.pb代码:

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import tensorflow as tf
import sys
sys.path.append("./")

export_dir = 'model-hot/1'
graph_pb = 'hot-yolov3-blob.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:   # name="" is important to ensure we don't get spurious prefixing
    tf.import_graph_def(graph_def, name="")
    g = tf.get_default_graph()

    image_tensor = g.get_tensor_by_name("inputs:0")
    detection_sbbox = g.get_tensor_by_name("detector/yolo-v3/detections:0")
    # detection_mbbox = g.get_tensor_by_name("pred_mbbox/concat_2:0")
    # detection_lbbox = g.get_tensor_by_name("pred_lbbox/concat_2:0")
    print(type(detection_sbbox))
    # out = {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox}
    # out = {''}
    # print(type(out))
    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
        tf.saved_model.signature_def_utils.predict_signature_def(
            {"in": image_tensor},
            {"bboxs":detection_sbbox})

            # {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox},)
            # {"scores":detection_scores},
            # {"classes":detection_classes},
            # {"nums":num_detections})

    builder.add_meta_graph_and_variables(sess,
                                         [tag_constants.SERVING],
                                         signature_def_map = sigs)
builder.save()

3.tfseving部署run.sh

docker run -p 8601:8501 -p 8600:8500 \
-v /home/ygwl/tensorflow-serving/model-hot:/models/model-hot \
-e MODEL_NAME=model-hot -t tensorflow/serving

4.java的maven工程调用
a. pom.xml

<properties>
    <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    <maven.compiler.source>1.8</maven.compiler.source>
    <maven.compiler.target>1.8</maven.compiler.target>
  </properties>

  <repositories>
    <repository>
      <id>javaxt.com</id>
      <url>http://www.javaxt.com/maven</url>
    </repository>
  </repositories>

  <dependencies>
    <dependency>
      <groupId>com.yesup.oss</groupId>
      <artifactId>tensorflow-client</artifactId>
      <version>1.4-2</version>
    </dependency>
    <!-- 这个库是做图像处理的 -->
    <dependency>
      <groupId>net.coobird</groupId>
      <artifactId>thumbnailator</artifactId>
      <version>0.4.8</version>
    </dependency>
    <dependency>
      <groupId>io.grpc</groupId>
      <artifactId>grpc-netty</artifactId>
      <version>1.7.0</version>
    </dependency>
    <dependency>
      <groupId>io.netty</groupId>
      <artifactId>netty-tcnative-boringssl-static</artifactId>
      <version>2.0.7.Final</version>
    </dependency>
    <dependency>
      <groupId>javaxt</groupId>
      <artifactId>javaxt-core</artifactId>
      <version>1.8.0</version>
    </dependency>
    <dependency>
      <groupId>org.bytedeco</groupId>
      <artifactId>javacv-platform</artifactId>
      <version>1.5.2</version>
    </dependency>
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>tensorflow</artifactId>
      <version>1.13.1</version>
    </dependency>
    <dependency>
      <groupId>org.tensorflow</groupId>
      <artifactId>proto</artifactId>
      <version>1.13.1</version>
    </dependency>
  </dependencies>

b. 调用代码

package org.example;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;

import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.util.ArrayList;
import java.util.List;

import net.coobird.thumbnailator.Thumbnails;

import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import javax.imageio.ImageIO;


public class Demo {
    public static void main(String[] args) throws Exception {
        String modelName = "model-hot";
        String signatureName = "serving_default";
        String filename = "F:\\tfm\\src\\main\\21.jpg";

        BufferedImage im = Thumbnails.of(filename).forceSize(416, 416).outputFormat("bmp").asBufferedImage();
        Raster raster = im.getRaster();
        List<Float> floatList = new ArrayList<>();
        float[] tmp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
        float[] pixels = raster.getPixels(0, 0, raster.getWidth(), raster.getHeight(), tmp);
        for (float pixel : pixels) {
            floatList.add(pixel);
        }

        long t = System.currentTimeMillis();
        //创建连接,注意usePlaintext设置为true表示用非SSL连接
        ManagedChannel channel = ManagedChannelBuilder.forAddress("172.20.112.102", 8600).usePlaintext(true).build();
        //这里还是先用block模式
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
        //创建请求
        Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
        //模型名称和模型方法名预设
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setSignatureName(signatureName);
        predictRequestBuilder.setModelSpec(modelSpecBuilder);
        //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        //150528 = 224 * 224 * 3
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));

        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
        tensorProtoBuilder.addAllFloatVal(floatList);
        predictRequestBuilder.putInputs("in", tensorProtoBuilder.build());
        //访问并获取结果
        Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
        List<Float> boxes = predictResponse.getOutputsOrThrow("bboxs").getFloatValList();
        System.out.println(boxes);

        List<List<Float>> bbox = getSplitList(8, boxes);

        List<List<Float>> bb = new ArrayList<>();
        for (int i = 0; i < bbox.size(); i++) {
//                System.out.println(bbox.get(i));
//                System.out.println(bbox.get(i).size());
//                System.out.println(bbox.get(i).get(4));
            if (bbox.get(i).get(4) < 0.9) {
                continue;
            }
            bb.add(bbox.get(i));
        }
        System.out.println("=====================================");
        System.out.println(bb);
        System.exit(1);
        System.out.println("cost time: " + (System.currentTimeMillis() - t));
    }

    private static List<List<Float>> getSplitList(int splitNum, List<Float> list) {
        List<List<Float>> splitList = new ArrayList<>();
        int groupFlag = list.size() % splitNum == 0 ? (list.size() / splitNum) : (list.size() / splitNum + 1);
        for (int j = 0; j < groupFlag; j++) {
            if ((j * splitNum + splitNum) <= list.size()) {
                splitList.add(list.subList(j * splitNum, j * splitNum + splitNum));
            } else if ((j * splitNum + splitNum) > list.size()) {
                splitList.add(list.subList(j * splitNum, list.size()));
            } else if (list.size() < splitNum) {
                splitList.add(list.subList(0, list.size()));
            }
        }
        return splitList;
    }

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值