0 前言
由于torch的简单易用和灵活性,很多研究工作都是基于torch实现的。
但在实际应用或者其他原因需要迁移到mindspore时, 我们都希望能直接复用torch已经预训练好的权重。
当然mindspore转torch 也是类似的。
1 方法
简单讲, 其实就是一个映射的过程。 权重保存的其实就是一个dict, 包含权重的名字和具体的数值, 只不过不同框架的组织形式会有所不同。
不同体现在2点: 一是名字的命名组织不同, 如torch中batchnorm
的权重后缀一般是.weight
和.bias
,而mindspore则是beta
和gamma
;二是存储格式不同。
针对第一点不同, 我们需要分别列出他们的权重,根据规则把他们的名字匹配上。
针对第二点不同, 我们要用torch和mindspore各自提供的加载或保存权重的接口,把权重处理成dict这个中间结果, 然后用各自的接口处理就可以了。
2 实现代码
import torch
from mindspore.train import save_checkpoint, load_checkpoint
from mindspore import Tensor
def convert_torch_ms(pth_file, ms_model):
torch_para_dict = torch.load(pth_file, map_location=torch.device('cpu'))
print("*" * 10, "torch name list:")
for k, v in torch_para_dict.items():
print(k)
ms_name_list = []
print("#" * 10, "ms name list:")
for name in ms_model.parameters_dict():
print(name)
ms_name_list.append(name)
ms_params_list = []
for ms_name in ms_name_list:
param_dict = {}
param_dict['name'] = ms_name
torch_name = convert_ms_name_to_torch(ms_name) # TODO
data = torch_para_dict[torch_name].numpy()
param_dict['data'] = Tensor(data)
ms_params_list.append(param_dict)
save_checkpoint(ms_params_list, "ms_weight.ckpt")
print("save ms weight sucess!")
torch的参数名可以通过加载的权重直接得到, mindspore的参数名列表可以通过model.parameters_dict()
获取。 然后通过一定的规则把他们匹配上,这就是convert_ms_name_to_torch
要实现的内容。 最后在把得到的权重以mindspore的方式存储为文件。
3 mindspore可加载权重文件
mindspore加载权重可以通过load_checkpoint
实现:
load_checkpoint(model, ckpt_file)
值得注意的是, torch和tf在默认情况下会完整匹配权重, 也就是说只要缺少一个权重, 则加载就会失败。但是mindspore不同,只有部分权重也是可以加载的,简单举例来说, 假设模型有100个权重, 给定的权重文件中只有20个权重, 那么依然可以加载成功, 只是剩余的80个没有加载权重而已。
这一点在转换权重时需要格外注意, 可能部分权重名称转的不对, 导致权重文件中缺少个别权重, 但程序不会报错。 一个技巧是可以打开mindspore info级别的日志(export GLOG_v=1), 日志中会显示那些权重未被加载。