mxnet模型转pytorch模型

转换基本流程:

1)创建pytorch的网络结构模型;

2)利用mxnet来读取其存储的预训练模型,用于读取模型的参数;

3)遍历mxnet加载的模型参数;

4)对一些指定的key值,需要进行相应的处理和转换;

5)对修改后的层名(key值),利用numpy之间的转换来实现加载;

6)对相应层进行参数(feature)进行比较;

流程基本是与caffe模型转pytorch模型这篇文章一致,唯一需要注意的一点就是:

mxnet中解析的参数有三个:

sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)

arg_params是主要参数如weights;

aux_params是辅助参数主要是bias或者是batchnorm中的一些参数;

不是所有参数都在arg_params中,batchnorm层的权值和偏置就是保存在aux_params

以Resnet20为例:

1)(略)

2)加载mxnet模型

#加载符号图与模型参数
def get_model(model_path, epoch):
    sym, arg_params, aux_params = mx.model.load_checkpoint(model_path, epoch)
    return sym, arg_params,aux_params

3)4)5)

    def init_model(self, model,param_dict,aux_params):
        # print(model)
        for n, m in model.named_modules():
            print(n)
            if isinstance(m, BatchNorm2d):
                self.bn_init(n, m, param_dict,aux_params)
            elif isinstance(m, Conv2d):
                self.conv_init(n, m, param_dict)
            elif isinstance(m, Linear):
                self.fc_init(n, m, param_dict)
            elif isinstance(m, PReLU):
                self.prelu_init(n, m, param_dict)


        return model

    def bn_init(self, n, m, param_dict,aux_params):
        if not (m.weight is None):
            m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_gamma'].asnumpy()))
            m.bias.data.copy_(torch.FloatTensor(param_dict[n+'_beta'].asnumpy()))
        m.running_mean.copy_(torch.FloatTensor(aux_params[n+'_moving_mean'].asnumpy()))
        m.running_var.copy_(torch.FloatTensor(aux_params[n+'_moving_var'].asnumpy()))

    def conv_init(self, n, m, param_dict):
        # print('n = ', n)
        m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_weight'].asnumpy()))
        if n in ['conv1_1', 'conv4_1', 'conv3_1', 'conv2_1']:
            m.bias.data.copy_(torch.FloatTensor(param_dict[n + '_bias'].asnumpy()))

    def fc_init(self, n, m, param_dict):
        m.weight.data.copy_(torch.FloatTensor(param_dict[n+'_weight'].asnumpy()))
        m.bias.data.copy_(torch.FloatTensor(param_dict[n+'_bias'].asnumpy()))

    def prelu_init(self, n, m, net):
        m.weight.data.copy_(torch.FloatTensor(param_dict[n + '_gamma'].asnumpy()))

6)(略,我自己也没写,直接使用测试集测试了一下转换后的模型,小数点后四位没有精度偏差,就没有做了)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫猫与橙子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值