Mindformer Baichuan2转换为Huggingface model
// An highlighted block
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, Trainer
import os
import json
import argparse
import mindspore as ms
dtype=ms.float16
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def name_de_replace(name: str):
"""replace hf param name to ms."""
name = name.replace( 'tok_embeddings.embedding_weight','embed_tokens.weight')
name = name.replace('.attention.wq.','.self_attn.q_proj.' )
name = name.replace( '.attention.wk.','.self_attn.k_proj.')
name = name.replace('.attention.wv.','.self_attn.v_proj.')
name = name.replace( '.attention.wo.','.self_attn.o_proj.')
name = name.replace( '.feed_forward.w1.','.mlp.gate_proj.')
name = name.replace( '.feed_forward.w2.','.mlp.down_proj.')
name = name.replace( '.feed_forward.w3.','.mlp.up_proj.')
name = name.replace( '.attention_norm.','.input_layernorm.')
name = name.replace( '.ffn_norm.','.post_attention_layernorm.')
name = name.replace( '.norm_out.','.norm.')
return name
def main(args):
ckpt_list = []
model_mind = ms.load_checkpoint(args.mindspore_ckpt_path)
args_hf = read_json(os.path.join(args.torch_config_dir, "config.json"))
for name, value in model_mind.items():
value = torch.from_numpy(value.asnumpy())
if 'attention.wq' in name:
wq = value
wk = torch.from_numpy(model_mind[name.replace('.attention.wq', '.attention.wk')].asnumpy())
wv = torch.from_numpy(model_mind[name.replace('.attention.wq', '.attention.wv')].asnumpy())
pack = torch.cat([wq, wk, wv], 0)
new_name = name.replace('.attention.wq','.self_attn.W_pack' )
ckpt_list.append({'name': new_name, 'data': pack})
continue
if name == 'norm_out.weight':
name = 'norm.weight'
if 'attention.wk' not in name and 'attention.wv' not in name:
name = name_de_replace(name)
ckpt_list.append({'name': name, 'data': value})
state_dict = ckpt_list
state_dict_fixed = {}
for ls in state_dict:
state_dict_fixed[ls['name']] = ls['data']
config = AutoConfig.from_pretrained(args.torch_config_dir, trust_remote_code=True)
model_empty = AutoModelForCausalLM.from_config(config, device_map="auto", trust_remote_code=True)
model_empty.load_state_dict(state_dict_fixed)
model_empty.save_pretrained(args.torch_ckpt_dir)
print("Convert Successfully!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--torch_config_dir', default='/home/Data/LLM/Baichuan2-13B-Chat', help='need config file to init model structure')
parser.add_argument('--mindspore_ckpt_path', default='/home/Downloads/model-sft-1124/rank_0/checkpoint_0.ckpt')
parser.add_argument('--torch_ckpt_dir', default='/home/Data/LLM/test_saved_Baichuan2_13B_7', help='path to save converted model')
args = parser.parse_args()
main(args)