GLM4微调代码解析

# -*- coding: utf-8 -*-
import os
import jieba
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    EvalPrediction,
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional

app = typer.Typer(pretty_exceptions_show_locals=False)  # 创建一个typer应用实例


class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        # 如果特征中有'output_ids'键,则提取输出ID
        output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
        if output_ids is not None:
            # 计算输出的最大长度
            max_output_length = max(len(out) for out in output_ids)
            if self.pad_to_multiple_of is not None:
                # 将最大长度填充到指定的倍数
                max_output_length = (
                        (
                                max_output_length + self.pad_to_multiple_of - 1) //
                        self.pad_to_multiple_of * self.pad_to_multiple_of
                )
            for feature in features:
                # 填充输出ID到最大长度
                remainder = [self.tokenizer.pad_token_id] * (
                        max_output_length - len(feature['output_ids'])
                )
                if isinstance(feature['output_ids'], list):
                    feature['output_ids'] = feature['output_ids'] + remainder
                else:
                    feature['output_ids'] = np.concatenate(
                        [feature['output_ids'], remainder]
                    ).astype(np.int64)
        return super().__call__(features, return_tensors)


class Seq2SeqTrainer(_Seq2SeqTrainer):
    # Not Support for apex
    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:

        model.train()  # 设定模型为训练模式
        inputs = self._prepare_inputs(inputs)  # 准备输入数据

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)  # 计算损失

        if self.args.n_gpu > 1:
            loss = loss.mean()  # 多GPU情况下取平均损失
        self.accelerator.backward(loss)  # 反向传播
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps  # 分离损失并处理梯度累积
        del inputs  # 删除输入数据
        torch.cuda.empty_cache()  # 清理CUDA缓存
        return detached_loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, Any],
            prediction_loss_only: bool,
            ignore_keys=None,
            **gen_kwargs,
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:

        # 预测步骤中禁用梯度计算
        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids')  # 提取输出ID
            input_ids = inputs['input_ids']

            # 调用父类的预测步骤
            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

            # 调整生成的tokens
            generated_tokens = generated_tokens[:, input_ids.size()[1]:]
            labels = output_ids

            del inputs, input_ids, output_ids  # 删除临时变量
            torch.cuda.empty_cache()  # 清理CUDA缓存

        return loss, generated_tokens, labels  # 返回损失,生成的tokens和标签


@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None  # 训练数据文件路径
    val_file: Optional[str] = None  # 验证数据文件路径
    test_file: Optional[str] = None  # 测试数据文件路径
    num_proc: Optional[int] = None  # 处理数据的进程数

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix  # 获取数据文件格式

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        # 返回数据文件的字典,键为数据集分割类型,值为文件路径
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }


@dc.dataclass
class FinetuningConfig(object):
    data_config: DataConfig  # 数据配置

    max_input_length: int  # 最大输入长度
    max_output_length: int  # 最大输出长度
    combine: bool  # 是否组合

    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')  # 训练参数
    )
    peft_config: Optional[PeftConfig] = None  # 可选的Peft配置

    def __post_init__(self):
        # 设置评估策略
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
            self.training_args.evaluation_strategy = 'no'
            self.data_config.val_file = None
        else:
            self.training_args.per_device_eval_batch_size = (
                    self.training_args.per_device_eval_batch_size
                    or self.training_args.per_device_train_batch_size
            )

    @classmethod
    def from_dict(cls, **kwargs) -> 'FinetuningConfig':
        # 从字典中创建FinetuningConfig实例
        training_args = kwargs.get('training_args', None)
        if training_args is not None and not isinstance(
                training_args, Seq2SeqTrainingArguments
        ):
            gen_config = training_args.get('generation_config')
            if not isinstance(gen_config, GenerationConfig):
                training_args['generation_config'] = GenerationConfig(
                    **gen_config
                )
            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)

        data_config = kwargs.get('data_config')
        if not isinstance(data_config, DataConfig):
            kwargs['data_config'] = DataConfig(**data_config)

        peft_config = kwargs.get('peft_config', None)
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
        return cls(**kwargs)

    @classmethod
    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
        # 从文件中加载配置
        path = Path(path)
        parser = yaml.YAML(typ='safe', pure=True)
        parser.indent(mapping=2, offset=2, sequence=4)
        parser.default_flow_style = False
        kwargs = parser.load(path)
        return cls.from_dict(**kwargs)


def _load_datasets(
        data_dir: str,
        data_format: str,
        data_files: dict[NamedSplit, str],
        num_proc: Optional[int],
) -> DatasetDict:
    # 加载数据集
    if data_format == '.jsonl':
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct


class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files,
            self._num_proc,
        )

    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)  # 获取指定分割的数据集

    def get_dataset(
            self,
            split: NamedSplit,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )


def process_message(message):
    if 'tools' in message and message['role'] == 'system':
        for tool in message['tools']:
            parameters = tool['function']['parameters']['properties']
            tool['function']['parameters']['properties'] = \
                {k: v for k, v in parameters.items() if
                 v is not None}
    elif 'tools' in message:
        del message['tools']
    return message  # 处理消息


def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
        combine: bool,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            loss_masks = [False] * len(input_ids)
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            for j in range(last_assistant_index + 1, len(input_ids)):
                loss_masks[j] = True
        else:
            for message in conv:
                message = process_message(message)
                loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                input_ids += new_input_ids
                loss_masks += [loss_mask_val] * len(new_input_ids)

        input_ids.append(151336)  # EOS for chat
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                labels.append(input_id)
            else:
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])

    del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
    torch.cuda.empty_cache()

    return {'input_ids': batched_input_ids, 'labels': batched_labels}


def process_batch_eval(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
        combine: bool,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_output_ids = []

    for conv in batched_conv:
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            output_prompt, output_ids = (
                input_ids[:1],
                input_ids[last_assistant_index:],
            )
            output_ids.append(151336)
            batched_input_ids.append(
                input_ids[:max_input_length] + output_prompt[:1]
            )
            batched_output_ids.append(output_ids[:max_output_length])
        else:
            input_ids = [151331, 151333]
            for message in conv:
                if len(input_ids) >= max_input_length:
                    break
                else:
                    message = process_message(message)
                    new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                    if message['role'] == 'assistant':
                        output_prompt, output_ids = (
                            new_input_ids[:1],
                            new_input_ids[1:],
                        )
                        output_ids.append(151336)
                        batched_input_ids.append(
                            input_ids[:max_input_length] + output_prompt[:1]
                        )
                        batched_output_ids.append(output_ids[:max_output_length])
                    input_ids += new_input_ids

    del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
    torch.cuda.empty_cache()

    return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}


def load_tokenizer_and_model(
        model_dir: str,
        peft_config: Optional[PeftConfig] = None,
):
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16  # Must use BFloat 16
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16
        )
    return tokenizer, model


def compute_metrics(eval_preds: EvalPrediction, tokenizer):
    batched_pred_ids, batched_label_ids = eval_preds
    metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
        pred_txt = tokenizer.decode(pred_ids).strip()
        label_txt = tokenizer.decode(label_ids).strip()
        pred_tokens = list(jieba.cut(pred_txt))
        label_tokens = list(jieba.cut(label_txt))
        rouge = Rouge()
        scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
        for k, v in scores[0].items():
            metrics_dct[k].append(round(v['f'] * 100, 4))
        metrics_dct['bleu-4'].append(
            sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
    return {k: np.mean(v) for k, v in metrics_dct.items()}


@app.command()
def main(
        data_dir: Annotated[str, typer.Argument(help='')],
        model_dir: Annotated[
            str,
            typer.Argument(
                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
            ),
        ],
        config_file: Annotated[str, typer.Argument(help='')],
        auto_resume_from_checkpoint: str = typer.Argument(
            default='',
            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'
        ),
):
    ft_config = FinetuningConfig.from_file(config_file)  # 从配置文件中加载微调配置
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)  # 加载tokenizer和模型
    data_manager = DataManager(data_dir, ft_config.data_config)  # 创建数据管理器

    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    print('train_dataset:', train_dataset)  # 打印训练数据集
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if val_dataset is not None:
        print('val_dataset:', val_dataset)  # 打印验证数据集
    test_dataset = data_manager.get_dataset(
        Split.TEST,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            combine=ft_config.combine,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if test_dataset is not None:
        print('test_dataset:', test_dataset)  # 打印测试数据集

    model.gradient_checkpointing_enable()  # 启用梯度检查点
    model.enable_input_require_grads()  # 启用输入梯度

    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding='longest',
            return_tensors='pt',
        ),
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
    )

    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()  # 开始训练
    else:
        output_dir = ft_config.training_args.output_dir
        dirlist = os.listdir(output_dir)
        checkpoint_sn = 0
        for checkpoint_str in dirlist:
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
                if checkpoint > checkpoint_sn:
                    checkpoint_sn = checkpoint
        if auto_resume_from_checkpoint.upper() == "YES":
            if checkpoint_sn > 0:
                model.gradient_checkpointing_enable()
                model.enable_input_require_grads()
                checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                trainer.train()
        else:
            if auto_resume_from_checkpoint.isdigit():
                if int(auto_resume_from_checkpoint) > 0:
                    checkpoint_sn = int(auto_resume_from_checkpoint)
                    model.gradient_checkpointing_enable()
                    model.enable_input_require_grads()
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                    trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                print(auto_resume_from_checkpoint,
                      "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")

    if test_dataset is not None:
        trainer.predict(test_dataset)  # 进行测试集预测


if __name__ == '__main__':
    app()

这段代码是GLM4内置好的微调代码,finetune.py文件。其构建一个微调模型的完整过程,主要用于机器学习和自然语言处理任务。
可根据自己的具体需求替换一些字段内容,以适应项目环境和数据。
这里列出几个可能需要替换或调整的主要部分:

1、数据文件路径和名称:

  • 在 DataConfig 类中,train_file、val_file 和 test_file 应根据你的实际数据文件位置和名称进行更改。这些文件包含了训练、验证和测试的数据集。

2、模型目录:

  • 在函数 main 中的 model_dir 参数,你需要提供预训练模型的目录或者Hugging Face上的模型ID。如果你使用本地模型,请确保路径正确。

3、配置文件路径:

  • 在函数 main 中的 config_file 参数指向微调配置的文件路径。你需要根据实际使用的配置文件位置进行修改。

4、模型和训练参数:

  • 在 FinetuningConfig 类中,max_input_length、max_output_length、combine 和 training_args 等参数可能需要根据你的具体任务和数据调整。这些参数影响输入数据的处理和模型的训练效果。

5、特定功能的实现方式:

  • 如果使用了特定的预处理或后处理功能,如 process_batch 和 process_batch_eval 函数中的 tokenizer.apply_chat_template,你可能需要根据你的模型和tokenizer的实际情况进行调整。

6、PEFT配置:

  • 如果你使用了 PEFT (Parallel Execution of Fine-Tuning) 或类似技术,需要在 peft_config 中指定相关配置。如果没有使用,可能需要删除或替换相关代码块。

7、指定输出目录:

  • 在 Seq2SeqTrainingArguments 中的 output_dir 参数需要指定一个你希望保存模型和输出日志的目录。

8、性能调优参数:

  • 如 gradient_accumulation_steps、per_device_train_batch_size 等参数可能需要根据你的硬件配置进行调整,以优化模型训练的性能和资源利用率。

启动命令:
python finetune.py --model_dir /data/glm4-9b-chat/ --config_file config/lora.yaml

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值