配环境时用pytorch2.0以上好些。一次成功见这个教程
具体布置自己的东西时,可以参考这个教程
数据集准备
数据如下准备。然后test2017用不着,空着
跑实验
1/选择一个模型 使用python tools/train.py configs/...运行,找到work dirs目录下的模型py文件
2/修改配置文件训练测测试相关参数:修改检测数,一般是classes
3/重载数据集的类名。去官网下载权重,把配置文件中load_from = none 修改为load_form = '权重路径'
4/运行修改好后的配置文件 python tools/train.py work dirs/...xxx.py,得到相应结果
8/ 运行python tools/ workdirs/....py 来测试模型的参数大小,flops和fps
常见报错
1.File "<__array_function__ internals>", line 200, in concatenate ValueError: need at least one array to concatenate
解决:1.在配置文件的三个datalodaer里面,data_root='data/coco/',后面加上metainfo
2.确保classes设置的数量对了
如metainfo=dict(
classes=(
'Yagi_antenna',
'Plate_log_antenna',
'Patch_antenna',
),
palette=[
(
220,
20,
60,
),
(
119,
11,
32,
),
(
0,
0,
142,
),
]),
变成这样:
train_dataloader = dict(
batch_sampler=dict(_scope_='mmdet', type='AspectRatioBatchSampler'),
batch_size=2,
dataset=dict(
_scope_='mmdet',
ann_file='annotations/instances_train2017.json',
backend_args=None,
data_prefix=dict(img='train2017/'),
data_root='data/coco/',
filter_cfg=dict(filter_empty_gt=True, min_size=32),
metainfo=dict(
classes=(
'Yagi_antenna',
'Plate_log_antenna',
'Patch_antenna',
),
palette=[
(
220,
20,
60,
),
(
119,
11,
32,
),
(
0,
0,
142,
),
]),
pipeline=[
dict(backend_args=None, type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
keep_ratio=True,
ratio_range=(
0.1,
2.0,
),
scale=(
1024,
1024,
),
type='RandomResize'),
dict(
allow_negative_crop=True,
crop_size=(
1024,
1024,
),
crop_type='absolute_range',
recompute_bbox=True,
type='RandomCrop'),
dict(min_gt_bbox_wh=(
0.01,
0.01,
), type='FilterAnnotations'),
dict(prob=0.5, type='RandomFlip'),
dict(type='PackDetInputs'),
],
type='CocoDataset'),
num_workers=4,
persistent_workers=True,
sampler=dict(_scope_='mmdet', shuffle=True, type='DefaultSampler'))
2.data'category id']= self.cat ids[label] IndexError: list index out of range
参考这个教程。很有用,我是当时代码中的类名由a-b 变成了a - b,中间多了空格,导致代码和json里面的标注不一样而报错。
3.you should set `PYTHONPATH` to make 'sys.path include the directory which contains your custom module
解决:import sys
sys.path.append("/hy-tmp/mmdetection")