triton部署yolov5笔记(三)

直达链接

Triton部署YOLOV5笔记(一)
Triton部署YOLOV5笔记(二)
triton部署yolov5笔记(三)
triton部署yolov5笔记(四)

拉取镜像

docker pull tienln/tensorrt:8.0.3_opencv 
# docker pull tienln/ubuntu:18.04_conda #根据自己需要是否拉取ubuntu镜像
docker pull nvcr.io/nvidia/tritonserver:21.09-py3

克隆代码

创建自己的工作路径

mkdir myworkspace # 接下来的工作在这个路径下进行

进入路径

cd myworkspace

下载代码

git clone -b v5.0 https://github.com/ultralytics/yolov5.git
git clone -b yolov5-v5.0 https://github.com/wang-xinyu/tensorrtx.git

创建wts文件

首先进入工作路径

cd myworkspace

复制gen_wts.py文件到yolov5的根目录下

cp tensorrtx/yolov5/gen_wts.py yolov5 

进入yolov5根目录

cd yolov5  
# docker run -it --rm --gpus all -v $PWD:/yolov5 tienln/ubuntu:18.04_conda /bin/bash  #这是容器指向当前目录,也就是yolov5这个目录,如果不拉取ubuntu镜像可以直接在你的服务器端进入目录
cd /yolov5  
conda activate yolov5  # 需要提前创建conda虚拟环境,安装目标检测所需要的包
python gen_wts.py -w yolov5s.pt -o yolov5s.wts #执行命令,生成yolov5.wts文件

创建tensorrt engine文件

进入工作目录

cd myworkspace
cp yolov5/yolov5s.wts tensorrtx/yolov5 # 拷贝wts文件到tensorrt的yolov5路径下
cd tensorrtx/yolov5  # 进入该路径 

进入tensorrt容器进行下列操作
进入容器

docker run -it --rm --gpus all -v $PWD:/yolov5 tienln/tensorrt:8.0.3_opencv /bin/bash   # 进入容器,映射到tensorrt/yolov5路径

在容器里执行操作

cd /yolov5  # 进入yolov5路径
mkdir build  
cd build   
cmake ..  
make -j16  
./yolov5 -s ../yolov5s.wts ../yolov5s.engine s # 生成engine文件

创建Triton Inference Server

cd yourworkingdirectoryhere  
mkdir -p triton_deploy/models/yolov5/1/  
mkdir triton_deploy/plugins  
cp tensorrtx/yolov5/yolov5s.engine triton_deploy/models/yolov5/1/model.plan  
cp tensorrtx/yolov5/build/libmyplugins.so triton_deploy/plugins/libmyplugins.so 

目录结构如下

.
├── models
│   └── model_0
│       └── 1
│           └── model.plan
└── plugins
    └── libmyplugins.so

run triton

docker run \
--gpus all \
--rm \
-p9000:8000 -p9001:8001 -p9002:8002 \
-v $(pwd)/triton_deploy/models:/models \
-v $(pwd)/triton_deploy/plugins:/plugins \
--env LD_PRELOAD=/plugins/libmyplugins.so \
nvcr.io/nvidia/tritonserver:21.09-py3 tritonserver \
--model-repository=/models \
--strict-model-config=false \
--log-verbose 1

客户端测试

cd myworkspace   
cd clients/yolov5
docker run -it --rm --gpus all --network host -v $PWD:/client tienln/ubuntu:18.04_conda /bin/bash  
conda activate yolov5  
pip install tritonclient  
cd /client
python client.py -o data/dog_result.jpg image data/dog.jpg 

client全部代码如下

将所有文件放一个目录下
client.py文件

#!/usr/bin/env python
# client.py
import argparse
import numpy as np
import sys
import cv2
import time

import tritonclient.grpc as grpcclient
from tritonclient.utils import InferenceServerException

from processing import preprocess, postprocess
from render import render_box, render_filled_box, get_text_size, render_text, RAND_COLORS, plot_one_box
from labels import COCOLabels

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('mode',
                        choices=['dummy', 'image', 'video'],
                        default='dummy',
                        help='Run mode. \'dummy\' will send an emtpy buffer to the server to test if inference works. \'image\' will process an image. \'video\' will process a video.')
    parser.add_argument('input',
                        type=str,
                        nargs='?',
                        help='Input file to load from in image or video mode')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        required=False,
                        default='model_0',
                        help='Inference model name, default yolov5')
    parser.add_argument('--width',
                        type=int,
                        required=False,
                        default=640,
                        help='Inference model input width, default 608')
    parser.add_argument('--height',
                        type=int,
                        required=False,
                        default=640,
                        help='Inference model input height, default 608')
    parser.add_argument('-u',
                        '--url',
                        type=str,
                        required=False,
                        default='127.0.0.1:8081',
                        help='Inference server URL, default localhost:8001')
    parser.add_argument('-o',
                        '--out',
                        type=str,
                        required=False,
                        default='',
                        help='Write output into file instead of displaying it')
    parser.add_argument('-c',
                        '--confidence',
                        type=float,
                        required=False,
                        default=0.5,
                        help='Confidence threshold for detected objects, default 0.5')
    parser.add_argument('-n',
                        '--nms',
                        type=float,
                        required=False,
                        default=0.45,
                        help='Non-maximum suppression threshold for filtering raw boxes, default 0.45')
    parser.add_argument('-f',
                        '--fps',
                        type=float,
                        required=False,
                        default=24.0,
                        help='Video output fps, default 24.0 FPS')
    parser.add_argument('-i',
                        '--model-info',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Print model status, configuration and statistics')
    parser.add_argument('-v',
                        '--verbose',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable verbose client output')
    parser.add_argument('-t',
                        '--client-timeout',
                        type=float,
                        required=False,
                        default=None,
                        help='Client timeout in seconds, default no timeout')
    parser.add_argument('-s',
                        '--ssl',
                        action="store_true",
                        required=False,
                        default=False,
                        help='Enable SSL encrypted channel to the server')
    parser.add_argument('-r',
                        '--root-certificates',
                        type=str,
                        required=False,
                        default=None,
                        help='File holding PEM-encoded root certificates, default none')
    parser.add_argument('-p',
                        '--private-key',
                        type=str,
                        required=False,
                        default=None,
                        help='File holding PEM-encoded private key, default is none')
    parser.add_argument('-x',
                        '--certificate-chain',
                        type=str,
                        required=False,
                        default=None,
                        help='File holding PEM-encoded certicate chain default is none')

    FLAGS = parser.parse_args()
    print(FLAGS)
    # Create server context
    try:
        triton_client = grpcclient.InferenceServerClient(
            url=FLAGS.url,
            verbose=FLAGS.verbose,
            ssl=FLAGS.ssl,
            root_certificates=FLAGS.root_certificates,
            private_key=FLAGS.private_key,
            certificate_chain=FLAGS.certificate_chain)
    except Exception as e:
        print("context creation failed: " + str(e))
        sys.exit()

    # Health check
    if not triton_client.is_server_live():
        print("FAILED : is_server_live")
        sys.exit(1)

    if not triton_client.is_server_ready():
        print("FAILED : is_server_ready")
        sys.exit(1)
    
    if not triton_client.is_model_ready(FLAGS.model):
        print("FAILED : is_model_ready")
        sys.exit(1)

    if FLAGS.model_info:
        # Model metadata
        try:
            metadata = triton_client.get_model_metadata(FLAGS.model)
            print(metadata)
        except InferenceServerException as ex:
            if "Request for unknown model" not in ex.message():
                print("FAILED : get_model_metadata")
                print("Got: {}".format(ex.message()))
                sys.exit(1)
            else:
                print("FAILED : get_model_metadata")
                sys.exit(1)

        # Model configuration
        try:
            config = triton_client.get_model_config(FLAGS.model)
            if not (config.config.name == FLAGS.model):
                print("FAILED: get_model_config")
                sys.exit(1)
            print(config)
        except InferenceServerException as ex:
            print("FAILED : get_model_config")
            print("Got: {}".format(ex.message()))
            sys.exit(1)

    # DUMMY MODE
    if FLAGS.mode == 'dummy':
        print("Running in 'dummy' mode")
        print("Creating emtpy buffer filled with ones...")
        inputs = []
        outputs = []
        inputs.append(grpcclient.InferInput('data', [1, 3, FLAGS.width, FLAGS.height], "FP32"))
        inputs[0].set_data_from_numpy(np.ones(shape=(1, 3, FLAGS.width, FLAGS.height), dtype=np.float32))
        outputs.append(grpcclient.InferRequestedOutput('prob'))

        print("Invoking inference...")
        results = triton_client.infer(model_name=FLAGS.model,
                                    inputs=inputs,
                                    outputs=outputs,
                                    client_timeout=FLAGS.client_timeout)
        if FLAGS.model_info:
            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
            if len(statistics.model_stats) != 1:
                print("FAILED: get_inference_statistics")
                sys.exit(1)
            print(statistics)
        print("Done")

        result = results.as_numpy('prob')
        print(f"Received result buffer of size {result.shape}")
        print(f"Naive buffer sum: {np.sum(result)}")

    # IMAGE MODE
    if FLAGS.mode == 'image':
        print("Running in 'image' mode")
        if not FLAGS.input:
            print("FAILED: no input image")
            sys.exit(1)
        
        inputs = []
        outputs = []
        inputs.append(grpcclient.InferInput('data', [1, 3, FLAGS.width, FLAGS.height], "FP32"))
        outputs.append(grpcclient.InferRequestedOutput('prob'))

        print("Creating buffer from image file...")
        input_image = cv2.imread(str(FLAGS.input))
        if input_image is None:
            print(f"FAILED: could not load input image {str(FLAGS.input)}")
            sys.exit(1)
        input_image_buffer = preprocess(input_image, [FLAGS.width, FLAGS.height])
        input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
        inputs[0].set_data_from_numpy(input_image_buffer)

        print("Invoking inference...")
        t1 = time.time()
        results = triton_client.infer(model_name=FLAGS.model,
                                    inputs=inputs,
                                    outputs=outputs,
                                    client_timeout=FLAGS.client_timeout)
        if FLAGS.model_info:
            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
            if len(statistics.model_stats) != 1:
                print("FAILED: get_inference_statistics")
                sys.exit(1)
            print(statistics)
        print("Done")

        result = results.as_numpy('prob')
        print(f"Received result buffer of size {result.shape}")
        print(f"Naive buffer sum: {np.sum(result)}")

        detected_objects = postprocess(result, input_image.shape[1], input_image.shape[0], [FLAGS.width, FLAGS.height], FLAGS.confidence, FLAGS.nms)
        print(f"Detected objects: {len(detected_objects)}")

        for box in detected_objects:
            print(f"{COCOLabels(box.classID).name}: {box.confidence}")
            plot_one_box(box.box(), input_image,color=tuple(RAND_COLORS[box.classID % 64].tolist()), label=f"{COCOLabels(box.classID).name}:{box.confidence:.2f}",)   

        if FLAGS.out:
            cv2.imwrite(FLAGS.out, input_image)
            print(f"Saved result to {FLAGS.out}")
        else:
            cv2.imshow('image', input_image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        t2 = time.time()
        print('inference time is: {}ms'.format(1000 * (t2 - t1)))

    # VIDEO MODE
    if FLAGS.mode == 'video':
        print("Running in 'video' mode")
        if not FLAGS.input:
            print("FAILED: no input video")
            sys.exit(1)

        inputs = []
        outputs = []
        inputs.append(grpcclient.InferInput('data', [1, 3, FLAGS.width, FLAGS.height], "FP32"))
        outputs.append(grpcclient.InferRequestedOutput('prob'))

        print("Opening input video stream...")
        cap = cv2.VideoCapture(FLAGS.input)
        if not cap.isOpened():
            print(f"FAILED: cannot open video {FLAGS.input}")
            sys.exit(1)

        counter = 0
        out = None
        print("Invoking inference...")
        while True:
            ret, frame = cap.read()
            if not ret:
                print("failed to fetch next frame")
                break

            if counter == 0 and FLAGS.out:
                print("Opening output video stream...")
                fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
                out = cv2.VideoWriter(FLAGS.out, fourcc, FLAGS.fps, (frame.shape[1], frame.shape[0]))

            input_image_buffer = preprocess(frame, [FLAGS.width, FLAGS.height])
            input_image_buffer = np.expand_dims(input_image_buffer, axis=0)
            inputs[0].set_data_from_numpy(input_image_buffer)

            results = triton_client.infer(model_name=FLAGS.model,
                                    inputs=inputs,
                                    outputs=outputs,
                                    client_timeout=FLAGS.client_timeout)

            result = results.as_numpy('prob')

            detected_objects = postprocess(result, frame.shape[1], frame.shape[0], [FLAGS.width, FLAGS.height], FLAGS.confidence, FLAGS.nms)
            print(f"Frame {counter}: {len(detected_objects)} objects")
            counter += 1

            for box in detected_objects:
                print(f"{COCOLabels(box.classID).name}:{box.confidence}")
                plot_one_box(box.box(), frame, color=tuple(RAND_COLORS[box.classID % 64].tolist()), label=f"{COCOLabels(box.classID).name}: {box.confidence:.2f}",)   

            if FLAGS.out:
                out.write(frame)
            else:
                cv2.imshow('image', frame)
                if cv2.waitKey(1) == ord('q'):
                    break

        if FLAGS.model_info:
            statistics = triton_client.get_inference_statistics(model_name=FLAGS.model)
            if len(statistics.model_stats) != 1:
                print("FAILED: get_inference_statistics")
                sys.exit(1)
            print(statistics)
        print("Done")

        cap.release()
        if FLAGS.out:
            out.release()
        else:
            cv2.destroyAllWindows()
        print("Done")

boundingbox.py文件

# boundingbox.py
class BoundingBox:
    def __init__(self, classID, confidence, x1, x2, y1, y2, image_width, image_height):
        self.classID = classID
        self.confidence = confidence
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2
        self.u1 = x1 / image_width
        self.u2 = x2 / image_width
        self.v1 = y1 / image_height
        self.v2 = y2 / image_height
    
    def box(self):
        return (self.x1, self.y1, self.x2, self.y2)
        
    def width(self):
        return self.x2 - self.x1
    
    def height(self):
        return self.y2 - self.y1

    def center_absolute(self):
        return (0.5 * (self.x1 + self.x2), 0.5 * (self.y1 + self.y2))
    
    def center_normalized(self):
        return (0.5 * (self.u1 + self.u2), 0.5 * (self.v1 + self.v2))
    
    def size_absolute(self):
        return (self.x2 - self.x1, self.y2 - self.y1)
    
    def size_normalized(self):
        return (self.u2 - self.u1, self.v2 - self.v1)

labels.py文件

from enum import Enum

class COCOLabels(Enum):
    person = 0
    bicycle = 1
    car = 2
    motorcycle = 3
    airplane = 4
    bus = 5
    train = 6
    truck = 7
    boat = 8
    traffic_light = 9
    fire_hydrant = 10
    stop_sign = 11
    parking_meter = 12
    bench = 13
    bird = 14
    cat = 15
    dog = 16
    horse = 17
    sheep = 18
    cow = 19
    elephant = 20
    bear = 21
    zebra = 22
    giraffe = 23
    backpack = 24
    umbrella = 25
    handbag = 26
    tie = 27
    suitcase = 28
    frisbee = 29
    skis = 30
    snowboard = 31
    sports_ball = 32
    kite = 33
    baseball_bat = 34
    baseball_glove = 35
    skateboard = 36
    surfboard = 37
    tennis_racket = 38
    bottle = 39
    wine_glass = 40
    cup = 41
    fork = 42
    knife = 43
    spoon = 44
    bowl = 45
    banana = 46
    apple = 47
    sandwich = 48
    orange = 49
    broccoli = 50
    carrot = 51
    hot_dog = 52
    pizza = 53
    donut = 54
    cake = 55
    chair = 56
    sofa = 57
    pottedplant = 58
    bed = 59
    diningtable = 60
    toilet = 61
    tvmonitor = 62
    laptop = 63
    mouse = 64
    remote = 65
    keyboard = 66
    cell_phone = 67
    microwave = 68
    oven = 69
    toaster = 70
    sink = 71
    refrigerator = 72
    book = 73
    clock = 74
    vase = 75
    scissors = 76
    teddy_bear = 77
    hair_drier = 78
    toothbrush = 79

processing.py文件

# processing.py
from boundingbox import BoundingBox

import cv2
import numpy as np

def preprocess(raw_bgr_image, input_shape):
    """
    description: Preprocess an image before TRT YOLO inferencing.
                 Convert BGR image to RGB,
                 resize and pad it to target size, normalize to [0,1],
                 transform to NCHW format.          
    param:
        raw_bgr_image: int8 numpy array of shape (img_h, img_w, 3)
        input_shape: a tuple of (H, W)
    return:
        image:  the processed image float32 numpy array of shape (3, H, W)
    """
    input_w, input_h = input_shape
    image_raw = raw_bgr_image
    h, w, c = image_raw.shape
    image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
    # Calculate widht and height and paddings
    r_w = input_w / w
    r_h = input_h / h
    if r_h > r_w:
        tw = input_w
        th = int(r_w * h)
        tx1 = tx2 = 0
        ty1 = int((input_h - th) / 2)
        ty2 = input_h - th - ty1
    else:
        tw = int(r_h * w)
        th = input_h
        tx1 = int((input_w - tw) / 2)
        tx2 = input_w - tw - tx1
        ty1 = ty2 = 0
    # Resize the image with long side while maintaining ratio
    image = cv2.resize(image, (tw, th))
    # Pad the short side with (128,128,128)
    image = cv2.copyMakeBorder(
        image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128)
    )
    image = image.astype(np.float32)
    # Normalize to [0,1]
    image /= 255.0
    # HWC to CHW format:
    image = np.transpose(image, [2, 0, 1])
    return image


def xywh2xyxy(x, origin_h, origin_w, input_w, input_h):
    """
    description:    Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    param:
        origin_h:   height of original image
        origin_w:   width of original image
        x:          A boxes numpy, each row is a box [center_x, center_y, w, h]
    return:
        y:          A boxes numpy, each row is a box [x1, y1, x2, y2]
    """
    y = np.zeros_like(x)
    r_w = input_w / origin_w
    r_h = input_h / origin_h
    if r_h > r_w:
        y[:, 0] = x[:, 0] - x[:, 2] / 2
        y[:, 2] = x[:, 0] + x[:, 2] / 2
        y[:, 1] = x[:, 1] - x[:, 3] / 2 - (input_h - r_w * origin_h) / 2
        y[:, 3] = x[:, 1] + x[:, 3] / 2 - (input_h - r_w * origin_h) / 2
        y /= r_w
    else:
        y[:, 0] = x[:, 0] - x[:, 2] / 2 - (input_w - r_h * origin_w) / 2
        y[:, 2] = x[:, 0] + x[:, 2] / 2 - (input_w - r_h * origin_w) / 2
        y[:, 1] = x[:, 1] - x[:, 3] / 2
        y[:, 3] = x[:, 1] + x[:, 3] / 2
        y /= r_h

    return y

def bbox_iou(box1, box2, x1y1x2y2=True):
    """
    description: compute the IoU of two bounding boxes
    param:
        box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
        box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))            
        x1y1x2y2: select the coordinate format
    return:
        iou: computed iou
    """
    if not x1y1x2y2:
        # Transform from center and width to exact coordinates
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        # Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    # Get the coordinates of the intersection rectangle
    inter_rect_x1 = np.maximum(b1_x1, b2_x1)
    inter_rect_y1 = np.maximum(b1_y1, b2_y1)
    inter_rect_x2 = np.minimum(b1_x2, b2_x2)
    inter_rect_y2 = np.minimum(b1_y2, b2_y2)
    # Intersection area
    inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) * \
                 np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)
    # Union Area
    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou

def non_max_suppression(prediction, origin_h, origin_w, input_w, input_h, conf_thres=0.5, nms_thres=0.4):
    """
    description: Removes detections with lower object confidence score than 'conf_thres' and performs
    Non-Maximum Suppression to further filter detections.
    param:
        prediction: detections, (x1, y1, x2, y2, conf, cls_id)
        origin_h: original image height
        origin_w: original image width
        conf_thres: a confidence threshold to filter detections
        nms_thres: a iou threshold to filter detections
    return:
        boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
    """
    # Get the boxes that score > CONF_THRESH
    boxes = prediction[prediction[:, 4] >= conf_thres]
#     print(boxes)
    # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
    boxes[:, :4] = xywh2xyxy(boxes[:, :4], origin_h, origin_w, input_w, input_h )
    # clip the coordinates
    boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w -1)
    boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w -1)
    boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h -1)
    boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h -1)
    # Object confidence
    confs = boxes[:, 4]
    # Sort by the confs
    boxes = boxes[np.argsort(-confs)]
    # Perform non-maximum suppression
    keep_boxes = []
    while boxes.shape[0]:
        large_overlap = bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
        label_match = boxes[0, -1] == boxes[:, -1]
        # Indices of boxes with lower confidence scores, large IOUs and matching labels
        invalid = large_overlap & label_match
        keep_boxes += [boxes[0]]
        boxes = boxes[~invalid]
    boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
    return boxes

def postprocess(output, origin_w, origin_h, input_shape, conf_th=0.5, nms_threshold=0.5, letter_box=False):
    """Postprocess TensorRT outputs.
    # Args
        output: list of detections with schema 
        [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...] 
        conf_th: confidence threshold
        letter_box: boolean, referring to _preprocess_yolo()
    # Returns
        list of bounding boxes with all detections above threshold and after nms, see class BoundingBox
    """
    
    # Get the num of boxes detected
    # Here we use the first row of output in that batch_size = 1
    output = output[0]
    num = int(output[0])
    # Reshape to a two dimentional ndarray
    pred = np.reshape(output[1:], (-1, 6))[:num, :]

    # Do nms
    boxes = non_max_suppression(pred, origin_h, origin_w, input_shape[0], input_shape[1], conf_thres=conf_th, nms_thres=nms_threshold)
    result_boxes = boxes[:, :4] if len(boxes) else np.array([])
    result_scores = boxes[:, 4] if len(boxes) else np.array([])
    result_classid = boxes[:, 5].astype(np.int) if len(boxes) else np.array([])
        
    detected_objects = []
    for box, score, label in zip(result_boxes, result_scores, result_classid):
        detected_objects.append(BoundingBox(label, score, box[0], box[2], box[1], box[3], origin_h, origin_w))
    return detected_objects

render.py文件

# render.py
import numpy as np

import cv2

from math import sqrt

_LINE_THICKNESS_SCALING = 500.0

np.random.seed(69)
RAND_COLORS = np.random.randint(10, 255, (80, 3), "int")  # used for class visu

def render_box(img, box, color=(200, 200, 200)):
    """
    Render a box. Calculates scaling and thickness automatically.
    :param img: image to render into
    :param box: (x1, y1, x2, y2) - box coordinates
    :param color: (b, g, r) - box color
    :return: updated image
    """
    x1, y1, x2, y2 = box
    thickness = int(
        round(
            (img.shape[0] * img.shape[1])
            / (_LINE_THICKNESS_SCALING * _LINE_THICKNESS_SCALING)
        )
    )
    thickness = max(1, thickness)
    img = cv2.rectangle(
        img,
        (int(x1), int(y1)),
        (int(x2), int(y2)),
        color,
        thickness=thickness
    )
    return img

def render_filled_box(img, box, color=(200, 200, 200)):
    """
    Render a box. Calculates scaling and thickness automatically.
    :param img: image to render into
    :param box: (x1, y1, x2, y2) - box coordinates
    :param color: (b, g, r) - box color
    :return: updated image
    """
    x1, y1, x2, y2 = box
    img = cv2.rectangle(
        img,
        (int(x1), int(y1)),
        (int(x2), int(y2)),
        color,
        thickness=cv2.FILLED
    )
    return img

_TEXT_THICKNESS_SCALING = 700.0
_TEXT_SCALING = 520.0


def get_text_size(img, text, normalised_scaling=1.0):
    """
    Get calculated text size (as box width and height)
    :param img: image reference, used to determine appropriate text scaling
    :param text: text to display
    :param normalised_scaling: additional normalised scaling. Default 1.0.
    :return: (width, height) - width and height of text box
    """
    thickness = int(
        round(
            (img.shape[0] * img.shape[1])
            / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
        )
        * normalised_scaling
    )
    thickness = max(1, thickness)
    scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
    return cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scaling, thickness)[0]


def render_text(img, text, pos, color=(200, 200, 200), normalised_scaling=1.0):
    """
    Render a text into the image. Calculates scaling and thickness automatically.
    :param img: image to render into
    :param text: text to display
    :param pos: (x, y) - upper left coordinates of render position
    :param color: (b, g, r) - text color
    :param normalised_scaling: additional normalised scaling. Default 1.0.
    :return: updated image
    """
    x, y = pos
    thickness = int(
        round(
            (img.shape[0] * img.shape[1])
            / (_TEXT_THICKNESS_SCALING * _TEXT_THICKNESS_SCALING)
        )
        * normalised_scaling
    )
    thickness = max(2, thickness)
    scaling = img.shape[0] / _TEXT_SCALING * normalised_scaling
    size = get_text_size(img, text, normalised_scaling)
    cv2.putText(
        img,
        text,
        (int(x), int(y + size[1])),
        cv2.FONT_HERSHEY_SIMPLEX,
        scaling,
        color,
        thickness=thickness,
    )
    return img

def plot_one_box(x, img, color=None, label=None, line_thickness=None):
    """
    description: Plots one bounding box on image img,
                 this function comes from YoLov5 project.
    param: 
        x:      a box likes [x1,y1,x2,y2]
        img:    a opencv image object
        color:  color to draw rectangle, such as (0,255,0)
        label:  str
        line_thickness: int
    return:
        no return

    """
    tl = (
        line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
    )  # line/font thickness
    if color == None:
        color = [np.random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(
            img,
            label,
            (c1[0], c1[1] - 2),
            0,
            tl / 3,
            [225, 255, 255],
            thickness=tf,
            lineType=cv2.LINE_AA,
        )

参考文献:大佬的github链接
火焰烟雾检测链接robmarkcole/fire-detection-from-images
spacewalk01/yolov5-fire-detection
zk2ly/Smoke_Fire_Detection
imsaksham-c/Fire-Smoke-Detection
DeepQuestAI/Fire-Smoke-Dataset
Tatowo/Smoke_Fire_Detection

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值