想复现Oriented R-CNN的代码并做一些改动,下载下来发现是使用的mmdetection框架,于是边学习记录一些遇到的问题。
1、训练步骤
(1)在configs中选自己想要用的模型,可以直接修改,我习惯新建一个.py文件,将model、dataset等config文件全部复制过来,方便改动,即开头的_base_中所有文件都拷贝到新.py中。
(2)修改后命令行内运行
python tools/train.py configs/xxx.py
2、无法下载预训练模型,报错"HTTP Error 403: Forbidden"
换成国内源的网址即可正常下载
选取想要的模型,将地址复制下来,在configs/default_runtime.py中修改load_from为新地址即可。
3、使用作者给的模型用于预训练时报“xxx.pth is a zip archive(did you mean to use torch.jit.load()?)“
在configs/obb/oriented_rcnn/README.md中作者给出了很多训练好了模型参数文件,还很贴心地给了百度云的下载链接。下载下来之后将路径复制在configs/default_runtime.py中load_from处即可。
运行时可能会报以上错误,这是由于该模型在训练时使用的是高版本的torch,1.6之后的pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。解决方法是新建一个torch==1.6的环境,然后运行:
import torch
state_dict = torch.load("faster_rcnn_orpn_r101_fpn_1x_mssplit_rr_dota10_epoch12.pth", map_location='cpu')
torch.save(state_dict, "faster_rcnn_orpn_r101_fpn_1x_mssplit_rr_dota10_epoch12.pth",
_use_new_zipfile_serialization=False)
把“xxxx.pth”文件换成自己的文件路径即可。这样就将原文件重新保存为非zip格式的了。这里load时加上了map_location='cpu'是由于我的CUDA版本太低,torch.cuda.is_available()=FALSE,但又必须使用torch1.6,不过不影响结果。用同样的方法使用更新好的权重文件就行了。
4、用mmdetection中现成模型训练自己的数据集
如训练COCO格式的自制数据集,类别与COCO不一致
(1)在data文件夹中放入分好train/val/test的图片和标签。
(2)自己的config文件中,将num_classes修改为自己的类别数(不加背景),如model ——> roi_head ——> bbox_head ——> num_classes,如果有mask分支则mask_head中也有一项num_classes。
(3)mmdet/datasets/coco.py 、 mmdet/core/evaluation/class_names.py 都将原CLASSES注释掉,换成新的类别名称(无BG项)。
5、ValueError: need at least one array to concatenate
一般都是数据集的标签里出了问题,查看标签文件里是否没有生成bbox坐标或label等
6、AttributeError: module 'torch' has no attribute 'square
mmdet/core/bbox/coder/obb/midpoint_offset_coder.py文件中第125行:
diag_len = torch.sqrt(torch.square(center_polys[..., 0::2]) + torch.square(center_polys[..., 1::2]))
1.10版本以下的torch没有设置torch.square()这个函数,手动改成两个相同元素做torch.mul()就可以正常运行,即:
diag_len = torch.sqrt(
torch.mul(center_polys[..., 0::2], center_polys[..., 0::2]) +
torch.mul(center_polys[..., 1::2], center_polys[..., 1::2]))
7、其他问题
还遇到很多其他的小问题,虽然也比较恼火,可能一改改一天,但回顾下来,确实都是版本不匹配造成的,很浪费时间也令人emo,所以一定要check清楚自己的CUDA版本、适合的torch、torchvision,mmdet版本与mmcv版本的匹配问题等等,确定不是这些问题之后再去针对error改错。