将模型参数从Pytorch迁移到Mindspore

最近在做一些pytorch代码迁移到mindspore的项目,由于训练数据非常大不好重新训练,所以想把作者提供的预训练模型参数迁移一下看看效果,这篇文章相当于做个笔记。

这里以deeplab的Resnet101为例:

需要的东西

1. pytorch版本的权重文件(.pth文件)

2. mindspore的模型代码

        mindspore的模型代码需要根据pytorch版代码来自己迁移,有些比较常用的预训练模型代码可以在mindspore的modelzoo中找到。

首先导入需要的包
import mindspore
import torch
import pandas as pd
import csv
import numpy as np
from mindspore import load_checkpoint,load_param_into_net,save_checkpoint
from graphs.models import deeplab_multi # 导入你的mindspore版本网络代码
读取双方的参数并把键值(key)保存为csv文件
# mindspore模型
ms_model = deeplab_multi.DeeplabMulti(num_classes=21,pretrained=False)
keys = pd.DataFrame(ms_model.parameters_dict().keys())
keys.to_csv('mindspore_keys.csv')

# Pytorch参数
torch_model = torch.load('GTA5_source.pth') # 替换成自己的.pth文件路径
model_key = pd.DataFrame(model.keys())
model_key.to_csv('pytorch_res101.csv')
比较一下两边的区别

双方一对一的key可以很容易辨别,可以查看这篇博客实用干货:如何把Pytorch模型参数加载到MindSpore模型?-CSDN博客

此外,在pytorch中多了一个 num_batches_tracked 参数,它记录了自从模型训练开始以来处理的批次数量。这个参数不是用于直接参与前向或反向传播的计算,而是作为一个状态量,通常用于指导 BatchNorm 层的行为。

在迁移 PyTorch 模型到 MindSpore 时,可以忽略该参数。

示例模型只采用了conv2d层和BN层,如果有采用其他层需要再甄别一下

制作双方key值一一对应的字典
# 把双方的key值读进来
torch_keys = pd.read_csv('pytorch_res101.csv')
mindspore_keys = pd.read_csv('mindspore_keys.csv')

# 把'num_batches_tracked'参数踢掉
torch_list = []
for index,value in torch_keys.iterrows():
    name = value[-1]
    if not 'num_batches_tracked' in name:
        torch_list.append(name)
# 制作字典
ms_list = []
for index,value in mindspore_keys.iterrows():
    name = value[-1]
    ms_list.append(name)
key_dict = {}
if len(torch_list) == len(ms_list):
    for p in range(len(torch_list)):
        key_dict[ms_list[p]] = torch_list[p]
 然后把Pytorch参数写入mindspore模型
for k,v in ms_model.parameters_dict().items():
    torch_k = key_dict[k]
    torch_v = torch_model[torch_k]
    if not isinstance(torch_v,np.ndarray):
        torch_v = torch_v.cpu().numpy()
    ms_v = mindspore.Parameter(torch_v,name=k)
    ms_model.parameters_dict()[k] = ms_v
最后把模型保存好就完成了
save_checkpoint(ms_model,'deeplab_multi_ms.ckpt')
另外还有一点:

在pytorch中直接查看某层的参数字典,running_mean 和 running_var 两个参数是不作显示的:

# 单独创建一个bn层
import torch.nn as nn
bn1 = nn.BatchNorm2d(64, momentum=0.1)
params = bn1.named_parameters()
for i in params:
    print(i[0])

>>>
weight
bias

但可以通过加后缀的方式来查看:

bn1.running_mean

>>>
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

  • 11
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值