lora
-
导入必要的库和模块:
导入各种库,包括 PyTorch、Transformers、DeepSpeed、Modelscope、PEFT 相关模块等。 -
seed_it
函数:
设置随机种子以确保结果的可重复性,使用多个库设置随机种子以确保所有随机数生成器的确定性。 -
torch_gc
函数:
清理 GPU 内存,释放未使用的缓存和 IPC 资源。 -
ModelArguments
数据类:
用于配置模型参数,主要是模型路径。 -
DataArguments
数据类:
用于配置数据参数,包括训练数据路径和评估数据路径。 -
TrainingArguments
数据类:
用于配置训练参数,包括缓存目录、优化器、最大序列长度、是否使用 LoRA 和系统提示词。 -
LoraArguments
数据类:
用于配置 LoRA 微调的参数,包括秩、alpha 参数、dropout 率、目标模块、权重路径、偏置类型和是否使用 QLoRA。 -
rank0_print
函数:
打印函数,仅在local_rank
为 0 时打印输出。 -
preprocess
函数:
数据预处理函数,将输入文本转换为模型可以处理的格式,包括 tokenization 和 padding。 -
SupervisedDataset
类:
监督数据集类,用于微调数据集的预处理。 -
LazySupervisedDataset
类:
延迟预处理数据集类,在每次获取数据时动态预处理数据。 -
make_supervised_data_module
函数:
创建数据模块,包括加载和预处理训练和评估数据集。此处合并了多个数据集。 -
train
函数:
训练函数,加载模型和分词器,配置 LoRA 模型,创建 Trainer 并开始训练。 -
merge_model
函数:
模型合并函数,将微调后的 LoRA 模型与基础模型合并并卸载。 -
test_lora_model
函数:
测试微调后的 LoRA 模型,通过一个示例对话来验证模型效果。 -
__main__
部分:
主程序入口,设置随机种子,合并两个数据集进行微调,微调后测试生成示
例响应。