前言
本系列博客将针对部分模型进行代码微调实战:
模型包括baichuan2、qwen、chatglm、mistral等;
微调的方式包括SFT、DPO、PPO、KTO等;
主要基于transformers、trl、peft等实现;
自从chatgpt问世以来,人工智能领域的发展便进入了大模型时代。这些模型,如 GPT、Gimini、claude、llama、Qwen、baichuan等,已经在多个自然语言处理任务上设立了新的性能标准,并且展现出巨大的潜力。这些模型通常在大型数据集上进行预训练,以捕捉广泛的语言规律和知识,然后在特定任务上进行微调以实现更精确的应用。微调大模型成为了一种高效利用LLM解决领域问题的重要方法。
微调的核心动机在于利用大模型在广泛数据上学到的丰富特征表示。这些大型模型通过在大规模数据集上学习,已经内化了丰富的语言结构和语义信息,使它们能够理解和生成人类语言。通过微调,我们可以将这些高级特征应用于特定的下游任务,如情感分析、文本分类或问答系统,无需从头开始训练模型。这种方法不仅节省了大量的训练时间和资源,还能显著提升任务的执行效果。
目前,github仓库中已经出现了一些LLM微调框架,如llama-factory、llama2-chinese等。使用开源框架虽然可以很方便地对很多模型进行微调,但是作为一个算法工程师,一直使用开源框架而没有属于自己的脚本,会导致无法深刻理解脚本的处理流程。如果缺乏对训练脚本内部细节的把握,当需要训练框架不支持的模型时就会变得无从下手。
此外,编写自己的微调脚本还有助于实现模型的透明度和可解释性。在企业和关键应用中,理解模型的行为和判断其可靠性是至关重要的。通过深入掌握微调的过程,开发者可以更有效地监控模型在特定数据集上的表现,及时发现并调整可能导致偏差或错误的因素。
一、Qwen SFT脚本
1.引入库
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
from peft import LoraConfig, get_peft_model
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
from transformers.generation.utils import GenerationConfig
from trl import SFTTrainer
2.定义参数类
@dataclass
class ScriptArguments:
# traindata parameters
train_data: Optional[str] = field(
default="/data2/xxx/train_dpo/data/hh-rlhf", metadata={"help": "训练数据的位置"})
# training parameters
model_name_or_path: Optional[str] = field(
default="", metadata={"help": "the model name"})
max_length: Optional[int] = field(
default=512, metadata={"help": "max length of each sample"})
max_prompt_length: Optional[int] = field(
default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: Optional[int] = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
label_pad_token_id: Optional[int] = field(
default=-100, metadata={"help": "label for non response tokens"})
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
3.定义数据集加载器
def get_dataset(train_data_path: str, silent: bool = False, cache_dir: str = None) -> Tuple[Dataset, Dataset]:
datasetall = load_dataset(
"json",
data_files={
train_data_path
},
cache_dir=cache_dir,
)
def split_prompt_and_responses(sample) -> Dict[str, str]:
answers = sample["output"]
instruction = sample["instruction"]
return {
"prompt": instruction + answers[0],
}
datasetall = datasetall.map(split_prompt_and_responses)
train_test_split = datasetall["train"].train_test_split(test_size=0.8)
dataset_train = train_test_split['test']
dataset_test = train_test_split['train']
return dataset_train, dataset_test
4.定义主函数
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses() # [0]
# load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
trust_remote_code=True,
torch_dtype='auto',
# device_map='auto'
)
model.generation_config = GenerationConfig.from_pretrained(
script_args.model_name_or_path)
# laod peft model
LORA_R = 32
# LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["c_attn", "c_proj", "w1", "w2"]
config = LoraConfig(
r=LORA_R,
# lora_alpha=LORA_ALPHA,
target_modules=TARGET_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
tokenizer = AutoTokenizer.from_pretrained(
script_args.model_name_or_path, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eod_id
with training_args.main_process_first(desc="loading and tokenization"):
# Load train and Load evaluation dataset
train_dataset, eval_dataset = get_dataset(
train_data_path=script_args.train_data)
# initialize the sft trainer
sft_trainer = SFTTrainer(
model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_seq_length=script_args.max_length,
dataset_text_field = "prompt"
)
sft_trainer.train()
5.sh脚本
python train.py \
--model_name_or_path /mnt/data3/models/Qwen-7B-Chat/ \
--train_data /mnt/data3/xxxx/comparison_gpt4_data_zh.json \
--learning_rate 2e-4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 1 \
--max_length 1024 \
--report_to tensorboard \
--save_strategy steps \
--save_steps 500 \
--logging_steps 10 \
--save_total_limit 2 \
--output_dir ./test # --max_steps 2000 \