参考文档:Datawhale
大语言模型微调
大语言模型微调(Fine-tuning Large Language Models)是指在一个大规模预训练语言模型的基础上,进一步在特定数据集上进行训练,以适应特定任务或应用场景。以下是对大语言模型微调的详细介绍:
大语言模型预训练
大语言模型(如GPT-3、BERT等)通常先在大量文本数据上进行预训练。预训练的目标是让模型学习语言的基本结构、词汇、语法、语义等知识。预训练任务通常包括:
- 自回归任务(如GPT):根据前面的词预测下一个词。
- 掩码语言模型任务(如BERT):随机掩盖输入中的一些词,并让模型预测这些掩盖的词。
预训练阶段使模型获得了广泛的语言知识,但它们并未针对特定的任务进行优化。
微调的目的
微调的主要目的是将预训练语言模型的广泛语言能力转化为特定任务的性能。通过在特定任务数据集上进一步训练模型,可以让模型更好地适应该任务的需求。
微调的过程
微调通常包括以下几个步骤:
- 选择预训练模型:选择一个已经预训练的大语言模型作为微调的基础模型。
- 准备数据集:收集并准备用于微调的特定任务数据集。这些数据通常已被标注并分为训练集、验证集和测试集。
- 模型架构调整:如果需要,可以对模型架构进行调整。例如,增加特定任务的输出层。
- 微调训练:在特定任务数据集上进一步训练模型。训练过程中可能会使用不同的损失函数和优化器,以适应特定任务。
- 验证和测试:在验证集和测试集上评估微调后的模型性能。
微调的应用
微调可以应用于多种自然语言处理任务,包括但不限于:
- 文本分类:如情感分析、垃圾邮件检测等。
- 序列标注:如命名实体识别、词性标注等。
- 问答系统:如阅读理解、FAQ问答等。
- 文本生成:如摘要生成、对话生成等。
微调的优势
- 高效利用预训练知识:通过微调,可以充分利用预训练模型在大规模数据上学习到的语言知识。
- 适应特定任务:微调可以让模型在特定任务上表现出色,而不仅仅是一般的语言能力。
- 减少训练资源需求:相比从头训练一个模型,微调所需的计算资源和时间大大减少。
代码运行
想要跑通本Task的代码,可以直接参考如下视频:Datawhale AI夏令营第三期逻辑推理赛道baseline02跑通指南_哔哩哔哩_bilibili
在此感谢b站@ClaYyuy大佬提供的手把手教学!
接下来介绍一下本Task所使用的微调技术。
LoRA微调
LoRA介绍
LoRA(Low-Rank Adaptation)微调是一种高效的模型微调技术,特别适用于大型预训练语言模型的适应性调整。LoRA的核心思想是通过引入低秩矩阵来调整模型的权重,从而在不显著增加模型参数数量的情况下,实现对模型的微调。
LoRA的优势
- 可以针对不同的下游任务构建小型 LoRA 模块,从而在共享预训练模型参数基础上有效地切换下游任务。
- LoRA 使用自适应优化器(Adaptive Optimizer),不需要计算梯度或维护大多数参数的优化器状态,训练更有效、硬件门槛更低。
- LoRA 使用简单的线性设计,在部署时将可训练矩阵与冻结权重合并,不存在推理延迟。
- LoRA 与其他方法正交,可以组合。
LoRA的原理
详细介绍可参考如下文档:
- https://github.com/microsoft/LoRA?tab=readme-ov-file
- https://arxiv.org/pdf/2106.09685
- https://huggingface.co/docs/peft/quicktour
多路LLM投票
多路LLM投票是一种集成学习的方法,通过同时使用多个预训练的语言模型(LLM),对每个模型的输出进行投票,以得到更稳健和准确的最终结果。这个方法的基本思想是,通过结合多个模型的预测结果,可以减少单个模型的错误和偏差,从而提高整体性能。
实现步骤
- 加载多个LLM模型:加载多个不同的预训练语言模型。
- 生成预测:对每个输入,使用每个模型生成预测结果。
- 集成预测:对所有模型的预测结果进行集成,通常通过投票机制。
- 输出最终结果:根据投票结果确定最终的输出。
具体代码分析
查看Baseline代码,可以看到以下部分:
def most_frequent_char(char1, char2, char3):
# 创建一个字典来存储每个字符的出现次数
frequency = {char1: 0, char2: 0, char3: 0}
# 增加每个字符的出现次数
frequency[char1] += 1
frequency[char2] += 1
frequency[char3] += 1
# 找到出现次数最多的字符
most_frequent = max(frequency, key=frequency.get)
return most_frequent
这段代码的作用是找到出现次数最多的字符。
def process_datas(datas,MODEL_NAME):
results = []
# 送入多线程任务
for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
problem = data['problem']
for id,question in enumerate(data['questions']):
prompt = get_prompt(problem,
question['question'],
question['options'],
)
res,res1,res2 = api_retry(MODEL_NAME, prompt),api_retry(MODEL_NAME, prompt),api_retry(MODEL_NAME, prompt)
extract_response,extract_response1,extract_response2 = extract(res),extract(res1),extract(res2)
ans = most_frequent_char(extract_response,extract_response1,extract_response2)
data['questions'][id]['answer'] = ans
results.append(data)
return results
这段代码三次调用llm,做出现次数统计,最终返回投票数最多的结果。
当代码运行完毕后,可以下载upload.jsonl文件并提交,分数大概在0.7-0.75之间