本教程一共包括如下流程:
1. 数据集准备和可视化
2.自定义配置文件
3.训练前可视化验证
4.模型训练
5.模型测试和推理
6.可视化分析
下面我们将以 MMDetection 团队自研的 RTMDet 算法为例,结合一个简单的 cat 数据集来描述整个训练推理可视化过程。
0、环境检测及安装
# 安装 mmengine 和 mmcv 依赖
# 为了防止后续版本变更导致的代码无法运行,我们暂时锁死版本
!pwd
%pip install -U "openmim==0.3.7"
!mim install "mmengine==0.7.1"
!mim install "mmcv==2.0.0"
# Install mmdetection
!rm -rf mmdetection
# 为了防止后续更新导致的可能无法运行,特意新建了 tutorials 分支
!git clone -b tutorials https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
%pip install -e .
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env
import mmdet
# 环境信息收集和打印
def collect_env():
"""Collect the information of the running environments."""
env_info = collect_base_env()
env_info['MMDetection'] = f'{mmdet.__version__}+{get_git_hash()[:7]}'
return env_info
if __name__ == '__main__':
for name, val in collect_env().items():
print(f'{name}: {val}')
1、数据集准备及可视化
我们提供了一个简单的 cat 猫数据集,该数据集来自社区用户,总共包括 144 张图片,并且已经提前划分为了训练集和测试集。
# 数据集可视化
import os
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))
image_paths= [filename for filename in os.listdir('cat_dataset/images')][:8]
for i,filename in enumerate(image_paths):
name = os.path.splitext(filename)[0]
image = Image.open('cat_dataset/images/'+filename).convert("RGB")
plt.subplot(2, 4, i+1)
plt.imshow(image)
plt.title(f"{filename}")
plt.xticks([])
plt.yticks([])
plt.tight_layout()
2、自定义配置文件
本教程采用 RTMDet 进行演示,在开始自定义配置文件前,先来了解下 RTMDet 算法。
RTMDet 是一个高性能低延时的检测算法,目前已经实现了目标检测、实例分割和旋转框检测任务。其简要描述为:**为了获得更高效的模型架构,MMDetection 探索了一种具有骨干和 Neck 兼容容量的架构,由一个基本的构建块构成,其中包含大核深度卷积。MMDetection 进一步在动态标签分配中计算匹配成本时引入软标签,以提高准确性。结合更好的训练技巧,得到的目标检测器名为 RTMDet,在 NVIDIA 3090 GPU 上以超过 300 FPS 的速度实现了 52.8% 的 COCO AP,优于当前主流的工业检测器。RTMDet 在小/中/大/特大型模型尺寸中实现了最佳的参数-准确度权衡,适用于各种应用场景,并在实时实例分割和旋转对象检测方面取得了新的最先进性能。
3、训练前可视化验证
我们可以采用 mmdet 提供的 tools/analysis_tools/browse_dataset.py
脚本来对训练前的 dataloader 输出进行可视化,确保数据部分没有问题。
考虑到我们仅仅想可视化前几张图片,因此下面基于 browse_dataset.py 实现一个简单版本即可。
4、模型训练
python tools/train.py rtmdet_tiny_1xb12-40e_cat.py
5、模型测试和推理
python tools/test.py rtmdet_tiny_1xb12-40e_cat.py work_dirs/rtmdet_tiny_1xb12-40e_cat/best_coco/bbox_mAP_epoch_30.pth
6、可视化分析
可视化分析包括特征图可视化以及类似 Grad CAM 等可视化分析手段。不过由于 MMDetection 中还没有实现,我们可以直接采用 MMYOLO 中提供的功能和脚本。MMYOLO 是基于 MMDetection 开发,并且此案有了统一的代码组织形式,因此 MMDetection 配置可以直接在 MMYOLO 中使用,无需更改配置。