【MindSpore】转换PyTorch的checkpoint文件为MindSpore

本段代码可将PyTorch的checkpoint文件转换为MindSpore的checkpoint文件。

注意:只支持仅包含Convolution和BatchNorm算子的简单网络,例如用于图像语义分割的HRNet。算子的命名要与迁移网路的PyTorch版本一致。

from mindspore.train.serialization import load_checkpoint, load_param_into_net, save_checkpoint
from mindspore import Tensor

import torch

from src.seg_hrnet import get_seg_model
from src.config.config import config as cfg

import os


def show_params(ckpt_file_path, frame="torch", key=True, value=False):
    """ Show contents of a checkpoint file.
    """
    if frame == "torch":
        params = torch.load(ckpt_file_path, map_location=torch.device('cpu'))
    elif frame == "mindspore":
        params = load_checkpoint(ckpt_file_path)
    else:
        raise ValueError("Attribute `params` must be in [`torch`, `mindspore`]! ")
    if key and value:
        for k, v in params.items():
            print(k, v)
    elif key and not value:
        for k in params.keys():
            print(k)
    elif value and not key:
        for v in params.values():
            print(v)


def compare_model_names(torch_params_dict, mindspore_params_dict):
    """ Compare the params' names between torch and mindspore nets.
    """
    t_params_dict = torch_params_dict.copy()
    m_params_dict = mindspore_params_dict.copy()
    for key in torch_params_dict.keys():
        if "num_batches_tracked" in key:
            t_params_dict.pop(key)

    for t, m in zip(t_params_dict.keys(), m_params_dict.keys()):
        print(t)
        print(m)
        print("=============================")


def from_torch_to_mindspore(net, ckpt_file_path, save_path):
    """ Transform a torch checkpoint file into mindspore checkpoint.
        Modify the param's name first, then change tensor type.
    """
    if not os.path.isfile(ckpt_file_path):
        raise FileExistsError("The file `{}` is not exist! ".format(ckpt_file_path))
    if ".ckpt" not in save_path:
        raise ValueError("Attribute `save_path` should be a checkpoint file with the end of `.ckpt`!")

    params = torch.load(ckpt_file_path, map_location=torch.device('cpu'))

    torch_params = list(params.items())
    num_params = len(torch_params)
    params_list = []
    for i in range(num_params):
        key, value = torch_params[i]
        if "weight" in key and i+2 < num_params:
            if "running_mean" in torch_params[i+2][0]:
                key = key.replace("weight", "gamma")
        if "bias" in key and i+1 < num_params:
            if "running_mean" in torch_params[i+1][0]:
                key = key.replace("bias", "beta")
        if "running_var" in key:
            key = key.replace("running_var", "moving_variance")
        if "running_mean" in key:
            key = key.replace("running_mean", "moving_mean")
        if "num_batches_tracked" in key:
            continue
        if "incre" in key:      # `incre` is a name of params in hrnet for classification.
            break
        params_list.append({"name": key, "data": Tensor(value.numpy())})
    save_checkpoint(params_list, save_path)


if __name__ == "__main__":
    net = get_seg_model(cfg, init="TruncatedNormal", pretrained_ckpt="nothing")
    m_params = net.parameters_dict()
    t_params = torch.load("train_out/hrnet_cs_8090_torch11.pth", map_location=torch.device('cpu'))
    show_params("train_out/hrnet_cs_8090_torch11.pth")

    # compare_model_names(t_params, m_params)
    # from_torch_to_mindspore(net, "train_out/hrnet_cs_8090_torch11.pth", "train_out/torch_hrnet_result.ckpt")
    # a = load_checkpoint("train_out/torch_hrnet_result.ckpt")
    # for m, n in zip(m_params.keys(), a.keys()):
    #     print(m)
    #     print(n)
    #     print("==================")

    # load_param_into_net(net, a)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值