torch权重转mindspore

0 前言

由于torch的简单易用和灵活性,很多研究工作都是基于torch实现的。
但在实际应用或者其他原因需要迁移到mindspore时, 我们都希望能直接复用torch已经预训练好的权重。
当然mindspore转torch 也是类似的。

1 方法

简单讲, 其实就是一个映射的过程。 权重保存的其实就是一个dict, 包含权重的名字和具体的数值, 只不过不同框架的组织形式会有所不同。
不同体现在2点: 一是名字的命名组织不同, 如torch中batchnorm的权重后缀一般是.weight.bias,而mindspore则是betagamma;二是存储格式不同。
针对第一点不同, 我们需要分别列出他们的权重,根据规则把他们的名字匹配上。
针对第二点不同, 我们要用torch和mindspore各自提供的加载或保存权重的接口,把权重处理成dict这个中间结果, 然后用各自的接口处理就可以了。

2 实现代码

import torch
from mindspore.train import save_checkpoint, load_checkpoint
from mindspore import Tensor


def convert_torch_ms(pth_file, ms_model):
    torch_para_dict = torch.load(pth_file, map_location=torch.device('cpu'))
    print("*" * 10, "torch name list:")
    for k, v in torch_para_dict.items():
        print(k)
        
    ms_name_list = []
    print("#" * 10, "ms name list:")
    for name in ms_model.parameters_dict():
    	print(name)
        ms_name_list.append(name)

    ms_params_list = []
    for ms_name in ms_name_list:
        param_dict = {}
        param_dict['name'] = ms_name
        torch_name = convert_ms_name_to_torch(ms_name)  # TODO
        data = torch_para_dict[torch_name].numpy()
        param_dict['data'] = Tensor(data)
        ms_params_list.append(param_dict)
    save_checkpoint(ms_params_list, "ms_weight.ckpt")
    print("save ms weight sucess!")

torch的参数名可以通过加载的权重直接得到, mindspore的参数名列表可以通过model.parameters_dict()获取。 然后通过一定的规则把他们匹配上,这就是convert_ms_name_to_torch要实现的内容。 最后在把得到的权重以mindspore的方式存储为文件。

3 mindspore可加载权重文件

mindspore加载权重可以通过load_checkpoint实现:

load_checkpoint(model, ckpt_file)

值得注意的是, torch和tf在默认情况下会完整匹配权重, 也就是说只要缺少一个权重, 则加载就会失败。但是mindspore不同,只有部分权重也是可以加载的,简单举例来说, 假设模型有100个权重, 给定的权重文件中只有20个权重, 那么依然可以加载成功, 只是剩余的80个没有加载权重而已。
这一点在转换权重时需要格外注意, 可能部分权重名称转的不对, 导致权重文件中缺少个别权重, 但程序不会报错。 一个技巧是可以打开mindspore info级别的日志(export GLOG_v=1), 日志中会显示那些权重未被加载。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值