mmdetection中使用自己的数据测试

mmdetection中进行测试

使用自己的数据进行测试,保存到本地并且能够按照置信度进行区分。
1、mmcv中带有show_result_pyplot的api但是仅仅能够显示图片,不能保存到本地,可以调用他的包进行修改后保存到本地。
2、他的api是将多所有的图片进行保存,在图片中仅显示检出大于设定置信度的框,低于的框不进行显示。
3、在mmdetection可以调用api中的inference_detector信息返回results,result是nmi的数组,其中n是你数据中class类别的数量,m是每个的类比对应的检测框的数量,i代表每个检测框的数组,i是长度为5的数组,i[0-3]是检测框的坐标,i[4]是检测框的置信度,而每个框是按照置信度从高到低排列的。
4、基于以上的分析就可以将每个类按照置信度进行分为检出图和没有检出的图。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from argparse import ArgumentParser
import os
from mmdet.apis import inference_detector, init_detector  #, show_result_pyplot
import cv2
import time
import numpy as np
 
def show_result_pyplot(model, img, result, score_thr=0.01, fig_size=(15, 10)):
    if hasattr(model, 'module'):
        model = model.module
    img = model.show_result(img, result, score_thr=score_thr, show=False)
    return img
def main():
    # config文件
    config_file = '/home/cv/mmdetection/work_dirs/lwx/7-12-cascade_rcnn_x101_32x4d_fpn_20e_coco/cascade_rcnn_x101_32x4d_fpn_20e_coco.py'
    # 训练好的模型
    checkpoint_file = '/home/cv/mmdetection/work_dirs/lwx/7-12-cascade_rcnn_x101_32x4d_fpn_20e_coco/latest.pth'

    # model = init_detector(config_file, checkpoint_file)
    model = init_detector(config_file, checkpoint_file, device='cuda:0')
 
    # 测试图片路径
    img_dir = '/home/cv/test_data/test_11/ps/ps_sp/'
    #img_dir = '/home/cv/test_data/test_11/test/'
     # 测试的分类
    test_class_id = 'ps_sp'    
    # 测试置信度
    test_thr = 0.6
    # 高于置信度的图片
    bad_dir = '/home/cv/mmdetection/outputs/test/ps_sp/bad'
    # 低于置信度的图片
    good_dir = '/home/cv/mmdetection/outputs/test/ps_sp/good'
    # class_name   
    class_name = ['yw','jb','ecl','ps_sp','ps_bx','ps_pd','ps_tq','pu']

    if not os.path.exists(bad_dir):
        os.makedirs(bad_dir)
    if not os.path.exists(good_dir):
        os.makedirs(good_dir)
    
    name_list = []  
    write_list = []
    for test in os.listdir(img_dir):
        
        start = time.time()
        #test = test.replace('\n', '')
        name = img_dir + test

        name_list.append(test)
        # print('model is processing the {}/{} images.'.format(count, len(test_list)))
        # result = inference_detector(model, name)
        # model = init_detector(config_file, checkpoint_file, device='cuda:0')
        result = inference_detector(model, name)
        #print(name)

        img = show_result_pyplot(model, name, result, score_thr=0.1)
        thr_list = []
        i = 0
        for class_result in result:
            #print('result:',class_result)
            if test_class_id == class_name[i]:
                if len(class_result) != 0:
                    #print(class_result,"   ",len(class_result))
                    for class_list in class_result:
                        #print('class_re',class_list)    
                        if class_list[4] >= test_thr:
                            print('{} {} is {}'.format(test,class_name[i], class_list[4]))
                            cv2.imwrite("{}/{}".format(bad_dir,test), img)
                            break
                        else:
                            cv2.imwrite("{}/{}".format(good_dir,test), img) 
                else:
                    cv2.imwrite("{}/{}".format(good_dir,test), img) 
            i += 1

if __name__ == '__main__':
    main()
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值