ssbuild大佬的chatglm_finetuning项目---train.py代码解读

# -*- coding: utf-8 -*-
# 查看日子tensorboard --logdir=. --bind_all   在当前目录下查找TensorFlow事件文件,启动TensorBoard服务器,并将其绑定到所有可用的网络接口,以便在本地网络上查看TensorBoard可视化结果。
import logging

import torch
from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments
from deep_training.nlp.models.chatglm import ChatGLMConfig, setup_model_profile
from deep_training.nlp.models.lora import LoraArguments
from deep_training.utils.trainer import ModelCheckpoint, SimpleModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.strategies import DeepSpeedStrategy
from transformers import HfArgumentParser

from data_utils import NN_DataHelper, train_info_args, get_deepspeed_config
from models import MyTransformer,ChatGLMTokenizer


class MySimpleModelCheckpoint(SimpleModelCheckpoint):
    def __init__(self, *args, **kwargs):
        super(MySimpleModelCheckpoint, self).__init__(*args, **kwargs)
        lora_args: LoraArguments = self.external_kwargs['lora_args']
        if lora_args.with_lora:
            self.weight_file = './best_ckpt'
            self.last_weight_file = './last_ckpt'

    # 加载模型权重的方法,首先根据传入的参数初始化一个MyTransformer模型,然后使用加载的权重文件,将模型的参数设置为加载的权重,最后返回模型对象。
    def load_model_from_ckpt(self):
        model_args = self.external_kwargs['model_args']
        training_args = self.external_kwargs['training_args']
        lora_args = LoraArguments.from_pretrained(self.last_weight_file)
        pl_module = MyTransformer(lora_args=lora_args,
                              config=config,
                              model_args=model_args,
                              training_args=training_args)


        pl_module.backbone.from_pretrained(pl_module.backbone.model,self.last_weight_file)
        return pl_module


    def on_save_model(
            self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
    ) -> None:

        lora_args : LoraArguments =  self.external_kwargs['lora_args']
        # 保存权重
        # 保存模型的方法,如果LoraArguments中不包含有with_lora字段,则直接调用父类的on_save_model()方法,保存模型权重。
        if not lora_args.with_lora:
            super(MySimpleModelCheckpoint, self).on_save_model(trainer, pl_module)
        else:
            # 如果包含有with_lora字段,则首先获取监控指标和验证集的性能指标,然后根据这些指标,更新当前最好的模型权重和最新的模型权重。
            # 如果当前最好的模型权重更新了,则保存权重到weight_file中。最后,无论最好的模型权重是否更新,都将当前模型权重保存到last_weight_file中。
            monitor_candidates = self._monitor_candidates(trainer)
            monitor_candidates.update(self.on_get_metric(trainer, pl_module))
            val = monitor_candidates.get(self.monitor, None)

            #保存loss最小权重
            if self.update_best(val):
                logging.info('epoch {} ,step {} , save best {}, {}\n'.format(monitor_candidates['epoch'],
                                                                             monitor_candidates['step'],
                                                                             self.best[self.monitor],
                                                                             self.weight_file))
                pl_module.backbone.save_pretrained(self.weight_file)
            #保存最新权重
            pl_module.backbone.save_pretrained(self.last_weight_file)
            # # 从最新权重加载模型
            # pl_module = self.load_model_from_ckpt()



            
if __name__ == '__main__':
    # 使用 Hugging Face 提供的 HfArgumentParser 工具解析命令行参数,并将其存储在四个变量中:model_args、training_args、data_args 和 lora_args。
    parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, LoraArguments))
    model_args, training_args, data_args, lora_args = parser.parse_dict(train_info_args)


    # 调用 setup_model_profile 函数来设置 PyTorch 模型,以便进行训练。
    setup_model_profile()
    # 调用 get_deepspeed_config 函数来获取一个 DeepSpeed 配置,这是一个分布式训练框架。
    deepspeed_config = get_deepspeed_config()

    # 保存最小loss模型
    # 根据是否开启 Lora 模式来选择使用 MySimpleModelCheckpoint 或 ModelCheckpoint 作为检查点回调函数,前者用于 Lora 模式,后者用于普通模式。
    if lora_args.with_lora:
        assert deepspeed_config is None,ValueError('lora mode does not support deepspeed')
        checkpoint_callback = MySimpleModelCheckpoint(
                              # monitor="loss",
                              every_n_epochs = 1,
                              every_n_train_steps=2000 // training_args.gradient_accumulation_steps,
                              #模型参数
                              model_args=model_args,
                              training_args=training_args,
                              lora_args=lora_args,)
    else:
        checkpoint_callback = ModelCheckpoint('./best_ckpt',
                                              # monitor='loss',
                                              save_weights_only=False,
                                              save_last=True,
                                              save_top_k=1,
                                              # every_n_train_steps=1000,
                                              every_n_epochs=1)

    # 根据当前机器的 GPU 数量,决定使用 ddp 还是 auto 策略进行分布式训练。如果 DeepSpeed 配置不为空,则使用 DeepSpeedStrategy 进行训练。
    strategy = 'ddp' if torch.cuda.device_count() > 1 else 'auto'
    if deepspeed_config is not None and len(deepspeed_config):
        strategy = DeepSpeedStrategy(config=deepspeed_config,)

    # 创建一个Trainer对象,并传入相应的参数
    trainer = Trainer(
        callbacks=[checkpoint_callback,LearningRateMonitor(logging_interval='step')],
        max_epochs=training_args.max_epochs,
        max_steps=training_args.max_steps,
        accelerator="gpu",
        devices=data_args.devices,
        enable_progress_bar=True,
        default_root_dir=data_args.output_dir,
        gradient_clip_val=training_args.max_grad_norm,
        accumulate_grad_batches=training_args.gradient_accumulation_steps,
        num_sanity_val_steps=0,
        strategy=strategy
        # precision=16,#半精度
    )

    # 创建一个NN_DataHelper对象,用于加载tokenizer和config。
    dataHelper = NN_DataHelper(model_args, training_args, data_args)

    # 调用dataHelper的load_tokenizer_and_config方法,加载tokenizer和config。
    tokenizer, config, _,_ = dataHelper.load_tokenizer_and_config(tokenizer_class_name=ChatGLMTokenizer,config_class_name=ChatGLMConfig)
    config.eos_token_id = 130005

    # 根据config的属性值,判断是否开启了ptuning v2或量化模型,并在必要时抛出异常。
    if config.pre_seq_len is not None and lora_args.with_lora:
        raise ValueError('with lora and ptuning v2 cannot open at the same time')

    if config.pre_seq_len is not None:
        if config.quantization_bit:
            raise Exception('量化模型不支持微调训练')

    # 额外参数
    # 将tokenizer和data_args保存到checkpoint_callback对象中。
    checkpoint_callback.tokenizer = tokenizer
    checkpoint_callback.data_args = data_args

    # 调用config的save_pretrained方法,保存config。
    config.save_pretrained('best_ckpt')

    # 根据data_args中的设置,缓存数据集。
    if data_args.do_train:
        dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False,shuffle=True,mode='train')
    if data_args.do_eval:
        dataHelper.make_dataset_with_args(data_args.eval_file, mode='eval')
    if data_args.do_test:
        dataHelper.make_dataset_with_args(data_args.test_file,mode='test')


    # 创建一个MyTransformer对象,用于实现模型的训练和预测。其中,config、model_args、training_args和lora_args均为该对象的初始化参数。
    pl_model = MyTransformer(config=config, model_args=model_args, training_args=training_args,lora_args=lora_args)


    # 代码定义了一个变量 ckpt_path,用于保存或加载训练过程中的最佳权重路径。
    ckpt_path = './best_ckpt/best.pt'
    # 代码使用了一个 if 语句,用于判断是否需要将模型转换为 ONNX 格式,如果不需要转换,则执行以下步骤:
    if not data_args.convert_onnx:
        #  只恢复权重 , 不恢复步数和优化器 ,
        #  如果想恢复步数, 修改 trainer.fit(pl_model, train_dataloaders=train_datasets,ckpt=ckpt_path)  注lora 当前不支持恢复步数。
        # if os.path.exists(ckpt_path):
        #     if not lora_args.with_lora:
        #         # 加载权重继续训练
        #         pl_model = MyTransformer.load_from_checkpoint(ckpt_path, config=config,model_args=model_args,training_args=training_args,lora_args=lora_args,strict=False)
        #     else:
        #         # 加载lora权重 继续训练  0.0.20版本支持lora 继续训练
        #         pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path=ckpt_path,lora_config=lora_args,strict=False)

        # 定义了一个函数 dataset_loader_filter_fn,这个函数的作用是打印出当前数据集的总数,并将数据集返回。
        def dataset_loader_filter_fn(dataset):
            print('*' * 30,'total',len(dataset))
            return dataset
        # 调用 dataHelper.load_distributed_random_sampler() 函数,这个函数用于加载数据集,其中包含了许多参数,
        # 例如训练数据的文件列表,是否需要加载数据到内存中等等。最终,函数返回一个数据集,这个数据集可以用于模型的训练。
        train_datasets = dataHelper.load_distributed_random_sampler(
            dataHelper.train_files,
            with_load_memory=True,
            collate_fn=dataHelper.collate_fn,
            batch_size=training_args.train_batch_size,
            drop_last=True,#多卡建议扔掉
            num_processes=trainer.world_size, process_index=trainer.global_rank,
            dataset_loader_filter_fn=dataset_loader_filter_fn
        )

        # 如果数据集不为空,则调用 trainer.fit() 函数,用于训练模型,其中的参数包括模型、数据集等。
        if train_datasets is not None:
            trainer.fit(pl_model, train_dataloaders=train_datasets)

    else:
        if not lora_args.with_lora:
            # 调用 MyTransformer.load_from_checkpoint() 函数,用于加载模型权重。
            pl_model = MyTransformer.load_from_checkpoint(ckpt_path, config=config,
                                                       model_args=model_args,
                                                       training_args=training_args,
                                                       lora_args=lora_args,strict=False)
            # input_sample = (
            #     ("input_ids", torch.ones(size=(1, 128), dtype=torch.int32)),
            #     ("attention_mask", torch.ones(size=(1, 1,128,128), dtype=torch.int32)),
            #     ("position_ids", torch.ones(size=(1, 2, 128), dtype=torch.int32)),
            # )
            # input_names = ("input_ids",'attention_mask','position_ids')
            # output_names = ("pred_ids",)
            # dynamic_axes = None or {"input_ids": [0, 1],
            #                         "attention_mask": [0, 0,1,1],
            #                         "position_ids": [0, 0,1],
            #                         "pred_ids": [0, 1]}
            # pl_module.convert_to_onnx('./best_ckpt/best.onnx',
            #                       input_sample=input_sample,
            #                       input_names=input_names,
            #                       output_names=output_names,
            #                       dynamic_axes=dynamic_axes)

            model = pl_model.get_glm_model()
            #调用 model.save_pretrained() 函数,将模型保存为 Hugging Face 模型。
            model.save_pretrained('huggingface_model',max_shard_size='10GB')
        else:
            # 如果需要进行 Lora 训练,则再次加载权重,并调用 pl_model.get_glm_model() 函数,得到训练好的模型。
            # 加载权重
            lora_args = LoraArguments.from_pretrained('./best_ckpt')
            pl_module = MyTransformer(lora_args=lora_args,
                                      config=config,
                                      model_args=model_args,
                                      training_args=training_args)
            # 二次加载权重
            pl_module.backbone.from_pretrained(pl_module.backbone.model, pretrained_model_name_or_path='./best_ckpt',lora_config=lora_args)

            model = pl_model.get_glm_model()

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
ChatGLM-6B源码是基于GLM的2D位置编码实现的。该位置编码的详细原理可以在原文《GLM: General Language Model Pretraining with Autoregressive Blank Infilling》中找到。在GitHub上,有一个微调ChatGLM-6B项目代码库,作者是mymusise。该项目使用Stanford Alpaca的52K数据集,并通过LoRA(低秩适应)的方式进行微调。在评测时,使用中文Rouge分数和BLEU-4指标,并将生成的结果保存在"./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt"文件中。 以上是关于ChatGLM-6B源码的一些解读。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [ChatGLM-6B模型结构组件源码阅读](https://blog.csdn.net/yjh_SE007/article/details/130728164)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [ChatGLM-6B的基座/部署/微调/实现:从GLM到6B的LoRA/P-Tuning微调、及6B源码解读](https://blog.csdn.net/v_JULY_v/article/details/129880836)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值