在微调大语言模型的时候,loss是我们判断模型训练效果的一大重要指标。loss通常代表着模型效果和预期结果之间的差距,当这个差距收敛或者变化不大时,我们认为模型训练停止,此时通过validation来寻找最佳的模型超参。
而loss为0通常是被认为是异常情况,因为在大语言模型的训练中,loss的计算是每一个token位置对下一个token的预测值和预期值的交叉熵函数。 哪怕在一个样本中,模型的输出和训练集的结果完全一致几乎不可能。所以被判定为异常情况。
两个解决办法:
1. 查看label的设置,通常在dataloader的data_collator参数中会有涉及,将label设为-100的地方去掉(这里说的是自己设置的-100,hugging face中的函数会将特殊token mask掉,这个不用去掉)
2. 将lora config中的参数load_in_4bit设置为True
下面分析可能的原因:
第一种是我们在mask label的过程中,设置错误。这里需要解释的是,什么是mask label。微调大模型时,训练集的形式是(prompt,responce)。由于模型学习的是回答问题,不需要去拟合问题的语言,所以我们只计算responce部分的loss,而不计算prompt的部分的loss。于是,在设置label的时候,我们会将prompt部分的label设置为-100(交叉熵规定对于token中含有-100的loss为0),以此避免其loss的计算。
这里如果不会设置,可以直接看这个codebase:https://github.com/allenai/open-instruct
或者参考下面的代码:
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length, add_bos=False):
'''
Here we assume each example has 'prompt' and 'completion' fields.
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
and it doesn't make sense to follow directly with the completion.
'''
# if prompt doesn't end with space and completion doesn't start with space, add space
if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')):
example_text = example['prompt'] + ' ' + example['completion']
else:
example_text = example['prompt'] + example['completion']
example_text = example_text + tokenizer.eos_token
if add_bos:
example_text = tokenizer.bos_token + example_text
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True)
# mask the prompt part for avoiding loss
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
第二个可能的原因是lora的config设置出错。这里需要将lora config的load_in_4bit参数设置为True,增强数据的稳定性。
希望对正在学习的小伙伴有所帮助!