目标检测任务是指给定一张图片,网络预测出图片中所包括的所有物体类别和对应的边界框
以我们提供的猫 cat 小数据集为例,带大家 15 分钟轻松上手 MMYOLO 目标检测。整个流程包含如下步骤:
文章目录
一、环境安装
假设你已经提前安装好了 Conda,接下来安装 PyTorch
conda create -n mmyolo python=3.8 -y
conda activate mmyolo
# 如果你有 GPU
conda install pytorch torchvision -c pytorch
# 如果你是 CPU
# conda install pytorch torchvision cpuonly -c pytorch
安装 MMYOLO 和依赖库:
git clone https://github.com/open-mmlab/mmyolo.git
cd mmyolo
pip install -U openmim
mim install -r requirements/mminstall.txt
# Install albumentations
mim install -r requirements/albu.txt
# Install MMYOLO
mim install -v -e .
# "-v" 指详细说明,或更多的输出
# "-e" 表示在可编辑模式下安装项目,因此对代码所做的任何本地修改都会生效,从而无需重新安装。
温馨提醒:由于本仓库采用的是 OpenMMLab 2.0,请最好新建一个 conda 虚拟环境,防止和 OpenMMLab 1.0 已经安装的仓库冲突。
二、安装和验证
步骤 0. 使用 MIM 安装 MMEngine、 MMCV 和 MMDetection 。
pip install -U openmim
mim install "mmengine>=0.6.0"
mim install "mmcv>=2.0.0rc4,<2.1.0"
mim install "mmdet>=3.0.0,<4.0.0"
如果你当前已经处于 mmyolo 工程目录下,则可以采用如下简化写法:
cd mmyolo
pip install -U openmim
mim install -r requirements/mminstall.txt
注意:
a. 在 MMCV-v2.x 中,mmcv-full 改名为 mmcv,如果你想安装不包含 CUDA 算子精简版,可以通过 mim install mmcv-lite>=2.0.0rc1 来安装。
b. 如果使用 albumentations,我们建议使用 pip install -r requirements/albu.txt 或者 pip install -U albumentations --no-binary qudida,albumentations 进行安装。 如果简单地使用 pip install albumentations==1.0.1 进行安装,则会同时安装 opencv-python-headless(即便已经安装了 opencv-python 也会再次安装)。我们建议在安装 albumentations 后检查环境,以确保没有同时安装 opencv-python 和 opencv-python-headless,因为同时安装可能会导致一些问题。更多细节请参考 官方文档 。
步骤 1. 安装 MMYOLO
方案 1. 如果你基于 MMYOLO 框架开发自己的任务,建议从源码安装:
git clone https://github.com/open-mmlab/mmyolo.git
cd mmyolo
# Install albumentations
mim install -r requirements/albu.txt
# Install MMYOLO
mim install -v -e .
# "-v" 指详细说明,或更多的输出
# "-e" 表示在可编辑模式下安装项目,因此对代码所做的任何本地修改都会生效,从而无需重新安装。
方案 2. 如果你将 MMYOLO 作为依赖或第三方 Python 包,使用 MIM 安装
mim install "mmyolo"
为了验证 MMYOLO 是否安装正确,我们提供了一些示例代码来执行模型推理。
步骤 1. 我们需要下载配置文件和模型权重文件。
mim download mmyolo --config yolov5_s-v61_syncbn_fast_8xb16-300e_coco --dest .
下载将需要几秒钟或更长时间,这取决于你的网络环境。完成后,你会在当前文件夹中发现两个文件 yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py 和 yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth。
步骤 2. 推理验证
方案 1. 如果你通过源码安装的 MMYOLO,那么直接运行以下命令进行验证:
python demo/image_demo.py demo/demo.jpg \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
# 可选参数
# --out-dir ./output *检测结果输出到指定目录下,默认为./output, 当--show参数存在时,不保存检测结果
# --device cuda:0 *使用的计算资源,包括cuda, cpu等,默认为cuda:0
# --show *使用该参数表示在屏幕上显示检测结果,默认为False
# --score-thr 0.3 *置信度阈值,默认为0.3
运行结束后,在 output 文件夹中可以看到检测结果图像,图像中包含有网络预测的检测框。
支持输入类型包括
单张图片, 支持 jpg, jpeg, png, ppm, bmp, pgm, tif, tiff, webp。
文件目录,会遍历文件目录下所有图片文件,并输出对应结果。
网址,会自动从对应网址下载图片,并输出结果。
方案 2. 如果你通过 MIM 安装的 MMYOLO, 那么可以打开你的 Python 解析器,复制并粘贴以下代码:
from mmdet.apis import init_detector, inference_detector
config_file = 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
checkpoint_file = 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth'
model = init_detector(config_file, checkpoint_file, device='cpu') # or device='cuda:0'
inference_detector(model, 'demo/demo.jpg')
你将会看到一个包含 DetDataSample 的列表,预测结果在 pred_instance 里,包含有预测框、预测分数 和 预测类别。
三、数据集准备
Cat 数据集是一个包括 144 张图片的单类别数据集(本 cat 数据集由 @RangeKing 提供原始图片,由 @PeterH0323 进行数据清洗), 包括了训练所需的标注信息。 样例图片如下所示:
你只需执行如下命令即可下载并且直接用起来:
python tools/misc/download_dataset.py --dataset-name cat --save-dir ./data/cat --unzip --delete
数据集组织格式如下所示:
data 位于 mmyolo 工程目录下, data/cat/annotations 中存放的是 COCO 格式的标注,data/cat/images 中存放的是所有图片。
四、配置准备
以 YOLOv5 算法为例,考虑到用户显存和内存有限,我们需要修改一些默认训练参数来让大家愉快的跑起来,核心需要修改的参数如下:
- YOLOv5 是 Anchor-Based 类算法,不同的数据集需要自适应计算合适的 Anchor
- 默认配置是 8 卡,每张卡 batch size 为 16,现将其改成单卡,每张卡 batch size 为 12
- 默认训练 epoch 是 300,将其改成 40 epoch
- 由于数据集太小,我们选择固定 backbone 网络权重
- 原则上 batch size 改变后,学习率也需要进行线性缩放,但是实测发现不需要
具体操作为在 configs/yolov5 文件夹下新建 yolov5_s-v61_fast_1xb12-40e_cat.py 配置文件(为了方便大家直接使用,我们已经提供了该配置),并把以下内容复制配置文件中。
# 基于该配置进行继承并重写部分配置
_base_ = 'yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
data_root = './data/cat/' # 数据集根路径
class_name = ('cat', ) # 数据集类别名称
num_classes = len(class_name) # 数据集类别数
# metainfo 必须要传给后面的 dataloader 配置,否则无效
# palette 是可视化时候对应类别的显示颜色
# palette 长度必须大于或等于 classes 长度
metainfo = dict(classes=class_name, palette=[(20, 220, 60)])
# 基于 tools/analysis_tools/optimize_anchors.py 自适应计算的 anchor
anchors = [
[(68, 69), (154, 91)