【MindSpore易点通】如何实现MindSpore与PyTorch模型文件相互转换

MindSpore保存的模型文件转换为PyTorch中的模型文件

输入 MindSpore的ckpt文件,以ResNet-18为例,MindSpore的网络结构和PyTorch保持一致,转完之后可直接加载进网络,这边参数只用到bn和conv2d,若有其他层MindSpore和PyTorch名称不一致,需要同样的修改名称。

作用:

  1. 方便和PyTorch训练结果对比
  2. 如果不确定自己训练的结果是不是正确(MindSpore的前向目前可能存在问题),可转到PyTorch上跑前向看看

转换脚本:

from mindspore.train.serialization import load_checkpoint
import torch

def mindspore2pytorch(ckpt_name='res18_ms.ckpt'):
    par_dict = load_checkpoint(ckpt_name)
    state_dict = {}
    state_dict_useless = ['global_step', 'learning_rate', 'beta1_power', 'beta2_power']

    for name in par_dict:
        parameter = par_dict[name].data

        if name in state_dict_useless or name.startswith('moment1.') or name.startswith('moment2.'):
            pass

        else:
            print('========================ms_name:', name )

            if name.endswith('.beta'):
                name = name[:name.rfind('.beta')]
                name = name + '.bias'

            elif name.endswith('.gamma'):
                name = name[:name.rfind('.gamma')]
                name = name + '.weight'

            elif name.endswith('.moving_mean'):
                name = name[:name.rfind('.moving_mean')]
                name = name + '.running_mean'

            elif name.endswith('.moving_variance'):
                name = name[:name.rfind('.moving_variance')]
                name = name + '.running_var'

            print('========================py_name:', name )
            state_dict[name] = torch.from_numpy(parameter.asnumpy())  ###

torch.save({'state_dict' : state_dict}, 'res18_py.pth')

PyTorch保存的模型文件转换为MindSpore中的模型文件

输入 PyTorch的pth文件,以ResNet-18为例,MindSpore的网络结构和PyTorch保持一致,转完之后可直接加载进网络,这边参数只用到BN和Conv2D,若有其他层MindSpore和PyTorch名称不一致,需要同样的修改名称。

作用

有的网络训练的时候必须加载一些预训练权重,将PyTorch的预训练权重转到MindSpore支持的格式。

转换脚本

from mindspore.train.serialization import save_checkpoint
from mindspore import Tensor
import torch

def pytorch2mindspore(ckpt_name='res18_py.pth'):

    par_dict = torch.load(ckpt_name, map_location=torch.device('cpu'))

    new_params_list = []

    for name in par_dict:
        param_dict = {}
        parameter = par_dict[name]

        print('========================py_name',name)

        if name.endswith('normalize.bias'):
            name = name[:name.rfind('normalize.bias')]
            name = name + 'normalize.beta'

        elif name.endswith('normalize.weight'):
            name = name[:name.rfind('normalize.weight')]
            name = name + 'normalize.gamma'

        elif name.endswith('.running_mean'):
            name = name[:name.rfind('.running_mean')]
            name = name + '.moving_mean'

        elif name.endswith('.running_var'):
            name = name[:name.rfind('.running_var')]
            name = name + '.moving_variance'

        print('========================ms_name',name)

        param_dict['name'] = name
        param_dict['data'] = Tensor(parameter.numpy())
        new_params_list.append(param_dict)


    save_checkpoint(new_params_list,  'res18_ms.ckpt')

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值