因公司需要,需要将mmdetection2.5版本训练的模型迁移到mmdetection1.0,两者因环境(torch版本,mmcv版本,核心代码实现上的区别),直接更改config存在问题,本文记录迁移过程
1、mmdetection2.5 模型保存
两个版本mmdet的torch 版本不相同(其中2.5版本为1.7,1.0版本为1.1),因此在mmdetection2.5上保存的模型,在低版本的torch中无法正确读取,因此首先要需要解决这个问题,解决方案是只保存state_dict,重点是需要设定_use_new_zipfile_serialization为False,否则低版本torch无法读取
import os
import torch
from mmdet.apis import init_detector
config_path = 'config_v2_5.py'
model_path = 'model_v2_5.pth'
save_path = 'mmdet_demo/transfer.pth'
model = init_detector(config_path, model_path)
state_dict = model.state_dict()
torch.save(state_dict,,_use_new_zipfile_serialization=False
如上述脚本可实现高版本模型保存state_dict, 这样把weights迁移到对应的低版本
2、制作config
为了简单实现需求,对比了两个版本config的异同,在2.5版本的基础上,修改得到1.0版本的config, 主
要更改的地方有四个方面,如下所示:
-
RPN相关
因2.x版本相较1.x版本进行了较大改进,尤其在RPN的组成上,核心代码一致,组合方式发生了变化。
-
bbox_roi_extractor相关
第二个变化为ROI相关的组合方式
-
bbox_head相关
-
test_cfg相关
主要注意使用nms中的参数名称
3、实现转化
实现转化的流程:
1、利用1.0的config进行模型初始化,得到对应的结构state_dict
2、获取2.5版本保存的state_dict
3、遍历1.0获取的state_dict,保存有差异的key值,存入两个版本的txt文件,如下图所示:
4、遍历1.0的state_dict,对于不同情况进行分别赋值
具体代码如下所示:
import torch
import mmcv
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint
import warnings
from mmdet.core import get_classes
from mmcv import Config
from mmdet.apis import init_detector
def read_txt(file):
with open(file, 'r') as f:
content = f.read().splitlines()
return content
if __name__ == '__main__':
import numpy as np
import os
num_classes = 3 # 1.0版本中的类别数量
txt_2 ='key_2.5.txt' ## 2.5版本的key值
txt_1 ='key_1.0.txt' ## 1.0版本的key值
key1 = read_txt(txt_1)
key2 = read_txt(txt_2)
config_path ='config_v1.0.py' # 1.0版本的config
model_path ='tranfer.pth' # 2.5版本的state_dict
cfg = Config.fromfile(config_path)
model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
state_dict = torch.load(model_path)
model_state = model.state_dict()
temp_dict = {}
for k, v in model_state.items():
if k in state_dict.keys() and np.shape(v) == np.shape(state_dict[k]):
temp_dict[k] = state_dict[k]
for k1, k2 in zip(key1, key2):
if np.shape(model_state[k1]) == np.shape(state_dict[k2]):
if k1.endswith('fc_cls.weight') or k1.endswith('fc_cls.bias'):
old_val = state_dict[k2]
new_val = torch.cat([old_val[-1:], old_val[:-1]], dim=0)
temp_dict[k1] = new_val
else:
temp_dict[k1] = state_dict[k2]
else:
old_val = state_dict[k2]
if k1.endswith('fc_reg.weight'):
out_channels, in_channels = old_val.shape[:2]
tmp_val = old_val.reshape(num_classes - 1, -1, in_channels,*old_val.shape[2:])
zero_mode = torch.zeros((1, tmp_val.shape[1],
tmp_val.shape[2])).type_as(tmp_val) #初始化0进行填充
new_val = torch.cat([zero_mode, tmp_val], dim=0)
new_val = new_val.reshape(-1, *old_val.shape[1:])
elif k1.endswith('fc_reg.bias'):
tmp_val = old_val.reshape(num_classes - 1, -1)
zero_mode = torch.zeros((1, tmp_val.shape[1])).type_as(tmp_val)
#初始化0进行填充
new_val = torch.cat([tmp_val, zero_mode], dim=0)
new_val = new_val.reshape(-1)
temp_dict[k1] = new_val
model_state.update(temp_dict)
model.load_state_dict(model_state, strict=True)
model.eval()
meta = {'CLASSES': ['1','2','3','4','5','6']}
checkpoint = {
'meta': meta,
'state_dict': model.state_dict()
}
torch.save(checkpoint,'transfer_model.pth')
上述实现参考mmdetection2.5中模型升级的代码:官方升级代码示例,由于v2.5转换到v1.0存在类别数目变少的情况,本文以全0进行填充
以上就是是mmdetection2.5迁移到mmdetection1.0的示例
如有问题,敬请指正!
–END–