MindSpore保存的模型文件转换为PyTorch中的模型文件
输入 MindSpore的ckpt文件,以ResNet-18为例,MindSpore的网络结构和PyTorch保持一致,转完之后可直接加载进网络,这边参数只用到bn和conv2d,若有其他层MindSpore和PyTorch名称不一致,需要同样的修改名称。
作用:
- 方便和PyTorch训练结果对比
- 如果不确定自己训练的结果是不是正确(MindSpore的前向目前可能存在问题),可转到PyTorch上跑前向看看
转换脚本:
from mindspore.train.serialization import load_checkpoint
import torch
def mindspore2pytorch(ckpt_name='res18_ms.ckpt'):
par_dict = load_checkpoint(ckpt_name)
state_dict = {}
state_dict_useless = ['global_step', 'learning_rate', 'beta1_power', 'beta2_power']
for name in par_dict:
parameter = par_dict[name].data
if name in state_dict_useless or name.startswith('moment1.') or name.startswith('moment2.'):
pass
else:
print('========================ms_name:', name )
if name.endswith('.beta'):
name = name[:name.rfind('.beta')]
name = name + '.bias'
elif name.endswith('.gamma'):
name = name[:name.rfind('.gamma')]
name = name + '.weight'
elif name.endswith('.moving_mean'):
name = name[:name.rfind('.moving_mean')]
name = name + '.running_mean'
elif name.endswith('.moving_variance'):
name = name[:name.rfind('.moving_variance')]
name = name + '.running_var'
print('========================py_name:', name )
state_dict[name] = torch.from_numpy(parameter.asnumpy()) ###
torch.save({'state_dict' : state_dict}, 'res18_py.pth')
PyTorch保存的模型文件转换为MindSpore中的模型文件
输入 PyTorch的pth文件,以ResNet-18为例,MindSpore的网络结构和PyTorch保持一致,转完之后可直接加载进网络,这边参数只用到BN和Conv2D,若有其他层MindSpore和PyTorch名称不一致,需要同样的修改名称。
作用:
有的网络训练的时候必须加载一些预训练权重,将PyTorch的预训练权重转到MindSpore支持的格式。
转换脚本:
from mindspore.train.serialization import save_checkpoint
from mindspore import Tensor
import torch
def pytorch2mindspore(ckpt_name='res18_py.pth'):
par_dict = torch.load(ckpt_name, map_location=torch.device('cpu'))
new_params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
print('========================py_name',name)
if name.endswith('normalize.bias'):
name = name[:name.rfind('normalize.bias')]
name = name + 'normalize.beta'
elif name.endswith('normalize.weight'):
name = name[:name.rfind('normalize.weight')]
name = name + 'normalize.gamma'
elif name.endswith('.running_mean'):
name = name[:name.rfind('.running_mean')]
name = name + '.moving_mean'
elif name.endswith('.running_var'):
name = name[:name.rfind('.running_var')]
name = name + '.moving_variance'
print('========================ms_name',name)
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
new_params_list.append(param_dict)
save_checkpoint(new_params_list, 'res18_ms.ckpt')