微调ChatGLM2-6b模型解决文本二分类任务代码解读
1. 导入包
import pandas as pd
导入pandas,用于后续读取和处理CSV数据。pandas为处理表格和时间序列数据提供了高效的数据结构。
2. 加载数据
train_df = pd.read_csv('./csv_data/train.csv')
test_df = pd.read_csv('./csv_data/test.csv')
print(train_df.info())
- pd.read_csv()读取CSV文件,返回DataFrame格式的数据。
- train_df和test_df分别存储训练集和测试集。
- print(train_df.info()) 打印数据概览,包括列名、类型、非空值个数等信息,对数据有个直观了解。
3. 制作数据集
res = []
for i in range(len(train_df)):
# 构造每一项
tmp = {
"instruction": "判断是否医学论文",
"input": "标题:"+title+" 摘要:"+abstract,
"output": label
}
res.append(tmp)
# 保存到json文件
import json
with open('paper_label.json', 'w') as f:
json.dump(res, f)
- 利用循环,遍历训练集每一行,根据标题和摘要构造prompt格式的训练样本。
- prompt包含三项:instruction(提示词),input(输入文本),output(目标输出)。
- 将训练样本保存在JSON文件中,方便后续训练。
4. 微调模型
from peft import PeftModel
from transformers import AutoTokenizer, AutoModel
model_path = "./chatglm2-6b"
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = PeftModel.from_pretrained(model, 'output_dir')
- 从huggingface加载预训练语言模型ChatGLM2-6B和相应的tokenizer。
- 使用peft工具进行微调,生成适用于当前任务的fine-tuned模型。
- 微调需要准备prompt格式的数据。
5. 定义预测函数
def predict(text):
input = f"判断是否医学论文:{text}"
response = model.chat(tokenizer, input)
return response
- 封装预测逻辑在一个函数中。
- 根据输入文本构造prompt输入,调用模型生成响应。
- 将模型输出返回。
6. 预测测试集
predictions = []
for i in range(len(test_df)):
text = 获取测试集论文信息
pred = predict(text)
predictions.append(pred)
- 循环遍历测试集每一项。
- 构造输入文本,调用预测函数生成结果。
- 将预测结果保存到predictions列表中。
7. 生成提交文件
submit = test_df[['id', 'keywords', 'predictions']]
submit.to_csv('submit.csv', index=False)
- 从测试集中取出需要的列,和预测一起构造提交文件格式。
- to_csv()可将DataFrame写入csv文件。
总结
ChatGLM2-6B是一个基于Transformer架构预训练的巨大语言模型,它通过在大规模文本语料上进行自监督学习,获得了强大的语言理解和生成能力。
要利用ChatGLM2-6B进行文本二分类,主要分以下几个步骤:
- 数据准备
收集文本分类任务需要的训练集和验证集/测试集,并转换成Prompt格式。prompt包含分类说明、文本输入和期望输出。 - 模型微调
在预训练ChatGLM2-6B模型基础上,使用prompt格式的数据进一步训练,使模型适应文本二分类任务。这一步会稍微调节模型参数。 - Prompt设计
为模型设计合适的prompt模板,指导模型对新输入的文本进行判断和分类。 - 模型推理
对新输入的文本,先转换为prompt输入,然后输入到fine-tuned的ChatGLM2-6B中,从生成的响应中解析出分类结果。 - 模型输出解析
解析模型生成的响应,提取分类结果,转换为标准化的0/1输出。
通过微调+Prompt设计,ChatGLM2-6B可以高效地进行文本二分类,完全利用了其强大的语言理解能力,并优于传统的分类模型。