本文基于ADGEN广告文本数据集,采用P-tuningv2技术微调ChatGLM2的简单案例,深入源码剖析微调细节,把握微调核心。
目录
1.数据集介绍
ADGEN(广告数据集):希望通过一段吸引的多样的措辞来描述该产品,吸引用户购买,是一个广告文本。每条数据包括一个产品的广告文本,文件保存的格式是{“content”:“”,“summary”:“”},“content"包含产品描述的属性值对,即希望生成该产品的哪些描述属性,例如如果是上衣,希望可以给定"材质”,“颜色”,“风格”,“图案”等描述生成包括以上属性描述的一段广告文本,“summary”对应特定产品的限定属性生成的广告示例。
该数据集可以通过 Google Drive 或者 Tsinghua Cloud 下载。 解压后有两个json文件,包括训练文件(train.json)和测试文件(test.json)
2. 模型准备
2.1 ChatGLM2-6b简介
- 清华大学的开发的开源中英双语对话模型 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 具有新的特性:
- 更强大性能:ChatGLM2升级基座模型,使用GLM的混合目标函数。
- 更长的上下文:基于 FlashAttention 技术,有初代模型的2K------》扩展到32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。
- 缺点:当前版本的 ChatGLM2-6B 对单轮超长文档的理解能力有限
- 更高效的推理:基于 Multi-Query Attention 技术,推理速度提升了 42%
2.2 安装与配置
- 安装ChatGLM-6B模型
https://github.com/THUDM/ChatGLM-6B 下载zip安装包后解压到服务器上 - 在anaconda 创建虚拟环境conda create -n <环境名称> python=3.X,并配置:
- 激活刚创建的虚拟环境 conda activate <环境名称>
- chatglm2模型下载并解压完成后,cd 进入 requirements.txt 所在的文件目录
- 终端执行pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3. 微调及微调细节
3.1 参数设置说明
跟训练相关的参数设置主要在train.sh文件,推理相关的参数在evaluate.sh文件。
训练参数 总结如下,推理部分参数与训练部分大致相同。
注意:
- arguments.py详细说明模型参数类相关的默认值设置(ModelArguments类)和数据训练参数类的默认值设置(DataTrainingArguments类)
- 请不要在train.sh 和 evaluate.sh 中用“#”设置注释,导致该变量参数失效,系统无法识别。
3.2 LLM输入和标签构建说明
train.sh 和 evaluate.sh 设置的prompt_column 和 response_column变量 最后经过main.py process_function_train函数构成了符合LLM的输入和输出。通过下面红框的这行代码,我们知道了json数据集中键值对是如何产生了LLM所需要的输入和输出。这启发我们如何定制私人数据集相应的LLM输入和输出。
可以看到query:查询问题的文本是由examples[prompt_column][i]得到,answer:回复是由examples[response_column][i]得到。
例如:
1 . 如果是多段输入进行拼接,得到一段输出的情况:
假设 josn数据集是有{“question”,“prompt”,“answer”}三段描述,其中我们希望query是由"question"+“prompt"构成,answer对应"answer”。这样设置的初衷可能是因为不同的question对应了不同的prompt文本。
我们可以首先对:
训练文件:train.sh 对应的部分进行修改:
--prompt_column [["question","prompt"]] \
--response_column answer\
其次对main.py 文件对应的的preprocess_function_train部分:
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query = "".join([examples[prompt_column][0][i] for i in range(len(examples[prompt_column][i]))])
answer = examples[response_column][i]
- 如果我们希望对query是有"question"+ 共同的“prompt”构成,answer对应"answer"。这样设置的初衷可能是希望让它更好学习某种特定规则的输出。
训练文件:train.sh 对应的部分进行修改:
--prompt_column question \
--response_column answer \
其次对main.py 文件对应的的preprocess_function_train部分:
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query, answer = examples[prompt_column][i]+"回答时,请加上前缀‘你好’。", examples[response_column][i]
基于以上,有几点认识:
- 训练集和测试集分别只有一个json或csv文件。
- 每条数据(以json为例)需要包括输入和输出的键值对,但是输入和输出的键值对并一定只有一个
- 输入和输出键值对名称可以改变,不一定设置为content和summary
- 结合train.sh和preprocess_function_train函数,可以更好根据自定义数据集构建合适的LLM输入和标签。
3.3 微调实现
-
- 首先进入配置了ChatGLM的虚拟环境,conda acitivate <环境名>
-
- 终端执行命令进入文件夹 cd <文件夹目录>
-
- 修改train.sh文件参数后,终端执行:bash train.sh(注意,sh文件里不能用#来标注注释,否则会导致该变量系统无法识别)
-
- 查看输出训练日志
注意:
- 查看输出训练日志
- sh文件不能带注释
- model_name_or_path 如果为“THUDM/chatglm-6b” 类似的带有“THUDM”会每次调用都从网络上下载模型,最好改为自己本地保存模型的路径。
输出打印日志,调用成功。
4. 评估细节
4.1 评估实现
经过p-tuning v2训练后保存的参数只有PrefixEncoder 部分的参数,所以推理时需要加载原始ChatGLM-6B模型参数以及PrefixEncoder的权重。推理部分修改主要在evaluate.sh文件。
修改地方注意两点
- 将 evaluate.sh 中的 CHECKPOINT 更改为训练时保存的 checkpoint 名称,同时如果你想要从本地加载模型,可以将model_name_or_path中的THUDM/chatglm-6b改为你本地的模型路径。
- 其他参数设置和训练一致。
执行成功后会输出 各项指标:
4.2 评估指标说明
BLEU:
- BLEU(Bilingual Evaluation Understudy)根据精确率(precision)来衡量机器翻译质量,BLEU 值越高,机器翻译结果与人工翻译结果越相似。
ROUGE:
- 全称是 (Recall-Oriented Understudy for Gisting Evaluation),根据召回率(recall)来衡量机器翻译质量
- ROUGE-N: 在 N-gram 上计算召回率
- ROUGE-L: 考虑了机器译文和参考译文之间的最长公共子序列
- ROUGE-W: 改进了ROUGE-L,用加权的方法计算最长公共子序列
Q : 什么是n-gram?
- n代表连续的n个词的组合,它的值可以是1,2,3,或者更高
- 当n=1时,我们称之为unigram,即1-gram,指单个词语,如句子“我喜欢学习自然语言处理”,1-gram:[“我”, “喜欢”, “学习”, “自然语言处理”, “。”]
- 当n=2时,我们称之为bigram,即2-gram,指相邻两个词语组合,如句子“我喜欢学习自然语言处理”,2-gram为:[“我喜欢”, “喜欢学习”, “学习自然语言处理”, “自然语言处理。”]
- 当n=3时,我们称之为trigram,即3-gram,指相邻三个词语组合,如句子“我喜欢学习自然语言处理”,3-gram为:[“我喜欢学习”, “喜欢学习自然语言处理”, “学习自然语言处理。”]
- 使用n-gram可以捕捉一定长度的上下文信息,有助于更好地理解文本和评估翻译质量。
BLEU 和ROUGE的计算
- 参考翻译:“今天天气晴朗。”
- 系统生成的翻译: “今天的天气是晴朗的。”
a. BLEU指标计算:
- 首先将 参考翻译 和 系统生成的翻译 拆分成 n-gram序列。
比如:参考翻译的1-gram:[“今天”, “天气”, “晴朗”, “。”] 而 系统生成的翻译的1-gram:[“今天”, “的”, “天气”, “是”, “晴朗”, “的”, “。”] - 接下来,计算两者n-gram的匹配数。例如,1-gram中有3个匹配:[“今天”, “天气”, “晴朗”]
- 计算精确度:系统生成的翻译中n-gram匹配的个数/系统生成的翻译中n-gram的总个数
- 考虑到较长的翻译可能具有较高的n-gram匹配,使用短文本惩罚(brevity penalty)来调整精确度。防止短翻译在BLEU中得分过高。
- 计算BLEU得分:BLEU = 短文本惩罚 * exp(1/n * (log(p1) + log(p2) + … + log(pn)))
其中,p1, p2, …, pn是1-gram, 2-gram, …, n-gram的精确度,n是n-gram的最大长度。
b. ROUGE指标计算:
- 用于评估文本摘要任务,着重于信息完整性和涵盖程度,所以将参考翻译和系统生成的翻译视为两个文本摘要。
- 计算系统生成的翻译(Predict)中包含的n-gram在参考翻译(Label)中出现的次数。
- 计算召回率(recall):将匹配的n-gram总数除以参考翻译中的总n-gram数。
- ROUGE得分可以根据需要使用不同的n-gram大小,通常使用ROUGE-1、ROUGE-2和ROUGE-L。
- ROUGE-1 = 召回率(系统生成的1-gram匹配数 / 参考翻译中的1-gram总数)
- ROUGE-2 = 召回率(系统生成的2-gram匹配数 / 参考翻译中的2-gram总数)
- ROUGE-L = 最长公共子序列(Longest Common Subsequence,LCSS)的长度 / 参考翻译的总长度
BLEU指标 和 ROUGE指标 取值范围是[0,1],可以在main.py源码中看到,指标计算过程中将数值都乘以100。因此原本取值范围:[0,1],扩大后变成[0,100]。
另外需要注意的:
- BLEU-4并不是只看4-gram的情况,而是计算从1-gram到4-gram的累积分数,加权策略为 1-gram, 2-gram, 3-gram和4-gram 的权重各占25%。
- 默认情况下, sentence_bleu()和corpus_bleu()都是计算累积的4-gram BLEU分数的, 也称之为BLEU-4。
- BLEU中的smooth_function有7种方法,不同smooth方法结果差异较大,可以查阅论文:https://aclanthology.org/P02-1040.pdf
其他没有注意到的细节,欢迎评论区讨论。源码提供了很多值得深入的细节。