mmdetection学习日记

想复现Oriented R-CNN的代码并做一些改动,下载下来发现是使用的mmdetection框架,于是边学习记录一些遇到的问题。

1、训练步骤

(1)在configs中选自己想要用的模型,可以直接修改,我习惯新建一个.py文件,将model、dataset等config文件全部复制过来,方便改动,即开头的_base_中所有文件都拷贝到新.py中。

(2)修改后命令行内运行

python tools/train.py configs/xxx.py

2、无法下载预训练模型,报错"HTTP Error 403: Forbidden"

换成国内源的网址即可正常下载

可以查询:mmcv/model_zoo/open_mmlab.json · master · mirrors / open-mmlab / mmcv · CODE CHINA (gitcode.net)https://gitcode.net/mirrors/open-mmlab/mmcv/-/blob/master/mmcv/model_zoo/open_mmlab.json

选取想要的模型,将地址复制下来,在configs/default_runtime.py中修改load_from为新地址即可。

3、使用作者给的模型用于预训练时报“xxx.pth is a zip archive(did you mean to use torch.jit.load()?)“

在configs/obb/oriented_rcnn/README.md中作者给出了很多训练好了模型参数文件,还很贴心地给了百度云的下载链接。下载下来之后将路径复制在configs/default_runtime.py中load_from处即可。

运行时可能会报以上错误,这是由于该模型在训练时使用的是高版本的torch,1.6之后的pytorch默认使用zip文件格式来保存权重文件,导致这些权重文件无法直接被1.5及以下的pytorch加载。解决方法是新建一个torch==1.6的环境,然后运行:

import torch

state_dict = torch.load("faster_rcnn_orpn_r101_fpn_1x_mssplit_rr_dota10_epoch12.pth", map_location='cpu')
torch.save(state_dict, "faster_rcnn_orpn_r101_fpn_1x_mssplit_rr_dota10_epoch12.pth",
           _use_new_zipfile_serialization=False)

把“xxxx.pth”文件换成自己的文件路径即可。这样就将原文件重新保存为非zip格式的了。这里load时加上了map_location='cpu'是由于我的CUDA版本太低,torch.cuda.is_available()=FALSE,但又必须使用torch1.6,不过不影响结果。用同样的方法使用更新好的权重文件就行了。

4、用mmdetection中现成模型训练自己的数据集

如训练COCO格式的自制数据集,类别与COCO不一致

(1)在data文件夹中放入分好train/val/test的图片和标签。

(2)自己的config文件中,将num_classes修改为自己的类别数(不加背景),如model ——> roi_head ——> bbox_head ——> num_classes,如果有mask分支则mask_head中也有一项num_classes。

(3)mmdet/datasets/coco.py 、 mmdet/core/evaluation/class_names.py  都将原CLASSES注释掉,换成新的类别名称(无BG项)

5、ValueError: need at least one array to concatenate

一般都是数据集的标签里出了问题,查看标签文件里是否没有生成bbox坐标或label等

6、AttributeError: module 'torch' has no attribute 'square

mmdet/core/bbox/coder/obb/midpoint_offset_coder.py文件中第125行:

diag_len = torch.sqrt(torch.square(center_polys[..., 0::2]) + 
torch.square(center_polys[..., 1::2]))

1.10版本以下的torch没有设置torch.square()这个函数,手动改成两个相同元素做torch.mul()就可以正常运行,即:

diag_len = torch.sqrt(
        torch.mul(center_polys[..., 0::2], center_polys[..., 0::2]) +
        torch.mul(center_polys[..., 1::2], center_polys[..., 1::2]))

7、其他问题

还遇到很多其他的小问题,虽然也比较恼火,可能一改改一天,但回顾下来,确实都是版本不匹配造成的,很浪费时间也令人emo,所以一定要check清楚自己的CUDA版本、适合的torch、torchvision,mmdet版本与mmcv版本的匹配问题等等,确定不是这些问题之后再去针对error改错。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值