mxnet模型转pytorch

核心:模型序列字典中键名的匹配

1 读取mxnet模型

import mxnet as mx
import torch
from resnet50_gcn import resnet50
from torch.nn import BatchNorm2d
from torch.nn import Conv2d
from torch.nn import Linear
import pandas as pd
import numpy as np

# 加载符号图与模型参数
def get_model(model_path, epoch):
    sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
    return sym, arg_params,aux_params

# read mxnet model
sym, arg_params, aux_params = get_model('resnet50_w-glore_0-3_', 0)

2 读取pytorch模型

pytorch_model = resnet50(pretrained=False)
pytorch_model = init_model(pytorch_model, new_arg_params)

3 键名匹配

def dict_change_keys(input_dict, change_map, new_dict):
    # new_dict = {}
    old_dict_key = []
    counts_num = 0
    # old_dict_key = list(pd.read_csv(change_map, header=None, index_col=False, sep='')[0])
    new_dict_key = pd.read_csv(change_map, header=None, sep=' ')
    for key in input_dict.keys():
        old_dict_key.append(key)
    # print(len(old_dict_key))
    # print(new_dict_key.head(2))

    for i in old_dict_key:
        # find map torch name
        # print(list(new_dict_key[0]))
        if i in list(new_dict_key[0]):
            # print(i)
            # print(new_dict_key[new_dict_key[0] == i])
            # print(new_dict_key[new_dict_key[0] == i][1].values[0])
            new_dict_key_name = new_dict_key[new_dict_key[0] == i][1].values[0]
            new_dict[new_dict_key_name] = input_dict[i]
            # print(new_dict)
            counts_num += 1
    # print(counts_num)
    return new_dict

new_arg_params = {}
new_arg_params = dict_change_keys(arg_params, 'mxnet.txt', new_arg_params)
new_arg_params = dict_change_keys(aux_params, 'mxnet.txt', new_arg_params)

4 键值复制

def init_model(model, param_dict):
    for n, m in model.named_modules():
        print(n)
        if isinstance(m, BatchNorm2d):
            bn_init(n, m, param_dict)
        elif isinstance(m, Conv2d):
            conv_init(n, m, param_dict)
        # elif isinstance(m, Linear):
        #     fc_init(n, m, param_dict)
    return model


def bn_init(n, m, param_dict):
    print(n)
    print(m.weight.shape[0])
    print(param_dict[n + '_bn_gamma'].shape[0])
    if not (m.weight is None) and (m.weight.shape[0] == param_dict[n + '_bn_beta'].shape[0]):
        m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_bn_gamma'].asnumpy()))
        m.bias.data.copy_(torch.FloatTensor(param_dict[n + '_bn_beta'].asnumpy()))
        m.running_mean.copy_(torch.FloatTensor(param_dict[n + '_bn_moving_mean'].asnumpy()))
        m.running_var.copy_(torch.FloatTensor(param_dict[n + '_bn_moving_var'].asnumpy()))


def conv_init(n, m, param_dict):
    print(n)
    if m.weight.shape[0] == param_dict[n + '_conv_weight'].shape[0]:
        m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_conv_weight'].asnumpy()))
    # if n in ['conv1_1', 'conv4_1', 'conv3_1', 'conv2_1']:
    # m.bias.data.copy_(torch.FloatTensor(param_dict[n].asnumpy()))


def fc_init(n, m, param_dict):
    print(n)
    m.weight.data.copy_(torch.FloatTensor(param_dict[n + '.weight'].asnumpy()))
    m.bias.data.copy_(torch.FloatTensor(param_dict[n + '.bias'].asnumpy()))

5 测试且保存模型

# read pytorch model
data = torch.autograd.Variable(torch.randn(1, 3, 224, 224))

pytorch_model = resnet50(pretrained=False)
pytorch_model = init_model(pytorch_model, new_arg_params)

output = init_model(data)

# save model
torch.save(pytorch_model.state_dict(), 'resnet50_gcn_init.pth')

6 问题

mxnet中bn的输入输出和pytorch模型中的不一致
  • mxnet中conv维度是(out,in,kernal size, kernal size),bn与in维度一致
  • pytorch中conv维度是(in,out,kernal size,kernal size),bn与out维度一致
    导致bn参数无法读取,全部舍弃
使用model.named_modules()函数

输出n,m分别为对应键名和其对应所有层,因加入图卷积后模型名不便于自动处理,纯手动太蠢了!

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值