mmdetection中进行测试
使用自己的数据进行测试,保存到本地并且能够按照置信度进行区分。
1、mmcv中带有show_result_pyplot的api但是仅仅能够显示图片,不能保存到本地,可以调用他的包进行修改后保存到本地。
2、他的api是将多所有的图片进行保存,在图片中仅显示检出大于设定置信度的框,低于的框不进行显示。
3、在mmdetection可以调用api中的inference_detector信息返回results,result是nmi的数组,其中n是你数据中class类别的数量,m是每个的类比对应的检测框的数量,i代表每个检测框的数组,i是长度为5的数组,i[0-3]是检测框的坐标,i[4]是检测框的置信度,而每个框是按照置信度从高到低排列的。
4、基于以上的分析就可以将每个类按照置信度进行分为检出图和没有检出的图。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from argparse import ArgumentParser
import os
from mmdet.apis import inference_detector, init_detector #, show_result_pyplot
import cv2
import time
import numpy as np
def show_result_pyplot(model, img, result, score_thr=0.01, fig_size=(15, 10)):
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
return img
def main():
# config文件
config_file = '/home/cv/mmdetection/work_dirs/lwx/7-12-cascade_rcnn_x101_32x4d_fpn_20e_coco/cascade_rcnn_x101_32x4d_fpn_20e_coco.py'
# 训练好的模型
checkpoint_file = '/home/cv/mmdetection/work_dirs/lwx/7-12-cascade_rcnn_x101_32x4d_fpn_20e_coco/latest.pth'
# model = init_detector(config_file, checkpoint_file)
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# 测试图片路径
img_dir = '/home/cv/test_data/test_11/ps/ps_sp/'
#img_dir = '/home/cv/test_data/test_11/test/'
# 测试的分类
test_class_id = 'ps_sp'
# 测试置信度
test_thr = 0.6
# 高于置信度的图片
bad_dir = '/home/cv/mmdetection/outputs/test/ps_sp/bad'
# 低于置信度的图片
good_dir = '/home/cv/mmdetection/outputs/test/ps_sp/good'
# class_name
class_name = ['yw','jb','ecl','ps_sp','ps_bx','ps_pd','ps_tq','pu']
if not os.path.exists(bad_dir):
os.makedirs(bad_dir)
if not os.path.exists(good_dir):
os.makedirs(good_dir)
name_list = []
write_list = []
for test in os.listdir(img_dir):
start = time.time()
#test = test.replace('\n', '')
name = img_dir + test
name_list.append(test)
# print('model is processing the {}/{} images.'.format(count, len(test_list)))
# result = inference_detector(model, name)
# model = init_detector(config_file, checkpoint_file, device='cuda:0')
result = inference_detector(model, name)
#print(name)
img = show_result_pyplot(model, name, result, score_thr=0.1)
thr_list = []
i = 0
for class_result in result:
#print('result:',class_result)
if test_class_id == class_name[i]:
if len(class_result) != 0:
#print(class_result," ",len(class_result))
for class_list in class_result:
#print('class_re',class_list)
if class_list[4] >= test_thr:
print('{} {} is {}'.format(test,class_name[i], class_list[4]))
cv2.imwrite("{}/{}".format(bad_dir,test), img)
break
else:
cv2.imwrite("{}/{}".format(good_dir,test), img)
else:
cv2.imwrite("{}/{}".format(good_dir,test), img)
i += 1
if __name__ == '__main__':
main()