对官方给的img_demo.py做了一点改动,可以输入图片文件夹,将结果输出到output文件夹中:
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmrotate.registry import VISUALIZERS
from mmrotate.utils import register_all_modules
import os
def parse_args():
parser = ArgumentParser()
# parser.add_argument('img', help='Image file') # 图片
parser.add_argument('--input_dir', help='Path to input directory') # 输入文件夹路径
parser.add_argument('--output_dir', help='Path to output directory') # 输出文件夹路径
parser.add_argument('config', help='Config file') # 配置文件
parser.add_argument('checkpoint', help='Checkpoint file') # 权重文件路径
# parser.add_argument('--out-file', default=None, help='Path to output file') # 输出文件路径
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='dota',
choices=['dota', 'sar', 'hrsc', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
parser.add_argument('--sufix_input',default='.png',help='Suffix of input files') # 输入文件后缀
parser.add_argument('--sufix_output',default='.png',help='Suffix of output files') # 输出文件后缀
args = parser.parse_args()
return args
def main(args):
# register all modules in mmrotate into the registries
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_detector(
args.config, args.checkpoint, palette=args.palette, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta
for file in os.listdir(args.input_dir):
if not file.endswith(args.sufix_input):
continue
file_path = os.path.join(args.input_dir, file) # 获取图片路径
img = mmcv.imread(file_path)
img = mmcv.imconvert(img, 'bgr', 'rgb')
# img_name = file.split('/')[-1]
out_img_name = file.replace(args.sufix_input, args.sufix_output) # 获取输出的图片名字
out_img_path = os.path.join(args.output_dir, out_img_name) # 得到输出文件路径
result = inference_detector(model, img)
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
show=False,
wait_time=0,
out_file=out_img_path,
pred_score_thr=args.score_thr)
# # test a single image 推理单张图片
# result = inference_detector(model, args.img)
# # show the results
# img = mmcv.imread(args.img)
# img = mmcv.imconvert(img, 'bgr', 'rgb')
# visualizer.add_datasample(
# 'result',
# img,
# data_sample=result,
# draw_gt=False,
# show=args.out_file is None,
# wait_time=0,
# out_file=args.out_file,
# pred_score_thr=args.score_thr)
if __name__ == '__main__':
args = parse_args()
main(args)
使用方法:
python demo/image_demo.py --input_dir '/root/autodl-tmp/split_ss_dota/test/images' --output_dir 'output' work_dirs/dy889-l/rotated_rtmdet_l-3x-dota.py work_dirs/dy889-l/epoch_36.pth
python demo/image_demo.py --input_dir 'path to your input dir' --output_dir 'output' <配置文件路径> <权重文件路径>
这两个一个是输入图片的后缀,一个是输出图片的后缀,默认都是png,也有bmp格式的图片如HRSC数据集,可以手动改一下或者在命令行输入参数。