前言
mmdetection
作为一个优秀的开源目标检测算法库,在训练模型方面是相当的方便,但是某些时候使用它进行推理时就有点难受,本文就演示如何批量推理图片(多张图片存放在文件夹中),mmdetection
的版本是2.27.0
批量推理图片
在mmdetection
中想测试图片那必须得有对应的标注信息文件,要是没有的话调用官方api
只能一张一张推理,慢的要死,还是自己弄一个文件靠谱。可以在根目录下创建一个batch_infer.py
的文件,这里需要调用推理的api
,我直接贴上代码:
import argparse
import os
from mmdet.apis import inference_detector, init_detector #, show_result_pyplot
import cv2
from pathlib import Path
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('--config', type=str, help='配置文件路径')
parser.add_argument('--checkpoint-file', type=str, help='权重文件路径')
parser.add_argument(
'--img-dir', type=str,
help='待检测图片路径')
parser.add_argument('--out-dir', type=str, help='保存检测图片路径')
parser.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed testing)')
parser.add_argument(
'--score-thr',
type=float,
default=0.50,
help='score threshold (default: 0.50)')
args = parser.parse_args()
return args
def show_result_pyplot(model, img, result, score_thr=0.3, fig_size=(15, 10)):
"""Visualize the detection results on the image.
Args:
model (nn.Module): The loaded detector.
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
return img
def main():
args = parse_args()
# config文件
config_file = args.config
# 训练好的模型
checkpoint_file = args.checkpoint_file
# checkpoint_file = 'work_dirs/faster_rcnn_r50_fpn_1x_coco/epoch_300.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# 图片路径
img_dir = args.img_dir
# 检测后存放图片路径
out_dir = args.out_dir
if not os.path.exists(out_dir):
os.mkdir(out_dir)
# 检测阈值
score_thr = args.score_thr
img_list = []
count = 0
path = Path(img_dir)
for p in path.iterdir():
# print('model is processing the {}/{} images.'.format(count, len(img_list)))
model = init_detector(config_file, checkpoint_file, device='cuda:0')
result = inference_detector(model, str(p))
img = show_result_pyplot(model, str(p), result, score_thr=score_thr)
cv2.imwrite("{}/{}.jpg".format(out_dir, p.stem), img)
if __name__ == '__main__':
main()
演示如何推理:
python batch_infer.py \
--config work_dirs/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco.py \
--checkpoint-file work_dirs/yolox_s_8x8_300e_coco/bast.pth \
--img-dir data/coco/test2000 --out-dir work_dirs/detect/xs/test2000
注意事项
我使用的是mmdeteciton-2.27.0
版本,2.x
版本应该是通用的;还有就是在推理时是需要用到GPU的,我测试过,如果没有GPU会报错,所以请注意这两点。