根据类别个数修改权重(测试mmdetection)
import torch
pretrained_weight = torch.load('faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
num_class = 2 ##根据自己数据集训练个数
pretrained_weight['state_dict']['roi_head.bbox_head.fc_cls.weight'].resize_(numclasses+1, 1024)
pretrained_weight['state_dict']['roi_head.bbox_head.fc_cls.bias'].resize_(numclasses+1)
pretrained_weight['state_dict']['roi_head.bbox_head.fc_reg.weight'].resize_(numclasses*4, 1024)
pretrained_weight['state_dict']['roi_head.bbox_head.fc_reg.bias'].resize_(numclasses*4)
torch.save(pretrained_weight, "faster_rcnn_r50_fpn_1x_%d.pth"%num_class)
将修改后的权重文件作为预训练权重。