多模型部署

1.安装docker
win10安装借鉴:https://www.jianshu.com/p/0f927dc8b23d

2.拉取tensorflow-serving镜像

docker pull tensorflow/serving

3.测试tfserving的例子

docker run -p 8701:8501 -p 8700:8500 \
-v C:/tmp/tfserving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu:/models/half_plus_two \
-e MODEL_NAME=half_plus_two -t tensorflow/serving

4.转换自己模型成save_model

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
import tensorflow as tf
import sys
sys.path.append("./")

export_dir = 'model-best-0619/1'
graph_pb = '0619_best_model.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:   # name="" is important to ensure we don't get spurious prefixing
    tf.import_graph_def(graph_def, name="")
    g = tf.get_default_graph()

    image_tensor = g.get_tensor_by_name("inputs:0")
    detection_sbbox = g.get_tensor_by_name("detector/yolo-v3/detections:0")
    # detection_mbbox = g.get_tensor_by_name("pred_mbbox/concat_2:0")
    # detection_lbbox = g.get_tensor_by_name("pred_lbbox/concat_2:0")
    print(type(detection_sbbox))
    # out = {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox}
    # out = {''}
    # print(type(out))
    sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
        tf.saved_model.signature_def_utils.predict_signature_def(
            {"in": image_tensor},
            {"bboxs":detection_sbbox})

            # {'sbbox':detection_sbbox,'mbbox':detection_mbbox,'lbbox':detection_lbbox},)
            # {"scores":detection_scores},
            # {"classes":detection_classes},
            # {"nums":num_detections})

    builder.add_meta_graph_and_variables(sess,
                                         [tag_constants.SERVING],
                                         signature_def_map = sigs)
builder.save()

5.部署
单模型:

docker run -p 8701:8501 -p 8700:8500 \
-v /home/ygwl/tensorflow-serving/model-fish:/models/model-fish \
-e MODEL_NAME=model-fish -t tensorflow/serving

多模型:

docker run -p 8901:8501 -p 8900:8500 -v G:/Gits/multimodels/:/models -e MODEL_NAME=models -t tensorflow/serving --model_config_file=/models/model.config

6.模型结构查看
http:///localhost:8701/v1/models/model-fish/metadata

7.REST和GRPC API客户端调用

import requests
import numpy as np
import cv2
import sys, os, shutil
import json

sys.path.append("./")
from PIL import Image

# input_img = r"E:\data\starin-diode\diode-0605\images\20200605_5.jpg"
# input_img = "diode.jpg"
input_imgs = r"E:\data\diode-opt\imgs/"
files = os.listdir(input_imgs)
save_path = r"out_imgs"
if os.path.exists(save_path):
    shutil.rmtree(save_path)
os.makedirs(save_path)
for file in files:
    # input_img = r"E:\data\diode-opt\imgs\20200611_84.jpg"
    input_img = input_imgs + file
    img = Image.open(input_img)
    print(np.array(img))
    img1 = img.resize((416, 416))
    image_np = np.array(img1)
    print(image_np)
    # print(image_np[np.newaxis,:].shape)
    img_data = image_np[np.newaxis, :].tolist()
    data = {"instances": img_data}

    # http://172.20.112.102:8701/v1/models/model-fish/metadata
    preds1 = requests.post("http://172.20.112.102:9201/v1/models/model-best-0619:predict", json=data)
    predictions1 = json.loads(preds1.content.decode('utf-8'))["predictions"][0]

    preds = requests.post("http://172.20.112.102:9101/v1/models/model-diode:predict", json=data)
    predictions = json.loads(preds.content.decode('utf-8'))["predictions"][0]
    print(predictions)
    # print(np.array(predictions)[:,4].max())
    pred = np.array(predictions)
    # print(pred.shape)
    # exit()
    a = pred[:, 4] > 0.25
    print(pred[a])
    print(len(pred[a]))
    # exit()
    im = cv2.imread(input_img)
    # print(im.shape) # hwc
    h_s = im.shape[0] / 416
    w_s = im.shape[1] / 416

    box = []
    for i in range(len(pred[a])):
        # print(pred[a][i])
        x1 = pred[a][i][0]
        y1 = pred[a][i][1]
        x2 = pred[a][i][2]
        y2 = pred[a][i][3]
        xx1 = (x1 - x2 / 2)
        yy1 = (y1 - y2 / 2)
        xx2 = (x1 + x2 / 2)
        yy2 = (y1 + y2 / 2)
        box.append([xx1, yy1, xx2, yy2, pred[a][i][4], pred[a][i][5:]])

        cv2.rectangle(im, (int(xx1 * w_s), int(yy1 * h_s)), (int(xx2 * w_s), int(yy2 * h_s)), (255, 0, 0))
    # print(box)
    # cv2.imshow('ss1', im)
    cv2.imwrite(save_path + '/' + file, im)
    # cv2.waitKey(0)

package org.example;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;

import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.util.ArrayList;
import java.util.List;

import net.coobird.thumbnailator.Thumbnails;

import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;

import javax.imageio.ImageIO;


public class Demo {
    public static void main(String[] args) throws Exception {
        String modelName = "model-hot";
        String signatureName = "serving_default";
        String filename = "F:\\tfm\\src\\main\\21.jpg";

        BufferedImage im = Thumbnails.of(filename).forceSize(416, 416).outputFormat("bmp").asBufferedImage();
        Raster raster = im.getRaster();
        List<Float> floatList = new ArrayList<>();
        float[] tmp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
        float[] pixels = raster.getPixels(0, 0, raster.getWidth(), raster.getHeight(), tmp);
        for (float pixel : pixels) {
            floatList.add(pixel);
        }

        long t = System.currentTimeMillis();
        //创建连接,注意usePlaintext设置为true表示用非SSL连接
        ManagedChannel channel = ManagedChannelBuilder.forAddress("172.20.112.102", 8600).usePlaintext(true).build();
        //这里还是先用block模式
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
        //创建请求
        Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
        //模型名称和模型方法名预设
        Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setSignatureName(signatureName);
        predictRequestBuilder.setModelSpec(modelSpecBuilder);
        //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
        TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
        TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
        //150528 = 224 * 224 * 3
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(416));
        tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));

        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
        tensorProtoBuilder.addAllFloatVal(floatList);
        predictRequestBuilder.putInputs("in", tensorProtoBuilder.build());
        //访问并获取结果
        Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
        List<Float> boxes = predictResponse.getOutputsOrThrow("bboxs").getFloatValList();
        System.out.println(boxes);

        List<List<Float>> bbox = getSplitList(8, boxes);

        List<List<Float>> bb = new ArrayList<>();
        for (int i = 0; i < bbox.size(); i++) {
//                System.out.println(bbox.get(i));
//                System.out.println(bbox.get(i).size());
//                System.out.println(bbox.get(i).get(4));
            if (bbox.get(i).get(4) < 0.9) {
                continue;
            }
            bb.add(bbox.get(i));
        }
        System.out.println("=====================================");
        System.out.println(bb);
        System.exit(1);
        System.out.println("cost time: " + (System.currentTimeMillis() - t));
    }

    private static List<List<Float>> getSplitList(int splitNum, List<Float> list) {
        List<List<Float>> splitList = new ArrayList<>();
        int groupFlag = list.size() % splitNum == 0 ? (list.size() / splitNum) : (list.size() / splitNum + 1);
        for (int j = 0; j < groupFlag; j++) {
            if ((j * splitNum + splitNum) <= list.size()) {
                splitList.add(list.subList(j * splitNum, j * splitNum + splitNum));
            } else if ((j * splitNum + splitNum) > list.size()) {
                splitList.add(list.subList(j * splitNum, list.size()));
            } else if (list.size() < splitNum) {
                splitList.add(list.subList(0, list.size()));
            }
        }
        return splitList;
    }

}

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值