由于mmdetection封装程度较高,直接更改builder过程比较麻烦,所以recommend这种预处理的方式,先将模型权重文件处理好
import torch
def main():
#gen coco pretrained weight
num_classes = 6
model_coco = torch.load("./checkpoint/cascade_mask_rcnn_swin_tiny_patch4_window7.pth") # weight
print(model_coco.keys())
print(model_coco['amp'])
print('*******************')
for key, value in model_coco["state_dict"].items():
print(key)
#bbox属于哪类的分类层,多分类,类别数+背景
model_coco["state_dict"]["roi_head.bbox_head.0.fc_cls.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.0.fc_cls.weight"][:num_classes+1, :]
model_coco["state_dict"]["roi_head.bbox_head.0.fc_cls.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.0.fc_cls.bias"][:num_classes+1]
#bbox位置回归层 4指x1,y1,x2,y2
model_coco["state_dict"]["roi_head.bbox_head.0.fc_reg.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.0.fc_reg.weight"][:4*num_classes, :]
model_coco["state_dict"]["roi_head.bbox_head.0.fc_reg.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.0.fc_reg.bias"][:4*num_classes]
#bbox属于哪类的分类层,多分类,类别数+背景
model_coco["state_dict"]["roi_head.bbox_head.1.fc_cls.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.1.fc_cls.weight"][:num_classes+1, :]
model_coco["state_dict"]["roi_head.bbox_head.1.fc_cls.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.1.fc_cls.bias"][:num_classes+1]
#bbox位置回归层 4指x1,y1,x2,y2
model_coco["state_dict"]["roi_head.bbox_head.1.fc_reg.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.1.fc_reg.weight"][:4*num_classes, :]
model_coco["state_dict"]["roi_head.bbox_head.1.fc_reg.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.1.fc_reg.bias"][:4*num_classes]
#bbox属于哪类的分类层,多分类,类别数+背景
model_coco["state_dict"]["roi_head.bbox_head.2.fc_cls.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.2.fc_cls.weight"][:num_classes+1, :]
model_coco["state_dict"]["roi_head.bbox_head.2.fc_cls.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.2.fc_cls.bias"][:num_classes+1]
#bbox位置回归层 4指x1,y1,x2,y2
model_coco["state_dict"]["roi_head.bbox_head.2.fc_reg.weight"] = \
model_coco["state_dict"]["roi_head.bbox_head.2.fc_reg.weight"][:4*num_classes, :]
model_coco["state_dict"]["roi_head.bbox_head.2.fc_reg.bias"] = \
model_coco["state_dict"]["roi_head.bbox_head.2.fc_reg.bias"][:4*num_classes]
#mask头的分割图,对于每类是一个channel,在此channel中,属于这类的是1,不属于是0,所以mask损失自然想到是BinaryCrossEntry
model_coco["state_dict"]["roi_head.mask_head.0.conv_logits.weight"] = \
model_coco["state_dict"]["roi_head.mask_head.0.conv_logits.weight"][:num_classes, :]
model_coco["state_dict"]["roi_head.mask_head.0.conv_logits.bias"] = \
model_coco["state_dict"]["roi_head.mask_head.0.conv_logits.bias"][:num_classes]
#mask头的分割图
model_coco["state_dict"]["roi_head.mask_head.1.conv_logits.weight"] = \
model_coco["state_dict"]["roi_head.mask_head.1.conv_logits.weight"][:num_classes, :]
model_coco["state_dict"]["roi_head.mask_head.1.conv_logits.bias"] = \
model_coco["state_dict"]["roi_head.mask_head.1.conv_logits.bias"][:num_classes]
#mask头的分割图
model_coco["state_dict"]["roi_head.mask_head.2.conv_logits.weight"] = \
model_coco["state_dict"]["roi_head.mask_head.2.conv_logits.weight"][:num_classes, :]
model_coco["state_dict"]["roi_head.mask_head.2.conv_logits.bias"] = \
model_coco["state_dict"]["roi_head.mask_head.2.conv_logits.bias"][:num_classes]
print( model_coco["state_dict"]["roi_head.mask_head.0.conv_logits.weight"].shape)
print( model_coco["state_dict"]["roi_head.mask_head.1.conv_logits.weight"].shape)
print( model_coco["state_dict"]["roi_head.mask_head.2.conv_logits.weight"].shape)
# save new model
torch.save(model_coco, "./checkpoint/classes_%d_cascade_mask_rcnn_swin_tiny_patch4_window7.pth" % num_classes)
if __name__ == "__main__":
main()
同时,注意一下custom dataset 是否标mask了,通常我们的custom dataset只标bbox,所以需要把mask-rcnn中使用mask禁用掉:
1、配置文件中base中的基础模型文件cascade_mask_rcnn_swin_fpn.py中use_mask由True改为False,注释掉mask_roi_extractor和mask_head两个变量
2、入口配置文件cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py中,train_pipeline中 with_mask由True改为False;dict(type=‘Collect’, keys=[‘img’, ‘gt_bboxes’, ‘gt_labels’, ‘gt_masks’])中去掉’gt_masks’
3、配置文件中base中数据集描述文件coco_instance.py中evaluation = dict(metric=[‘bbox’, ‘segm’]) ,去掉segm