mxnet复现SSD之预测结果

mxnet复现SSD系列文章目录

一、数据集的导入.
二、SSD模型架构.
三、训练脚本的实现.
四、损失、评价函数.
五、预测结果.



前言

本项目是按照pascal voc的格式读取数据集,数据集为kaggle官网提供的口罩检测数据集,地址:Face Mask Detection,模型架构参考自gluoncv ssd_300_vgg16_atrous_voc源码


一、读取单张图片进行预测

代码实现

import os
import argparse
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import image, nd
from tools.tools import try_gpu, import_module


# 读取单张测试图片
def single_image_data_loader(filename, test_image_size=300):
    """
    加载测试用的图片,测试数据没有groundtruth标签
    """

    def reader():
        img_size = test_image_size
        file_path = os.path.join(filename)
        img = image.imread(file_path)
        img = image.imresize(img, img_size, img_size, 3).astype('float32')

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        mean = nd.array(mean).reshape((1, 1, -1))
        std = nd.array(std).reshape((1, 1, -1))
        out_img = (img / 255.0 - mean) / std
        out_img = out_img.transpose((2, 0, 1)).expand_dims(axis=0)    # 通道 h w c->c h w

        yield out_img
    return reader


# 预测目标
def predict(test_image, net, img, labels, threshold=0.3):
    anchors,bbox_preds,cls_preds= net(test_image)
    cls_probs = nd.SoftmaxActivation(cls_preds.transpose((0, 2, 1)), mode='channel')
    output = nd.contrib.MultiBoxDetection(cls_probs, bbox_preds, anchors,
                                          force_suppress=True, clip=True,
                                          threshold=0.5, nms_threshold=.45)

    idx = [i for i, row in enumerate(output[0]) if row[0].asscalar() != -1]
    if idx:
        output = output[0, idx]
        display(img, labels, output, threshold=threshold)
        return True
    else:
        return False


# 显示多个边界框
def show_bboxes(axes, bboxes, labels=None):
    for i, bbox in enumerate(bboxes, 0):
        bbox = bbox.asnumpy()
        rect = plt.Rectangle(
            xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
            fill=False, linewidth=2, color='w')
        axes.add_patch(rect)
        if labels:
            axes.text(rect.xy[0], rect.xy[1], labels,
                      horizontalalignment='center', verticalalignment='center', fontsize=8,
                      color='k', bbox=dict(facecolor='w', alpha=1))


def display(img, labels, output, threshold):
    fig = plt.imshow(img.asnumpy())
    for row in output:
        score = row[1].asscalar()
        if score < threshold:
            continue
        h, w = img.shape[0:2]
        bbox = [row[2:6] * nd.array((w, h, w, h), ctx=row.context)]
        label = labels[int(row[0].asscalar())]
        show_bboxes(fig.axes, bbox, '%s-%.2f' % (label, score))
    plt.show()


def parse_args():
    parser = argparse.ArgumentParser(description='predict the single image')
    parser.add_argument('--image-path', dest='img_path', help='image path',
                        default=None, type=str)
    parser.add_argument('--model', dest='model', help='choice model to use',
                        default='resnet_ssd', type=str)
    parser.add_argument('--model-params', dest='model_params', help='choice model params to use',
                        default='mask_resnet18_SSD_model.params', type=str)
    parser.add_argument('--class-names', dest='class_names', help='choice class to use',
                        default='without_mask,with_mask,mask_weared_incorrect', type=str)
    parser.add_argument('--image-shape', dest='image_shape', help='image shape',
                        default=512, type=int)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    # ctx = try_gpu()
    ctx = mx.cpu()

    img = image.imread(args.img_path).as_in_context(ctx)
    reader = single_image_data_loader(args.img_path, args.image_shape)
    labels = args.class_names.strip().split(',')
    class_nums = len(labels)

    model_path = os.path.join('model', args.model_params)
    net = import_module('model.'+args.model).get_model(class_nums, pretrained_model=model_path, pretrained=True, ctx=ctx)

    for x in reader():
        output = predict(x, net, img, labels)
        if not output:
            print('not found!')

结果展示

在这里插入图片描述

二、实时检测

1.代码实现

与单张图片检测不同的是,读取摄像头进行实时检测,需要创建两个线程,分别为读取图像和处理图像。只用一个线程会非常的卡,无法达到实时检测的目的。

import cv2
import time
import threading
from collections import deque
import mxnet as mx
from tools.tools import try_gpu, import_module
lock = threading.Lock()


def img_transform(img, img_size=500):
    img = mx.image.imresize(img, img_size, img_size, 3).astype('float32')
    orig_img = img.asnumpy().astype('uint8')
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    mean = mx.nd.array(mean).reshape((1, 1, -1))
    std = mx.nd.array(std).reshape((1, 1, -1))
    out_img = (img / 255.0 - mean) / std
    out_img = out_img.transpose((2, 0, 1)).expand_dims(axis=0)  # 通道 h w c->c h w

    return out_img, orig_img


# 预测目标
def predict(test_image, net):
    anchors, bbox_preds, cls_preds = net(test_image)
    cls_probs = mx.nd.SoftmaxActivation(cls_preds.transpose((0, 2, 1)), mode='channel')
    output = mx.nd.contrib.MultiBoxDetection(cls_probs, bbox_preds, anchors,
                                             force_suppress=True, clip=True,
                                             threshold=0.5, nms_threshold=.45)

    idx = [i for i, row in enumerate(output[0]) if row[0].asscalar() != -1]
    if idx:
        return output[0, idx]


# 摄像头的显示
class WebcamThread(threading.Thread):
    def __init__(self, input, output, img_height, img_width, threshold=0.5, labels=None):
        super(WebcamThread).__init__()
        self._jobq = input
        self._output = output
        self._num = 0
        self.cap = cv2.VideoCapture(0)
        self.img_height = img_height
        self.img_width = img_width
        self.labels = labels
        self.threshold = threshold
        threading.Thread.__init__(self)

    def run(self):
        start = time.time()
        cv2.namedWindow('camera', flags=cv2.WINDOW_NORMAL | cv2.WINDOW_FREERATIO)
        if not self.cap.isOpened():
            print('摄像头打开失败')
        while self.cap.isOpened():
            # 计算fps
            if self._num < 60:
                self._num += 1
            else:
                end = time.time()
                fps = self._num / (end - start)

                start = time.time()
                self._num = 0
                print('fps:', fps)

            ret, frame = self.cap.read()
            lock.acquire()
            if len(self._jobq) == 10:
                self._jobq.popleft()
            else:
                self._jobq.append(frame)
            lock.release()
            frame = cv2.resize(frame, (self.img_width, self.img_height))
            if self._output[0] is not None:
                output = self._output[0]
                for row in output:
                    score = row[1].asscalar()
                    if score < self.threshold:
                        cv2.imshow('camera', frame)
                        continue
                    bounding_boxes = [
                        row[2:6] * mx.nd.array((self.img_width, self.img_height, self.img_width, self.img_height),
                                               ctx=row.context)]
                    for bbox in bounding_boxes:
                        bbox = bbox.asnumpy()
                        cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (205, 0, 0), 2)
                        if self.labels:
                            label = self.labels[int(row[0].asscalar())]
                            font = cv2.FONT_HERSHEY_TRIPLEX
                            cv2.putText(frame, label, (bbox[0], bbox[1]), font, 0.8, (205, 0, 0), 1, cv2.LINE_8)
                        cv2.imshow('camera', frame)
            else:
                cv2.imshow('camera', frame)
            if cv2.waitKey(1) == ord('q'):
                # 退出程序
                break
        print("实时读取线程退出!!!!")
        cv2.destroyWindow('camera')
        self._jobq.clear()  # 读取进程结束时清空队列
        self.cap.release()


# 处理摄像头传来的数据
class ModelDealhread(threading.Thread):
    def __init__(self, input, output, img_size, ctx=mx.cpu()):
        super(ModelDealhread).__init__()
        self._jobq = input
        self._output = output
        self.img_size = img_size
        self.ctx = ctx
        threading.Thread.__init__(self)

    def run(self):
        flag = False
        while True:
            if len(self._jobq) != 0:
                lock.acquire()
                im_new = self._jobq.pop()
                lock.release()

                frame = mx.nd.array(cv2.cvtColor(im_new, cv2.COLOR_BGR2RGB)).astype('uint8')
                img, frame = img_transform(frame, img_size=self.img_size)
                output = predict(img.as_in_context(self.ctx), net)

                lock.acquire()
                self._output[0] = output
                lock.release()
                # cv2.waitKey(500)
                flag = True
            elif flag is True and len(self._jobq) == 0:
                break

        print("间隔1s获取图像线程退出!!!!")


if __name__ == "__main__":
    ctx = try_gpu()
    # net = vgg_ssd.get_model(3, pretrained_model='model/mask_SSD_model.params', pretrained=True,ctx=ctx)
    net = import_module('model.resnet_ssd').get_model(3, pretrained_model='model/mask_resnet18_SSD_model.params',
                                                      pretrained=True, ctx=ctx)
    net.hybridize()

    q = deque([], 10)   # 双端队列,存储当前帧
    output_q = [None]   # 模型的输出
    labels = ['without_mask', 'with_mask', 'mask_weared_incorrect']
    th1 = WebcamThread(q, output_q, 500, 500, labels=labels)
    th2 = ModelDealhread(q, output_q, 500, ctx=ctx)

    # 开启两个线程
    th1.start()
    th2.start()

    th1.join()
    th2.join()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值