正在迁移一个mindspore的网络, 已有torch版本, 且已经有torch预训练好的权重。 从头训练mindspore的网络太耗时, 希望可以直接加载torch的权重进行finetune。 请问有什么方法可以做到呢?
需要把torch 的checkpoint转化成MindSpore格式的checkpoint。然后用MindSpore加载checkpoint执行fune-tune任务。
转化checkpoint可参考如下代码:
from mindspore.train.serialization import save_checkpointfrom mindspore import Tensorimport torchdef pytorch2mindspore('res18_py.pth'):
par_dict = torch.load('res18_py.pth')['state_dict']
new_params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
print('========================py_name',name)
if name.endswith('normalize.bias'):
name = name[:name.rfind('normalize.bias')]
name = name + 'normalize.beta'
elif name.endswith('normalize.weight'):
name = name[:name.rfind('normalize.weight')]
name = name + 'normalize.gamma'
elif name.endswith('.running_mean'):
name = name[:name.rfind('.running_mean')]
name = name + '.moving_mean'
elif name.endswith('.running_var'):
name = name[:name.rfind('.running_var')]
name = name + '.moving_variance'
print('========================ms_name',name)
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
new_params_list.append(param_dict)
save_checkpoint(new_params_list, 'res18_ms.ckpt')