MMYOLO、MMROTATE图片批量检测脚本

对官方给的img_demo.py做了一点改动,可以输入图片文件夹,将结果输出到output文件夹中:

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

import mmcv
from mmdet.apis import inference_detector, init_detector

from mmrotate.registry import VISUALIZERS
from mmrotate.utils import register_all_modules
import os


def parse_args():
    parser = ArgumentParser()
    # parser.add_argument('img', help='Image file') # 图片
    parser.add_argument('--input_dir', help='Path to input directory') # 输入文件夹路径
    parser.add_argument('--output_dir', help='Path to output directory') # 输出文件夹路径
    parser.add_argument('config', help='Config file') # 配置文件
    parser.add_argument('checkpoint', help='Checkpoint file') # 权重文件路径
    # parser.add_argument('--out-file', default=None, help='Path to output file') # 输出文件路径
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='dota',
        choices=['dota', 'sar', 'hrsc', 'random'],
        help='Color palette used for visualization')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    

    parser.add_argument('--sufix_input',default='.png',help='Suffix of input files') # 输入文件后缀
    parser.add_argument('--sufix_output',default='.png',help='Suffix of output files') # 输出文件后缀


    args = parser.parse_args()
    return args


def main(args):

    # register all modules in mmrotate into the registries
    register_all_modules()

    # build the model from a config file and a checkpoint file
    model = init_detector(
        args.config, args.checkpoint, palette=args.palette, device=args.device)

    # init visualizer
    visualizer = VISUALIZERS.build(model.cfg.visualizer)
    # the dataset_meta is loaded from the checkpoint and
    # then pass to the model in init_detector
    visualizer.dataset_meta = model.dataset_meta

    for file in os.listdir(args.input_dir):
        if not file.endswith(args.sufix_input):
            continue
        file_path = os.path.join(args.input_dir, file) # 获取图片路径
        img = mmcv.imread(file_path)
        img = mmcv.imconvert(img, 'bgr', 'rgb')
        # img_name = file.split('/')[-1]
        out_img_name = file.replace(args.sufix_input, args.sufix_output) # 获取输出的图片名字
        out_img_path = os.path.join(args.output_dir, out_img_name) # 得到输出文件路径
        result = inference_detector(model, img)
        visualizer.add_datasample(
            'result',
            img,
            data_sample=result,
            draw_gt=False,
            show=False,
            wait_time=0,
            out_file=out_img_path,
            pred_score_thr=args.score_thr)

    # # test a single image 推理单张图片
    # result = inference_detector(model, args.img)

    # # show the results
    # img = mmcv.imread(args.img)
    # img = mmcv.imconvert(img, 'bgr', 'rgb')
    # visualizer.add_datasample(
    #     'result',
    #     img,
    #     data_sample=result,
    #     draw_gt=False,
    #     show=args.out_file is None,
    #     wait_time=0,
    #     out_file=args.out_file,
    #     pred_score_thr=args.score_thr)


if __name__ == '__main__':
    args = parse_args()
    main(args)

使用方法:

python demo/image_demo.py --input_dir '/root/autodl-tmp/split_ss_dota/test/images' --output_dir 'output' work_dirs/dy889-l/rotated_rtmdet_l-3x-dota.py  work_dirs/dy889-l/epoch_36.pth
python demo/image_demo.py --input_dir 'path to your input dir' --output_dir 'output' <配置文件路径>  <权重文件路径>

在这里插入图片描述
这两个一个是输入图片的后缀,一个是输出图片的后缀,默认都是png,也有bmp格式的图片如HRSC数据集,可以手动改一下或者在命令行输入参数。

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Goafan

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值