问题描述:
遇到一个问题是,使用自己的数据做训练,我自己的数据集是加背景9类,COCO默认加背景81类,我图片,标注json等都做了相应确认没有问题,但是后面cascade有三个头部分类别忘了改了,这个就尴尬了,训练了两天半,结果开始就错了,想着是不是可以硬改,直接拿来用,发现不该config文件还能顶着warning用,改了config就报错了,又不想重新训练,所有就有了下面的操作。废话不多说看代码:
#-*-coding:utf-8-*-
import torch
from collections import OrderedDict
sd = torch.load("./epoch_20.pth", map_location="cuda:0")
# change state_dict
newsd_state_dict = sd['state_dict']
newsd_state_dict['bbox_head.0.fc_cls.weight'] = newsd_state_dict['bbox_head.0.fc_cls.weight'][:9]
newsd_state_dict['bbox_head.0.fc_cls.bias'] = newsd_state_dict['bbox_head.0.fc_cls.bias'][:9]
newsd_state_dict['bbox_head.1.fc_cls.weight'] = newsd_state_dict['bbox_head.1.fc_cls.weight'][:9]
newsd_state_dict['bbox_head.1.fc_cls.bias'] = newsd_state_dict['bbox_head.1.fc_cls.bias'][:9]
newsd_state_dict['bbox_head.2.fc_cls.weight'] = newsd_state_dict['bbox_head.2.fc_cls.weight'][:9]
newsd_state_dict['bbox_head.2.fc_cls.bias'] = newsd_state_dict['bbox_head.2.fc_cls.bias'][:9]
# change meta
newsd_meta = dict({'mmdet_version': '1.0.0+483beb7', 'config': "/home/zyzn/yyf/mmdetection/configs/zyzn/cascade_rcnn_hrnetv2p_w32_20e.py\nIMGS = '/home/zyzn/yyf/mmdetection/data/ZYZN/7category_20200420/hi_lo_logratio'\n\nTRAIN_ANN = '/home/zyzn/yyf/mmdetection/data/ZYZN/7category_20200420/annotations/7/7category-dual-energy-train.json'\nTEST_ANN = '/home/zyzn/yyf/mmdetection/data/ZYZN/7category_20200420/annotations/7/7category-dual-energy-val.json'\nNUM_CLASSES=8+1\n\n# model settings\nmodel = dict(\n type='CascadeRCNN',\n num_stages=3,\n pretrained='checkpoints/hrnetv2_w32-dc9eeb4f.pth',\n backbone=dict(\n type='HRNet',\n extra=dict(\n stage1=dict(\n num_modules=1,\n num_branches=1,\n block='BOTTLENECK',\n num_blocks=(4, ),\n num_channels=(64, )),\n stage2=dict(\n num_modules=1,\n num_branches=2,\n block='BASIC',\n num_blocks=(4, 4),\n num_channels=(32, 64)),\n stage3=dict(\n num_modules=4,\n num_branches=3,\n block='BASIC',\n num_blocks=(4, 4, 4),\n num_channels=(32, 64, 128)),\n stage4=dict(\n num_modules=3,\n num_branches=4,\n block='BASIC',\n num_blocks=(4, 4, 4, 4),\n num_channels=(32, 64, 128, 256)))),\n neck=dict(type='HRFPN', in_channels=[32, 64, 128, 256], out_channels=256),\n rpn_head=dict(\n type='RPNHead',\n in_channels=256,\n feat_channels=256,\n anchor_scales=[8],\n anchor_ratios=[0.5, 1.0, 2.0],\n anchor_strides=[4, 8, 16, 32, 64],\n target_means=[.0, .0, .0, .0],\n target_stds=[1.0, 1.0, 1.0, 1.0],\n loss_cls=dict(\n type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),\n loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),\n bbox_roi_extractor=dict(\n type='SingleRoIExtractor',\n roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),\n out_channels=256,\n featmap_strides=[4, 8, 16, 32]),\n bbox_head=[\n dict(\n type='SharedFCBBoxHead',\n num_fcs=2,\n in_channels=256,\n fc_out_channels=1024,\n roi_feat_size=7,\n num_classes=9,\n target_means=[0., 0., 0., 0.],\n target_stds=[0.1, 0.1, 0.2, 0.2],\n reg_class_agnostic=True,\n loss_cls=dict(\n type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),\n dict(\n type='SharedFCBBoxHead',\n num_fcs=2,\n in_channels=256,\n fc_out_channels=1024,\n roi_feat_size=7,\n num_classes=9,\n target_means=[0., 0., 0., 0.],\n target_stds=[0.05, 0.05, 0.1, 0.1],\n reg_class_agnostic=True,\n loss_cls=dict(\n type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),\n dict(\n type='SharedFCBBoxHead',\n num_fcs=2,\n in_channels=256,\n fc_out_channels=1024,\n roi_feat_size=7,\n num_classes=9,\n target_means=[0., 0., 0., 0.],\n target_stds=[0.033, 0.033, 0.067, 0.067],\n reg_class_agnostic=True,\n loss_cls=dict(\n type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),\n loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),\n ])\n# model training and testing settings\ntrain_cfg = dict(\n rpn=dict(\n assigner=dict(\n type='MaxIoUAssigner',\n pos_iou_thr=0.7,\n neg_iou_thr=0.3,\n min_pos_iou=0.3,\n ignore_iof_thr=-1),\n sampler=dict(\n type='RandomSampler',\n num=256,\n pos_fraction=0.5,\n neg_pos_ub=-1,\n add_gt_as_proposals=False),\n allowed_border=0,\n pos_weight=-1,\n debug=False),\n rpn_proposal=dict(\n nms_across_levels=False,\n nms_pre=2000,\n nms_post=2000,\n max_num=2000,\n nms_thr=0.7,\n min_bbox_size=0),\n rcnn=[\n dict(\n assigner=dict(\n type='MaxIoUAssigner',\n pos_iou_thr=0.5,\n neg_iou_thr=0.5,\n min_pos_iou=0.5,\n ignore_iof_thr=-1),\n sampler=dict(\n type='RandomSampler',\n num=512,\n pos_fraction=0.25,\n neg_pos_ub=-1,\n add_gt_as_proposals=True),\n pos_weight=-1,\n debug=False),\n dict(\n assigner=dict(\n type='MaxIoUAssigner',\n pos_iou_thr=0.6,\n neg_iou_thr=0.6,\n min_pos_iou=0.6,\n ignore_iof_thr=-1),\n sampler=dict(\n type='RandomSampler',\n num=512,\n pos_fraction=0.25,\n neg_pos_ub=-1,\n add_gt_as_proposals=True),\n pos_weight=-1,\n debug=False),\n dict(\n assigner=dict(\n type='MaxIoUAssigner',\n pos_iou_thr=0.7,\n neg_iou_thr=0.7,\n min_pos_iou=0.7,\n ignore_iof_thr=-1),\n sampler=dict(\n type='RandomSampler',\n num=512,\n pos_fraction=0.25,\n neg_pos_ub=-1,\n add_gt_as_proposals=True),\n pos_weight=-1,\n debug=False)\n ],\n stage_loss_weights=[1, 0.5, 0.25])\ntest_cfg = dict(\n rpn=dict(\n nms_across_levels=False,\n nms_pre=1000,\n nms_post=1000,\n max_num=1000,\n nms_thr=0.7,\n min_bbox_size=0),\n rcnn=dict(\n score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))\n# dataset settings\ndataset_type = 'CocoDataset'\ndata_root = 'data/coco/'\nimg_norm_cfg = dict(\n mean=[243.13, 239.50, 248.55], std=[32.26, 40.86, 16.42], to_rgb=True)\ntrain_pipeline = [\n dict(type='LoadImageFromFile'),\n dict(type='LoadAnnotations', with_bbox=True),\n dict(type='Resize', img_scale=(1600, 960), keep_ratio=True),\n dict(type='RandomFlip', flip_ratio=0.5),\n dict(type='Normalize', **img_norm_cfg),\n dict(type='Pad', size_divisor=32),\n dict(type='DefaultFormatBundle'),\n dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),\n]\ntest_pipeline = [\n dict(type='LoadImageFromFile'),\n dict(\n type='MultiScaleFlipAug',\n img_scale=(1600, 960),\n flip=False,\n transforms=[\n dict(type='Resize', keep_ratio=True),\n dict(type='RandomFlip'),\n dict(type='Normalize', **img_norm_cfg),\n dict(type='Pad', size_divisor=32),\n dict(type='ImageToTensor', keys=['img']),\n dict(type='Collect', keys=['img']),\n ])\n]\ndata = dict(\n imgs_per_gpu=1,\n workers_per_gpu=2,\n train=dict(\n type=dataset_type,\n ann_file=TRAIN_ANN,\n img_prefix=IMGS,\n pipeline=train_pipeline),\n val=dict(\n type=dataset_type,\n ann_file=TEST_ANN,\n img_prefix=IMGS,\n pipeline=test_pipeline),\n test=dict(\n type=dataset_type,\n ann_file=TEST_ANN,\n img_prefix=IMGS,\n pipeline=test_pipeline))\nevaluation = dict(interval=1, metric='bbox')\n# optimizer\noptimizer = dict(type='SGD', lr=0.0025, momentum=0.9, weight_decay=0.0001)\noptimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))\n# learning policy\nlr_config = dict(\n policy='step',\n warmup='linear',\n warmup_iters=500,\n warmup_ratio=1.0 / 3,\n step=[16, 19])\ncheckpoint_config = dict(interval=1)\n# yapf:disable\nlog_config = dict(\n interval=50,\n hooks=[\n dict(type='TextLoggerHook'),\n # dict(type='TensorboardLoggerHook')\n ])\n# yapf:enable\n# runtime settings\ntotal_epochs = 20\ndist_params = dict(backend='nccl')\nlog_level = 'INFO'\nwork_dir = './work_dirs/cascade_rcnn_hrnetv2p_w32_vold'\nload_from = None\nresume_from = None\nworkflow = [('train', 1)]\n", 'CLASSES': ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'), 'epoch': 15, 'iter': 0, 'mmcv_version': '0.4.3', 'time': 'Mon Apr 27 20:37:00 2020'})
# change optimizer
## trick , 不懂optimizer的存储原理,只知道一层层遍历,找到和类别数相关的参数修改就可以了
newsd_optimizer = sd['optimizer']
for k, v in newsd_optimizer.items():
if type(v) is dict:
for k2, v2 in v.items():
for k3, v3 in v2.items():
if v3.shape[0] == 81:
newsd_optimizer[k][k2][k3] = v3[:9]
elif type(v) is list:
for vv in v:
for kkk, vvv in vv.items():
if type(vvv) is list:
for vvvv in vvv:
pass
sd['meta'] = newsd_meta
sd['state_dict'] = newsd_state_dict
sd['optimizer'] = newsd_optimizer
torch.save(sd, "zzz.pth")
注意如果需要重新自定义一个权重文件,权重文件的数据类型(state_dict)应该使用OrderedDict()来初始化。