import os, argparse
import importlib
import json
import time
import cv2
import numpy as np
import mxnet as mx
from core.detection_module import DetModule
from utils.load_model import load_checkpoint
from utils.patch_config import patch_config_as_nothrow
import time
from datetime import datetime
# 改成你的数据集中的类别即可
coco = (
"queya",
"loujiao",
"broken",
"xiaomi",
"queya2",
"banbianya"
)
# 改成你的数据集中的类别即可,为每种类别的框赋予一种颜色
colors = {"queya":(0,255,255), #黄色 ok
"loujiao":(0,255,0), # 鲜绿 ok
"broken": (255, 255, 0), # 青色 ok
"xiaomi": (255, 144, 30), # 蓝色 0k
"queya2": (0, 97, 255), # 橙色
"banbianya": (203, 192, 255), # 粉红色
}
class Timer(object):
"""A simple timer."""
def __init__(self):
self.total_time = 0.
self.calls = 0
self.start_time = 0.
self.diff = 0.
self.average_time = 0.
def tic(self):
# using time.time instead of time.clock because time time.clock
# does not normalize for multithreading
self.start_time = time.time()
def toc(self, average=True):
self.diff = time.time() - self.start_time
self.total_time += self.diff
self.calls += 1
self.average_time = self.total_time / self.calls
if average:
return self.average_time
else:
return self.diff
class TDNDetector:
def __init__(self, configFn, ctx, outFolder, threshold):
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
config = importlib.import_module(configFn.replace('.py', '').replace('/', '.'))
_, _, _, _, _, _, self.__pModel, _, self.__pTest, self.transform, _, _, _ = config.get_config(is_train=False)
self.__pModel = patch_config_as_nothrow(self.__pModel)
self.__pTest = patch_config_as_nothrow(self.__pTest)
self.resizeParam = (800, 1200)
if callable(self.__pTest.nms.type):
self.__nms = self.__pTest.nms.type(self.__pTest.nms.thr)
else:
from operator_py.nms import py_nms_wrapper
self.__nms = py_nms_wrapper(self.__pTest.nms.thr)
arg_params, aux_params = load_checkpoint(self.__pTest.model.prefix, args.epoch) # self.__pTest.model.epoch
sym = self.__pModel.test_symbol
from utils.graph_optimize import merge_bn
sym, arg_params, aux_params = merge_bn(sym, arg_params, aux_params)
self.__mod = DetModule(sym, data_names=['data', 'im_info', 'im_id', 'rec_id'], context=ctx)
self.__mod.bind(data_shapes=[('data', (1, 3, self.resizeParam[0], self.resizeParam[1])),
('im_info', (1, 3)),
('im_id', (1,)),
('rec_id', (1,))], for_training=False)
self.__mod.set_params(arg_params, aux_params, allow_extra=False)
self.__saveSymbol(sym, outFolder, self.__pTest.model.prefix.split('/')[-1])
self.__threshold = threshold
def __call__(self, imgFilename): # detect onto image
roi_record, scale = self.__readImg(imgFilename)
h, w = roi_record['data'][0].shape
im_c1 = roi_record['data'][0].reshape(1, 1, h, w)
im_c2 = roi_record['data'][1].reshape(1, 1, h, w)
im_c3 = roi_record['data'][2].reshape(1, 1, h, w)
im_data = np.concatenate((im_c1, im_c2, im_c3), axis=1)
im_info, im_id, rec_id = [(h, w, scale)], [1], [1]
data = mx.io.DataBatch(data=[mx.nd.array(im_data),
mx.nd.array(im_info),
mx.nd.array(im_id),
mx.nd.array(rec_id)])
self.__mod.forward(data, is_train=False)
# extract results
outputs = self.__mod.get_outputs(merge_multi_context=False)
rid, id, info, cls, box = [x[0].asnumpy() for x in outputs]
rid, id, info, cls, box = rid.squeeze(), id.squeeze(), info.squeeze(), cls.squeeze(), box.squeeze()
cls = cls[:, 1:] # remove background
box = box / scale
output_record = dict(rec_id=rid, im_id=id, im_info=info, bbox_xyxy=box, cls_score=cls)
output_record = self.__pTest.process_output([output_record], None)[0]
final_result = self.__do_nms(output_record)
# obtain representable output
detections = []
for cid, bbox in final_result.items():
idx = np.where(bbox[:, -1] > self.__threshold)[0]
for i in idx:
final_box = bbox[i][:4]
score = bbox[i][-1]
detections.append({'cls': cid, 'box': final_box, 'score': score})
return detections, None
def __do_nms(self, all_output):
box = all_output['bbox_xyxy']
score = all_output['cls_score']
final_dets = {}
for cid in range(score.shape[1]):
score_cls = score[:, cid]
valid_inds = np.where(score_cls > self.__threshold)[0]
box_cls = box[valid_inds]
score_cls = score_cls[valid_inds]
if valid_inds.shape[0] == 0:
continue
det = np.concatenate((box_cls, score_cls.reshape(-1, 1)), axis=1).astype(np.float32)
det = self.__nms(det)
cls = coco[cid]
final_dets[cls] = det
return final_dets
def __readImg(self, imgFilename):
img = cv2.imread(imgFilename, cv2.IMREAD_COLOR)
height, width, channels = img.shape
roi_record = {'gt_bbox': np.array([[0., 0., 0., 0.]]), 'gt_class': np.array([0])}
roi_record['image_url'] = imgFilename
roi_record['resize_long'] = width
roi_record['resize_short'] = height
for trans in self.transform:
trans.apply(roi_record)
img_shape = [roi_record['resize_long'], roi_record['resize_short']]
shorts, longs = min(img_shape), max(img_shape)
scale = min(self.resizeParam[0] / shorts, self.resizeParam[1] / longs)
return roi_record, scale
def __saveSymbol(self, sym, outFolder, fnPrefix):
if not os.path.exists(outFolder): os.makedirs(outFolder)
resFilename = os.path.join(outFolder, fnPrefix + "_symbol_test.json")
sym.save(resFilename)
def parse_args():
parser = argparse.ArgumentParser(description='Test Detection')
#parser.add_argument('--config', type=str, default='config/tridentnet_r101v2c4_c5_1x.py', help='config file path')
parser.add_argument('--config', type=str, default='config/cascade_r101v1_fpn_1x.py', help='config file path')
parser.add_argument('--ctx', type=int, default=0, help='GPU index. Set negative value to use CPU')
# 把要测试的图像所在文件夹的路径传进去
parser.add_argument('--img_input', help='the image path', type=str, default='./data/coco/images/test/')
# 存储测试结果的文件夹
parser.add_argument('--output', type=str, default='./data/coco/images/draw_result/', help='Where to store results')
# 测试集json文件的地址,将gt画到测试图像上面
parser.add_argument('--jsonPath', type=str, default='./data/coco/annotations/instances_test.json', help='instances_test.json path')
parser.add_argument('--threshold', type=float, default=0.5, help='Detector threshold')
# 设置使用第几个epoch保存的模型
parser.add_argument('--epoch', help='override test epoch specified by config', type=int, default=5)
return parser.parse_args()
def draw(img, dets, gt):
# 先拼出要保存的文件路径
outPath = args.output + img.split("/")[-1]
# 改成你的数据集中类别,注意要和你的类别编号对应上
lable_dict = {1:"queya", 2:"loujiao", 3:"broken", 4:"xiaomi", 5:"queya2", 6:"banbianya"}
# 读json文件,根据文名名找到gt并画上去
img_id = int(img.split("/")[-1][:-4])
img = cv2.imread(img)
# 从测试集的json文件中读取gt信息,并画上去,方便查看检测效果咋样,不想画,可以注释掉
for i in gt:
if i["image_id"] == img_id:
gt_box = i["bbox"]
gt_lable = i["category_id"]
x1,x2 = (int(gt_box[0]), int(gt_box[1])), (int(gt_box[0])+int(gt_box[2]), int(int(gt_box[1])+gt_box[3]))
cv2.rectangle(img, x1, x2, (127,255,0), thickness=1,lineType=cv2.LINE_AA) # 检测框
cv2.putText(img, lable_dict[gt_lable], (x2[0], x2[1] + 2), 0, 2 / 3, [127,255,0], thickness=2, lineType=cv2.LINE_AA)
# 画检测框等信息
for i in range(len(dets)):
bbox = dets[i]['box']
label = '%s %.2f' % (dets[i]["cls"], dets[i]["score"])
tl = 2 # round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line thickness
color = colors[dets[i]["cls"]]
c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl) # 检测框
tf = max(tl - 1, 1) # font thickness 字体的粗细
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] # 找出文字的大小
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 # 注意此处的c2发生了变化
cv2.rectangle(img, c1, c2, color, -1) # filled 文字的背景色 -1表示填充背景色
cv2.putText(img, label, (c1[0], c1[1] -2), 0, tl / 3, [0, 0, 0], thickness=tf, lineType=cv2.LINE_AA) # 文字
cv2.imwrite(outPath, img) #保存图像
# cv2.waitKey(0)
if __name__ == "__main__":
print("1:",datetime.fromtimestamp(time.time()))
args = parse_args()
ctx = mx.gpu(args.ctx) if args.ctx >= 0 else args.cpu()
# imgFilenames = args.inputs
imgFilenames = os.listdir(args.img_input)
imgFilePaths = [args.img_input + i for i in imgFilenames]
print("2:",datetime.fromtimestamp(time.time()))
detector = TDNDetector(args.config, ctx, args.output, args.threshold)
# test.json文件只需要读取一次就行
import json
with open(args.jsonPath,"r") as load_f:
load_dict = json.load(load_f)
gt_list = load_dict["annotations"]
print("3:", datetime.fromtimestamp(time.time()))
_t = {'im_detect': Timer(), 'misc': Timer()}
total_dectime = 0
for i, imgFilePath in enumerate(imgFilePaths):
_t['im_detect'].tic()
dets, _ = detector(imgFilePath)
draw(imgFilePath, dets, gt_list)
#print(dets)
detect_time = _t['im_detect'].toc(average=False)
print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1,len(imgFilePaths), detect_time))
total_dectime += detect_time
print("测试结束!!!")
print("total_dectime = ", total_dectime)