P2T模型是基于mmdetection框架的,所以训练自己的数据集的步骤是相通的。我主要借鉴了这篇文章Swin Transformer实战实例分割:训练自己的数据集
1.安装环境
python=3.8
pytorch系列适配cuda11.0
mmdetection(按官方文档安装即可)
以及其他要求的库
2.数据集准备
coco格式,制作教程有很多,这里就不写了
3.相关文件修改
类别数量修改:configs\_base_\models\...py(采用的骨干设置文件),将其中的num_classes由80改为1,
修改configs\_base_\default_runtime.py中interval(用于设置log的保存间隔与总训练次数),load_from(可选)
修改权重文件
修改configs/base/datasets/coco_instance.py中数据集路径
修改detection/configs/mask_rcnn_p2t_b_fpn_1x_coco.py中的max_epochs、lr
修改mmdet/core/evalution/class_names.py和mmdet/datasets/coco.py中的标签
def coco_classes():
return ['cow']
class CocoDataset(CustomDataset):
CLASSES = ('cow',)
训练:
cd detection
bash dist_train.sh configs/...py(权重文件) 8(gpus)