Mindformer Baichuan2转换为Huggingface model

Mindformer Baichuan2转换为Huggingface model

// An highlighted block
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, Trainer
import os
import json
import argparse
import mindspore as ms
dtype=ms.float16
def read_json(path):
    with open(path, "r") as f:
        return json.load(f)
def name_de_replace(name: str):
    """replace hf param name to ms."""
    name = name.replace( 'tok_embeddings.embedding_weight','embed_tokens.weight')
    name = name.replace('.attention.wq.','.self_attn.q_proj.' )
    name = name.replace( '.attention.wk.','.self_attn.k_proj.')
    name = name.replace('.attention.wv.','.self_attn.v_proj.')
    name = name.replace( '.attention.wo.','.self_attn.o_proj.')
    name = name.replace( '.feed_forward.w1.','.mlp.gate_proj.')
    name = name.replace( '.feed_forward.w2.','.mlp.down_proj.')
    name = name.replace( '.feed_forward.w3.','.mlp.up_proj.')
    name = name.replace( '.attention_norm.','.input_layernorm.')
    name = name.replace( '.ffn_norm.','.post_attention_layernorm.')
    name = name.replace( '.norm_out.','.norm.')
    return name

#只是加载模型,但没有加载权重
def main(args):
    ckpt_list = []
    model_mind = ms.load_checkpoint(args.mindspore_ckpt_path)
    args_hf = read_json(os.path.join(args.torch_config_dir, "config.json"))

    for name, value in model_mind.items():
        value = torch.from_numpy(value.asnumpy())
        if 'attention.wq' in name:
            wq = value
            wk = torch.from_numpy(model_mind[name.replace('.attention.wq', '.attention.wk')].asnumpy())
            wv = torch.from_numpy(model_mind[name.replace('.attention.wq', '.attention.wv')].asnumpy())

            pack = torch.cat([wq, wk, wv], 0)
            new_name = name.replace('.attention.wq','.self_attn.W_pack' )  # '.self_attn.q_proj.', '.attention.wq.'
            ckpt_list.append({'name': new_name, 'data': pack})
            continue
        if name == 'norm_out.weight':
            name = 'norm.weight'

        if 'attention.wk' not in name and 'attention.wv' not in name:
            name = name_de_replace(name)
            ckpt_list.append({'name': name, 'data': value})

    # torch.save({'state_dict' : ckpt_list}, '/home/wq/Data/LLM/Baichuan2-7B-Base/test.pth')
    # checkpoint = torch.load('/home/wq/Data/LLM/Baichuan2-7B-Base/test.pth')

    state_dict = ckpt_list
    state_dict_fixed = {}
    for ls in state_dict:
        state_dict_fixed[ls['name']] = ls['data']
    config = AutoConfig.from_pretrained(args.torch_config_dir, trust_remote_code=True)  # 加载
    model_empty = AutoModelForCausalLM.from_config(config, device_map="auto", trust_remote_code=True)
    model_empty.load_state_dict(state_dict_fixed)
    model_empty.save_pretrained(args.torch_ckpt_dir)
    print("Convert Successfully!")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--torch_config_dir', default='/home/Data/LLM/Baichuan2-13B-Chat', help='need config file to init model structure')
    parser.add_argument('--mindspore_ckpt_path', default='/home/Downloads/model-sft-1124/rank_0/checkpoint_0.ckpt')
    parser.add_argument('--torch_ckpt_dir', default='/home/Data/LLM/test_saved_Baichuan2_13B_7', help='path to save converted model')

    args = parser.parse_args()
    main(args)
  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值