修改mmdetection预训练模型权重的类别数,以Pascal VOC为例

mmdetection官方提供的模型预训练权重文件都是基于MS COCO数据集训练的,当我们使用上述预训练权重文件对自己的数据集进行微调时,由于自己数据集的类别数往往与MS COCO不一致,因此需要做一些修改。完整代码在最后。

以 faster_rcnn_r50_fpn_1x_coco的权重文件为例:

faster_rcnn_r50_fpn_1x_coco的权重文件的下载地址:

mmdetection/configs/faster_rcnn at master · open-mmlab/mmdetection · GitHub

 1. 加载下载好的预训练权重

import torch

pretrain_pth = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
# print(pretrain_pth)
model = torch.load(pretrain_pth)

2.查看一些model的一些基本信息

print('type:', type(model))
print('length:', len(model))
print('_________________________________________________')

 3. 查看model中的key

print('key:')
for k in model.keys():  #查看模型字典里面的key
    print(k)      # meta  state_dict
print('_________________________________________________')

 在model里面有两个key,主要修改state_dict中的。meta里面是一些基础信息。可以看一下meta里面有哪些内容:

for metak in model['meta'].keys():
    print(metak)
print('_________________________________________________')

主要包含mmdet的版本,训练配置文件等。如果想要查看meta里的详细内容,可以使用:

for k in model:         #查看模型字典里面的value
    if k == 'meta':
        print(k, model[k])
print('_________________________________________________')

 4.查看state_dict内的内容:

for key, value in model['state_dict'].items(): # 打印出权重文件中网络结构
    print(key, value.size(), sep=" ")

 主要修改和类别相关的信息,也就是

roi_head.bbox_head.fc_cls.weight torch.Size([81, 1024])
roi_head.bbox_head.fc_cls.bias torch.Size([81])

以Pascal VOC数据集为例,类别数为20,加上背景类别总共为21

voc_num_classes = 21

修改如下:

model['state_dict']['roi_head.bbox_head.fc_cls.weight'] = model['state_dict']['roi_head.bbox_head.fc_cls.weight'][:voc_num_classes, :]
model['state_dict']['roi_head.bbox_head.fc_cls.bias'] = model['state_dict']['roi_head.bbox_head.fc_cls.bias'][:voc_num_classes]
torch.save(model, 'checkpoints/faster_rcnn_r50_fpn_1x_coco_voc.pth')  # 保存权重文件

修改完成后查看state_dict的内容:

类别已经变成21了~~~ 

完整代码:

import torch

pretrain_pth = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
# print(pretrain_pth)
model = torch.load(pretrain_pth)
voc_num_classes = 21


# print('type:', type(model))
# print('length:', len(model))
# print('_________________________________________________')


# print('key:')
# for k in model.keys():  #查看模型字典里面的key
#     print(k)      # meta  state_dict
# print('_________________________________________________')


# for metak in model['meta'].keys():
#     print(metak)
# print('_________________________________________________')


# for k in model:         #查看模型字典里面的value
#     if k == 'meta':
#         print(k, model[k])
# print('_________________________________________________')


# for state_dict_k in model['state_dict'].keys():
#     print(state_dict_k)
# print('_________________________________________________')


# for key, value in model['state_dict'].items(): # 打印出权重文件中网络结构
#     print(key, value.size(), sep=" ")

model['state_dict']['roi_head.bbox_head.fc_cls.weight'] = model['state_dict']['roi_head.bbox_head.fc_cls.weight'][:voc_num_classes, :]
model['state_dict']['roi_head.bbox_head.fc_cls.bias'] = model['state_dict']['roi_head.bbox_head.fc_cls.bias'][:voc_num_classes]
torch.save(model, 'checkpoints/faster_rcnn_r50_fpn_1x_coco_voc.pth')  # 保存权重文件

for key, value in model['state_dict'].items(): # 打印出权重文件中网络结构
    print(key, value.size(), sep=" ")

参考文献:

1.  MMDetection笔记:修改预训练模型权重类别数 - 知乎

2.  修改mmdetection的权重文件 - 知乎

  • 6
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值