高版本mmdetection2.5迁移到低版本mmdetection1.0方法

本文详细记录了如何将基于MMDetection2.5训练的模型迁移到1.0版本的过程,包括模型状态保存、配置文件调整以及不同版本state_dict的转换步骤,涉及关键代码示例,适用于处理模型版本升级问题。
摘要由CSDN通过智能技术生成

因公司需要,需要将mmdetection2.5版本训练的模型迁移到mmdetection1.0,两者因环境(torch版本,mmcv版本,核心代码实现上的区别),直接更改config存在问题,本文记录迁移过程

1、mmdetection2.5 模型保存

两个版本mmdet的torch 版本不相同(其中2.5版本为1.7,1.0版本为1.1),因此在mmdetection2.5上保存的模型,在低版本的torch中无法正确读取,因此首先要需要解决这个问题,解决方案是只保存state_dict,重点是需要设定_use_new_zipfile_serialization为False,否则低版本torch无法读取

import os
import torch
from mmdet.apis import init_detector
config_path = 'config_v2_5.py'
model_path = 'model_v2_5.pth'
save_path = 'mmdet_demo/transfer.pth'
model = init_detector(config_path, model_path)
state_dict = model.state_dict()
torch.save(state_dict,,_use_new_zipfile_serialization=False

如上述脚本可实现高版本模型保存state_dict, 这样把weights迁移到对应的低版本

2、制作config

为了简单实现需求,对比了两个版本config的异同,在2.5版本的基础上,修改得到1.0版本的config, 主
要更改的地方有四个方面,如下所示:

  • RPN相关
    因2.x版本相较1.x版本进行了较大改进,尤其在RPN的组成上,核心代码一致,组合方式发生了变化。
    在这里插入图片描述

  • bbox_roi_extractor相关
    第二个变化为ROI相关的组合方式
    在这里插入图片描述

  • bbox_head相关
    在这里插入图片描述

  • test_cfg相关
    主要注意使用nms中的参数名称
    在这里插入图片描述

3、实现转化

实现转化的流程:
1、利用1.0的config进行模型初始化,得到对应的结构state_dict
2、获取2.5版本保存的state_dict
3、遍历1.0获取的state_dict,保存有差异的key值,存入两个版本的txt文件,如下图所示:
在这里插入图片描述

4、遍历1.0的state_dict,对于不同情况进行分别赋值
具体代码如下所示:

import torch
import mmcv
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint
import warnings
from mmdet.core import get_classes
from mmcv import Config
from mmdet.apis import init_detector

def read_txt(file):
	with open(file, 'r') as f:
	content = f.read().splitlines()
	return content

if __name__ == '__main__':
	import numpy as np
	import os
	num_classes = 3 # 1.0版本中的类别数量
	txt_2 ='key_2.5.txt' ## 2.5版本的key值
	txt_1 ='key_1.0.txt' ## 1.0版本的key值
	key1 = read_txt(txt_1)
	key2 = read_txt(txt_2)
	config_path ='config_v1.0.py' # 1.0版本的config
	model_path ='tranfer.pth' # 2.5版本的state_dict
	cfg = Config.fromfile(config_path)
	model = build_detector(cfg.model, test_cfg=cfg.test_cfg)
	state_dict = torch.load(model_path)
	model_state = model.state_dict()
	temp_dict = {}
	for k, v in model_state.items():
		if k in state_dict.keys() and np.shape(v) == np.shape(state_dict[k]):
			temp_dict[k] = state_dict[k]
	for k1, k2 in zip(key1, key2):
		if np.shape(model_state[k1]) == np.shape(state_dict[k2]):
			if k1.endswith('fc_cls.weight') or k1.endswith('fc_cls.bias'):
				old_val = state_dict[k2]
				new_val = torch.cat([old_val[-1:], old_val[:-1]], dim=0)
				temp_dict[k1] = new_val
			else:
				temp_dict[k1] = state_dict[k2]
		else:
			old_val = state_dict[k2]
				if k1.endswith('fc_reg.weight'):
					out_channels, in_channels = old_val.shape[:2]
					tmp_val = old_val.reshape(num_classes - 1, -1, in_channels,*old_val.shape[2:])
					zero_mode = torch.zeros((1, tmp_val.shape[1],
					tmp_val.shape[2])).type_as(tmp_val) #初始化0进行填充
					new_val = torch.cat([zero_mode, tmp_val], dim=0)
					new_val = new_val.reshape(-1, *old_val.shape[1:])
				elif k1.endswith('fc_reg.bias'):
					tmp_val = old_val.reshape(num_classes - 1, -1)
					zero_mode = torch.zeros((1, tmp_val.shape[1])).type_as(tmp_val)
					#初始化0进行填充
					new_val = torch.cat([tmp_val, zero_mode], dim=0)
					new_val = new_val.reshape(-1)
				temp_dict[k1] = new_val
					
	model_state.update(temp_dict)
	model.load_state_dict(model_state, strict=True)
	model.eval()
	meta = {'CLASSES': ['1','2','3','4','5','6']}
	checkpoint = {
		'meta': meta,
		'state_dict': model.state_dict()
	}
	torch.save(checkpoint,'transfer_model.pth')

上述实现参考mmdetection2.5中模型升级的代码:官方升级代码示例,由于v2.5转换到v1.0存在类别数目变少的情况,本文以全0进行填充

以上就是是mmdetection2.5迁移到mmdetection1.0的示例

如有问题,敬请指正!

–END–

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI小花猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值