simpledet-tridentnet在测试数据集上画检测框并显示类别和置信度

import os, argparse
import importlib
import json
import time
import cv2
import numpy as np
import mxnet as mx
from core.detection_module import DetModule
from utils.load_model import load_checkpoint
from utils.patch_config import patch_config_as_nothrow
import time
from datetime import datetime

# 改成你的数据集中的类别即可
coco = (
    "queya",
    "loujiao",
    "broken",
    "xiaomi",
    "queya2",
    "banbianya"
)

# 改成你的数据集中的类别即可,为每种类别的框赋予一种颜色
colors = {"queya":(0,255,255), #黄色 ok
          "loujiao":(0,255,0), # 鲜绿  ok
          "broken": (255, 255, 0), # 青色 ok
          "xiaomi": (255, 144, 30), # 蓝色  0k
          "queya2": (0, 97, 255), # 橙色
          "banbianya": (203, 192, 255), # 粉红色
          }

class Timer(object):
    """A simple timer."""
    def __init__(self):
        self.total_time = 0.
        self.calls = 0
        self.start_time = 0.
        self.diff = 0.
        self.average_time = 0.

    def tic(self):
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        self.diff = time.time() - self.start_time
        self.total_time += self.diff
        self.calls += 1
        self.average_time = self.total_time / self.calls
        if average:
            return self.average_time
        else:
            return self.diff

class TDNDetector:
    def __init__(self, configFn, ctx, outFolder, threshold):
        os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
        config = importlib.import_module(configFn.replace('.py', '').replace('/', '.'))
        _, _, _, _, _, _, self.__pModel, _, self.__pTest, self.transform, _, _, _ = config.get_config(is_train=False)
        self.__pModel = patch_config_as_nothrow(self.__pModel)
        self.__pTest = patch_config_as_nothrow(self.__pTest)
        self.resizeParam = (800, 1200)
        if callable(self.__pTest.nms.type):
            self.__nms = self.__pTest.nms.type(self.__pTest.nms.thr)
        else:
            from operator_py.nms import py_nms_wrapper
            self.__nms = py_nms_wrapper(self.__pTest.nms.thr)
        arg_params, aux_params = load_checkpoint(self.__pTest.model.prefix, args.epoch) # self.__pTest.model.epoch
        sym = self.__pModel.test_symbol
        from utils.graph_optimize import merge_bn
        sym, arg_params, aux_params = merge_bn(sym, arg_params, aux_params)
        self.__mod = DetModule(sym, data_names=['data', 'im_info', 'im_id', 'rec_id'], context=ctx)
        self.__mod.bind(data_shapes=[('data', (1, 3, self.resizeParam[0], self.resizeParam[1])),
                                     ('im_info', (1, 3)),
                                     ('im_id', (1,)),
                                     ('rec_id', (1,))], for_training=False)
        self.__mod.set_params(arg_params, aux_params, allow_extra=False)
        self.__saveSymbol(sym, outFolder, self.__pTest.model.prefix.split('/')[-1])
        self.__threshold = threshold

    def __call__(self, imgFilename):  # detect onto image
        roi_record, scale = self.__readImg(imgFilename)
        h, w = roi_record['data'][0].shape

        im_c1 = roi_record['data'][0].reshape(1, 1, h, w)
        im_c2 = roi_record['data'][1].reshape(1, 1, h, w)
        im_c3 = roi_record['data'][2].reshape(1, 1, h, w)
        im_data = np.concatenate((im_c1, im_c2, im_c3), axis=1)

        im_info, im_id, rec_id = [(h, w, scale)], [1], [1]
        data = mx.io.DataBatch(data=[mx.nd.array(im_data),
                                     mx.nd.array(im_info),
                                     mx.nd.array(im_id),
                                     mx.nd.array(rec_id)])
        self.__mod.forward(data, is_train=False)
        # extract results
        outputs = self.__mod.get_outputs(merge_multi_context=False)
        rid, id, info, cls, box = [x[0].asnumpy() for x in outputs]
        rid, id, info, cls, box = rid.squeeze(), id.squeeze(), info.squeeze(), cls.squeeze(), box.squeeze()
        cls = cls[:, 1:]  # remove background
        box = box / scale
        output_record = dict(rec_id=rid, im_id=id, im_info=info, bbox_xyxy=box, cls_score=cls)
        output_record = self.__pTest.process_output([output_record], None)[0]
        final_result = self.__do_nms(output_record)
        # obtain representable output
        detections = []
        for cid, bbox in final_result.items():
            idx = np.where(bbox[:, -1] > self.__threshold)[0]
            for i in idx:
                final_box = bbox[i][:4]
                score = bbox[i][-1]
                detections.append({'cls': cid, 'box': final_box, 'score': score})
        return detections, None

    def __do_nms(self, all_output):
        box = all_output['bbox_xyxy']
        score = all_output['cls_score']
        final_dets = {}
        for cid in range(score.shape[1]):
            score_cls = score[:, cid]
            valid_inds = np.where(score_cls > self.__threshold)[0]
            box_cls = box[valid_inds]
            score_cls = score_cls[valid_inds]
            if valid_inds.shape[0] == 0:
                continue
            det = np.concatenate((box_cls, score_cls.reshape(-1, 1)), axis=1).astype(np.float32)
            det = self.__nms(det)
            cls = coco[cid]
            final_dets[cls] = det
        return final_dets

    def __readImg(self, imgFilename):
        img = cv2.imread(imgFilename, cv2.IMREAD_COLOR)
        height, width, channels = img.shape
        roi_record = {'gt_bbox': np.array([[0., 0., 0., 0.]]), 'gt_class': np.array([0])}
        roi_record['image_url'] = imgFilename
        roi_record['resize_long'] = width
        roi_record['resize_short'] = height

        for trans in self.transform:
            trans.apply(roi_record)
        img_shape = [roi_record['resize_long'], roi_record['resize_short']]
        shorts, longs = min(img_shape), max(img_shape)
        scale = min(self.resizeParam[0] / shorts, self.resizeParam[1] / longs)

        return roi_record, scale

    def __saveSymbol(self, sym, outFolder, fnPrefix):
        if not os.path.exists(outFolder): os.makedirs(outFolder)
        resFilename = os.path.join(outFolder, fnPrefix + "_symbol_test.json")
        sym.save(resFilename)


def parse_args():
    parser = argparse.ArgumentParser(description='Test Detection')
    #parser.add_argument('--config', type=str, default='config/tridentnet_r101v2c4_c5_1x.py', help='config file path')
    parser.add_argument('--config', type=str, default='config/cascade_r101v1_fpn_1x.py', help='config file path')
    parser.add_argument('--ctx', type=int, default=0, help='GPU index. Set negative value to use CPU')
    # 把要测试的图像所在文件夹的路径传进去
    parser.add_argument('--img_input', help='the image path', type=str, default='./data/coco/images/test/')
    # 存储测试结果的文件夹
    parser.add_argument('--output', type=str, default='./data/coco/images/draw_result/', help='Where to store results')
    # 测试集json文件的地址,将gt画到测试图像上面
    parser.add_argument('--jsonPath', type=str, default='./data/coco/annotations/instances_test.json', help='instances_test.json path')
    parser.add_argument('--threshold', type=float, default=0.5, help='Detector threshold')
    # 设置使用第几个epoch保存的模型
    parser.add_argument('--epoch', help='override test epoch specified by config', type=int, default=5)
    return parser.parse_args()

def draw(img, dets, gt):
    # 先拼出要保存的文件路径
    outPath = args.output + img.split("/")[-1]
    # 改成你的数据集中类别,注意要和你的类别编号对应上
    lable_dict = {1:"queya", 2:"loujiao", 3:"broken", 4:"xiaomi", 5:"queya2", 6:"banbianya"}

    # 读json文件,根据文名名找到gt并画上去
    img_id = int(img.split("/")[-1][:-4])
    img = cv2.imread(img)
    # 从测试集的json文件中读取gt信息,并画上去,方便查看检测效果咋样,不想画,可以注释掉
    for i in gt:
        if i["image_id"] == img_id:
            gt_box = i["bbox"]
            gt_lable = i["category_id"]
            x1,x2 = (int(gt_box[0]), int(gt_box[1])), (int(gt_box[0])+int(gt_box[2]), int(int(gt_box[1])+gt_box[3]))
            cv2.rectangle(img, x1, x2, (127,255,0), thickness=1,lineType=cv2.LINE_AA)  # 检测框
            cv2.putText(img, lable_dict[gt_lable], (x2[0], x2[1] + 2), 0, 2 / 3, [127,255,0], thickness=2, lineType=cv2.LINE_AA)

    # 画检测框等信息
    for i in range(len(dets)):
        bbox = dets[i]['box']
        label = '%s %.2f' % (dets[i]["cls"], dets[i]["score"])
        tl = 2 # round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line thickness
        color = colors[dets[i]["cls"]]
        c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
        cv2.rectangle(img, c1, c2, color, thickness=tl) # 检测框
        tf = max(tl - 1, 1)  # font thickness 字体的粗细
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]  # 找出文字的大小
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 注意此处的c2发生了变化
        cv2.rectangle(img, c1, c2, color, -1)  # filled 文字的背景色   -1表示填充背景色
        cv2.putText(img, label, (c1[0], c1[1] -2), 0, tl / 3, [0, 0, 0], thickness=tf, lineType=cv2.LINE_AA) # 文字

    cv2.imwrite(outPath, img) #保存图像
    # cv2.waitKey(0)


if __name__ == "__main__":
    print("1:",datetime.fromtimestamp(time.time()))
    args = parse_args()
    ctx = mx.gpu(args.ctx) if args.ctx >= 0 else args.cpu()
    # imgFilenames = args.inputs
    imgFilenames = os.listdir(args.img_input)
    imgFilePaths = [args.img_input + i for i in imgFilenames]
    print("2:",datetime.fromtimestamp(time.time()))
    detector = TDNDetector(args.config, ctx, args.output, args.threshold)
    # test.json文件只需要读取一次就行
    import json
    with open(args.jsonPath,"r") as load_f:
        load_dict = json.load(load_f)
    gt_list = load_dict["annotations"]
    print("3:", datetime.fromtimestamp(time.time()))
    _t = {'im_detect': Timer(), 'misc': Timer()}
    total_dectime = 0
    for i, imgFilePath in enumerate(imgFilePaths):
        _t['im_detect'].tic()
        dets, _ = detector(imgFilePath)
        draw(imgFilePath, dets, gt_list)
        #print(dets)
        detect_time = _t['im_detect'].toc(average=False)
        print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1,len(imgFilePaths), detect_time))
        total_dectime += detect_time
    print("测试结束!!!")
    print("total_dectime = ", total_dectime)

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值