文章目录
核心:模型序列字典中键名的匹配
1 读取mxnet模型
import mxnet as mx
import torch
from resnet50_gcn import resnet50
from torch.nn import BatchNorm2d
from torch.nn import Conv2d
from torch.nn import Linear
import pandas as pd
import numpy as np
# 加载符号图与模型参数
def get_model(model_path, epoch):
sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
return sym, arg_params,aux_params
# read mxnet model
sym, arg_params, aux_params = get_model('resnet50_w-glore_0-3_', 0)
2 读取pytorch模型
pytorch_model = resnet50(pretrained=False)
pytorch_model = init_model(pytorch_model, new_arg_params)
3 键名匹配
def dict_change_keys(input_dict, change_map, new_dict):
# new_dict = {}
old_dict_key = []
counts_num = 0
# old_dict_key = list(pd.read_csv(change_map, header=None, index_col=False, sep='')[0])
new_dict_key = pd.read_csv(change_map, header=None, sep=' ')
for key in input_dict.keys():
old_dict_key.append(key)
# print(len(old_dict_key))
# print(new_dict_key.head(2))
for i in old_dict_key:
# find map torch name
# print(list(new_dict_key[0]))
if i in list(new_dict_key[0]):
# print(i)
# print(new_dict_key[new_dict_key[0] == i])
# print(new_dict_key[new_dict_key[0] == i][1].values[0])
new_dict_key_name = new_dict_key[new_dict_key[0] == i][1].values[0]
new_dict[new_dict_key_name] = input_dict[i]
# print(new_dict)
counts_num += 1
# print(counts_num)
return new_dict
new_arg_params = {}
new_arg_params = dict_change_keys(arg_params, 'mxnet.txt', new_arg_params)
new_arg_params = dict_change_keys(aux_params, 'mxnet.txt', new_arg_params)
4 键值复制
def init_model(model, param_dict):
for n, m in model.named_modules():
print(n)
if isinstance(m, BatchNorm2d):
bn_init(n, m, param_dict)
elif isinstance(m, Conv2d):
conv_init(n, m, param_dict)
# elif isinstance(m, Linear):
# fc_init(n, m, param_dict)
return model
def bn_init(n, m, param_dict):
print(n)
print(m.weight.shape[0])
print(param_dict[n + '_bn_gamma'].shape[0])
if not (m.weight is None) and (m.weight.shape[0] == param_dict[n + '_bn_beta'].shape[0]):
m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_bn_gamma'].asnumpy()))
m.bias.data.copy_(torch.FloatTensor(param_dict[n + '_bn_beta'].asnumpy()))
m.running_mean.copy_(torch.FloatTensor(param_dict[n + '_bn_moving_mean'].asnumpy()))
m.running_var.copy_(torch.FloatTensor(param_dict[n + '_bn_moving_var'].asnumpy()))
def conv_init(n, m, param_dict):
print(n)
if m.weight.shape[0] == param_dict[n + '_conv_weight'].shape[0]:
m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_conv_weight'].asnumpy()))
# if n in ['conv1_1', 'conv4_1', 'conv3_1', 'conv2_1']:
# m.bias.data.copy_(torch.FloatTensor(param_dict[n].asnumpy()))
def fc_init(n, m, param_dict):
print(n)
m.weight.data.copy_(torch.FloatTensor(param_dict[n + '.weight'].asnumpy()))
m.bias.data.copy_(torch.FloatTensor(param_dict[n + '.bias'].asnumpy()))
5 测试且保存模型
# read pytorch model
data = torch.autograd.Variable(torch.randn(1, 3, 224, 224))
pytorch_model = resnet50(pretrained=False)
pytorch_model = init_model(pytorch_model, new_arg_params)
output = init_model(data)
# save model
torch.save(pytorch_model.state_dict(), 'resnet50_gcn_init.pth')
6 问题
mxnet中bn的输入输出和pytorch模型中的不一致
- mxnet中conv维度是(out,in,kernal size, kernal size),bn与in维度一致
- pytorch中conv维度是(in,out,kernal size,kernal size),bn与out维度一致
导致bn参数无法读取,全部舍弃
使用model.named_modules()函数
输出n,m分别为对应键名和其对应所有层,因加入图卷积后模型名不便于自动处理,纯手动太蠢了!