Darknet训练的yolov3.weights,转换yolov3.pb,tfserving部署。
1.转换pb模型github代码:
https://github.com/mystic123/tensorflow-yolo-v3
2.pb模型转换save_model.pb代码:
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import tensorflow as tf
import sys
sys.path.append("./")
export_dir = 'model-hot/1'
graph_pb = 'hot-yolov3-blob.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess: # name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
image_tensor = g.get_tensor_by_name("inputs:0")
detection_sbbox = g.get_tensor_by_name("detector/yolo-v3/detections:0")
# detection_mbbox = g.get_tensor_by_name("pred_mbbox/concat_2:0")
# detection_lbbox = g.get_tensor_by_name("pred_lbbox/concat_2:0")
print(type(detection_sbbox))
# out = {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox}
# out = {''}
# print(type(out))
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"in": image_tensor},
{"bboxs":detection_sbbox})
# {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox},)
# {"scores":detection_scores},
# {"classes":detection_classes},
# {"nums":num_detections})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map = sigs)
builder.save()
3.tfseving部署run.sh
docker run -p 8601:8501 -p 8600:8500 \
-v /home/ygwl/tensorflow-serving/model-hot:/models/model-hot \
-e MODEL_NAME=model-hot -t tensorflow/serving
4.java的maven工程调用
a. pom.xml
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<repositories>
<repository>
<id>javaxt.com</id>
<url>http://www.javaxt.com/maven</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
<!-- 这个库是做图像处理的 -->
<dependency>
<groupId>net.coobird</groupId>
<artifactId>thumbnailator</artifactId>
<version>0.4.8</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
<version>1.7.0</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.7.Final</version>
</dependency>
<dependency>
<groupId>javaxt</groupId>
<artifactId>javaxt-core</artifactId>
<version>1.8.0</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId>
<version>1.5.2</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.13.1</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
<version>1.13.1</version>
</dependency>
</dependencies>
b. 调用代码
package org.example;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.util.ArrayList;
import java.util.List;
import net.coobird.thumbnailator.Thumbnails;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
import javax.imageio.ImageIO;
public class Demo {
public static void main(String[] args) throws Exception {
String modelName = "model-hot";
String signatureName = "serving_default";
String filename = "F:\\tfm\\src\\main\\21.jpg";
BufferedImage im = Thumbnails.of(filename).forceSize(416, 416).outputFormat("bmp").asBufferedImage();
Raster raster = im.getRaster();
List<Float> floatList = new ArrayList<>();
float[] tmp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
float[] pixels = raster.getPixels(0, 0, raster.getWidth(), raster.getHeight(), tmp);
for (float pixel : pixels) {
floatList.add(pixel);
}
long t = System.currentTimeMillis();
//创建连接,注意usePlaintext设置为true表示用非SSL连接
ManagedChannel channel = ManagedChannelBuilder.forAddress("172.20.112.102", 8600).usePlaintext(true).build();
//这里还是先用block模式
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
//创建请求
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
//模型名称和模型方法名预设
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setSignatureName(signatureName);
predictRequestBuilder.setModelSpec(modelSpecBuilder);
//设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
//150528 = 224 * 224 * 3
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
tensorProtoBuilder.addAllFloatVal(floatList);
predictRequestBuilder.putInputs("in", tensorProtoBuilder.build());
//访问并获取结果
Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
List<Float> boxes = predictResponse.getOutputsOrThrow("bboxs").getFloatValList();
System.out.println(boxes);
List<List<Float>> bbox = getSplitList(8, boxes);
List<List<Float>> bb = new ArrayList<>();
for (int i = 0; i < bbox.size(); i++) {
// System.out.println(bbox.get(i));
// System.out.println(bbox.get(i).size());
// System.out.println(bbox.get(i).get(4));
if (bbox.get(i).get(4) < 0.9) {
continue;
}
bb.add(bbox.get(i));
}
System.out.println("=====================================");
System.out.println(bb);
System.exit(1);
System.out.println("cost time: " + (System.currentTimeMillis() - t));
}
private static List<List<Float>> getSplitList(int splitNum, List<Float> list) {
List<List<Float>> splitList = new ArrayList<>();
int groupFlag = list.size() % splitNum == 0 ? (list.size() / splitNum) : (list.size() / splitNum + 1);
for (int j = 0; j < groupFlag; j++) {
if ((j * splitNum + splitNum) <= list.size()) {
splitList.add(list.subList(j * splitNum, j * splitNum + splitNum));
} else if ((j * splitNum + splitNum) > list.size()) {
splitList.add(list.subList(j * splitNum, list.size()));
} else if (list.size() < splitNum) {
splitList.add(list.subList(0, list.size()));
}
}
return splitList;
}
}