将pytorch权重文件转化为mindspore权重文件的代码示例

将pytorch权重文件转化为mindspore权重文件的代码示例

1.取得参数名映射表

2.根据参数名映射迁移权重参数

代码示例如下:

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert checkpoint from torch/facebook"""
import argparse
import mindspore as ms
import torch

def get_keymap_txt(pth_file):
    map_path = pth_file.split('.')[0] + '_key_map.txt'
    map_file = open(map_path, 'w')
    state_dict = torch.load(pth_file, map_location=torch.device('cpu'))
    if 'model_state' in state_dict:
        state_dict = state_dict['model_state']
    elif 'module' in state_dict:
        state_dict = state_dict['module']
    elif 'model' in state_dict:
        state_dict = state_dict['model']
    
    for k in state_dict:
        new_k = k
        if new_k.startswith('encoder.'):
            new_k = new_k.split('encoder.')[1]
        else:
            continue
        if new_k.startswith('norm.'):
            new_k = new_k.replace('norm.', 'fc_norm.')
        if 'norm' in new_k:
            if 'weight' in new_k:
                new_k = new_k.replace('weight', 'gamma')
            if 'bias' in k:
                new_k = new_k.replace('bias', 'beta')
        map_file.write(k+' '+new_k+'\n')
    return map_path

def convert_weight(pth_file="mae_pretrain_vit_base.pth", key_map_path="key_map.txt"):
    """
    convert mae_vit_base_p16 weights from pytorch to mindspore
    pytorch and GPU required.
    """
    ms_ckpt_path = pth_file.split('.')[0] + '.ckpt'
    key_map_dict = {}
    with open(key_map_path, 'r') as f:
        key_map_lines = [s.strip() for s in f.readlines()]
        for line in key_map_lines:
            ckpt_key, model_key = line.split(' ')
            key_map_dict[ckpt_key] = model_key

    ms_ckpt = []
    ckpt = torch.load(pth_file, map_location=torch.device('cpu'))
    if 'model_state' in ckpt:
        state_dict = ckpt['model_state']
    elif 'module' in ckpt:
        state_dict = ckpt['module']
    elif 'model' in ckpt:
        state_dict = ckpt['model']
    else:
        state_dict = ckpt

    for k, v in state_dict.items():
        if 'decoder_pos_embed' == k:
            v = v[:, 1:, :]
        if k in key_map_dict:
            k_map = key_map_dict[k]
        else:
            continue

        ms_ckpt.append({'name': k_map, 'data': ms.Tensor(v.numpy())})
        print(k_map)

    ms.save_checkpoint(ms_ckpt, ms_ckpt_path)

def change_keys(pth_file="mae_pretrain_vit_base.pth", key_map_path="key_map.txt"):
    """
    convert mae_vit_base_p16 weights from pytorch to mindspore
    pytorch and GPU required.
    """
    key_map_dict = {}
    with open(key_map_path, 'r') as f:
        key_map_lines = [s.strip() for s in f.readlines()]
        for line in key_map_lines:
            ckpt_key, model_key = line.split(' ')
            key_map_dict[ckpt_key] = model_key

    new_state_dict = {}
    ckpt = torch.load(pth_file, map_location=torch.device('cpu'))
    if 'model_state' in ckpt:
        state_dict = ckpt['model_state']
    elif 'module' in ckpt:
        state_dict = ckpt['module']
    else:
        state_dict = ckpt

    for k, v in state_dict.items():
        if 'decoder_pos_embed' == k:
            v = v[:, 1:, :]
        k_map = key_map_dict[k]
        new_state_dict[k_map] = v
        print(k_map)
    
    ckpt['model_state'] = new_state_dict

    save_file = pth_file.replace('.pth', '_key_changed.pth')
    torch.save(ckpt, save_file)

def show_info(pth_file="mae_pretrain_vit_base.pth"):
    info_path = pth_file.split('.')[0] + '_info.txt'
    info_file = open(info_path, 'w')
    if '.ckpt' in pth_file:
        state_dict = ms.load_checkpoint(pth_file)
    elif '.pth' in pth_file:
        state_dict = torch.load(pth_file, map_location=torch.device('cpu'))
    # print(type(state_dict))
    # print(state_dict.keys())
    if 'model_state' in state_dict:
        state_dict = state_dict['model_state']
    elif 'module' in state_dict:
        state_dict = state_dict['module']
    elif 'model' in state_dict:
        state_dict = state_dict['model']

    for k, v in state_dict.items():  #state_dict.items():<class 'odict_items'>
        info = k + ' ' + str(v.shape) + ' ' + str(v.dtype)
        info_file.write(info + '\n')
        print(info)

    print(type(state_dict.items()))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="mae vit weight convert script")
    parser.add_argument("--torch_path",
                        type=str,
                        default="/home/zgw-194/zsz/mae_st_ms/mae_result/ViT-B_VideoMAE_ep2400_SSv2.pth",
                        help="The torch checkpoint path.")
    parser.add_argument("--key_map_path",
                        type=str,
                        default="/home/zgw-194/zsz/mae_st_ms/mae_result/ViT-B_VideoMAE_ucf101_wlm_key_map.txt")
    opt = parser.parse_args()

    # convert_weight(opt.torch_path, opt.key_map_path)
    # change_keys(opt.torch_path, opt.key_map_path)
    show_info(opt.torch_path)
    keymap_txt = get_keymap_txt(opt.torch_path)
    convert_weight(opt.torch_path, keymap_txt)


可用于测试的权重文件与代码:

​ 链接:https://pan.baidu.com/s/19nQzKo-ussycKzJ8HOfeug?pwd=0000 提取码:0000

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值