介绍
本阶段任务为适配、并编写LoRA的训练代码,使得模型训练可引入Low-Rand Adaption技术,以大幅减少训练所需内存空间与训练时间。
问题
使用正常的训练代码,即不适用LoRA进行训练,会导致模型爆显存,这就导致无法进行训练。如下图所示:
可以看出实际所需显存是远超80GiB的。
使用LoRA训练
首先和正常训练一样,我们先通过get_model方法导入预训练模型。
peft_config
编写LoRA所需的peft_config
peft_config = LoraConfig(
target_modules=r'.*language_model.*\.query_key_value',
inference_mode=args.inference_mode,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout
)
其中target_modules根据选择的语言模型不同而做出变化,下面是参考表:
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
"gpt_bigcode": ["c_attn"],
"mpt": ["Wqkv"],
"RefinedWebModel": ["query_key_value"],
"RefinedWeb": ["query_key_value"],
"falcon": ["query_key_value"],
"btlm": ["c_proj", "c_attn"],
"codegen": ["qkv_proj"],
}
梯度适配
if args.gradient_checkpointing:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.language_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# model.language_model.apply(
# partial(model.language_model._set_gradient_checkpointig, value=True))
model.gradient_checkpointing_enable()
显存变化
将显存压缩至69GiB进行训练。
训练结果与分析
模型权重
模型训练的结果我们可以在output文件夹下看到,里面包括我们训练得到的权重参数pytorch_model.bin
微调结果分析
下面图片均来自wandb,设置相关wandb代码,训练时将相关数据传至该官网,然后进行官网会根据相关结果产生相对应的图片。
我们采用LoRA分别进行了根据不同的数据格式、学习率和Iteration等超参数进行了多次实验,下图只展示几个实验结果用于分析。
以上是多次调整参数和数据输入格式的不同结果图,Accuracy和Loss
其中Accuracy图中的紫色线条和Loss的绿色线条就是使用方案一的数据格式,可以明显这种方式,导致模型理解电影能力下降,精度较低;同时计算时间过长,损失下降缓慢且不稳定。
下面两张图是在训练mPlug-Owl模型时的显卡使用情况
微调时遇到的问题
由于我们刚开始微调mPlug-Owl模型,对这个不太熟悉,发生了一些问题,最典型的是大模型开始胡言乱语,如下:
通过排查,是因为使用LoRA进行训练时导致一些权重的Key发生了改变,这就导致导入模型是会出现Key不对应,从而导致有部分的模型权重使用的是随机初始化,从而导致模型胡言乱语。
解决方法就是重新配置一边peft_config,然后初始化后的符合LoRA模型的架构,在重新导入权重,即将模型创建后调整为符合LoRA训练的模型架构,再重新载入预训练权重。