微调ChatGLM2-6b模型解决文本二分类任务代码解读

微调ChatGLM2-6b模型解决文本二分类任务代码解读

1. 导入包

import pandas as pd

导入pandas,用于后续读取和处理CSV数据。pandas为处理表格和时间序列数据提供了高效的数据结构。

2. 加载数据

train_df = pd.read_csv('./csv_data/train.csv')
test_df = pd.read_csv('./csv_data/test.csv')

print(train_df.info())
  • pd.read_csv()读取CSV文件,返回DataFrame格式的数据。
  • train_df和test_df分别存储训练集和测试集。
  • print(train_df.info()) 打印数据概览,包括列名、类型、非空值个数等信息,对数据有个直观了解。

3. 制作数据集

res = []

for i in range(len(train_df)):
  # 构造每一项
  tmp = {
      "instruction": "判断是否医学论文",
      "input": "标题:"+title+" 摘要:"+abstract,  
      "output": label
  }
  
  res.append(tmp)

# 保存到json文件
import json
with open('paper_label.json', 'w') as f:
  json.dump(res, f)
  • 利用循环,遍历训练集每一行,根据标题和摘要构造prompt格式的训练样本。
  • prompt包含三项:instruction(提示词),input(输入文本),output(目标输出)。
  • 将训练样本保存在JSON文件中,方便后续训练。

4. 微调模型

from peft import PeftModel
from transformers import AutoTokenizer, AutoModel 

model_path = "./chatglm2-6b"

model = AutoModel.from_pretrained(model_path) 
tokenizer = AutoTokenizer.from_pretrained(model_path)

model = PeftModel.from_pretrained(model, 'output_dir')
  • 从huggingface加载预训练语言模型ChatGLM2-6B和相应的tokenizer。
  • 使用peft工具进行微调,生成适用于当前任务的fine-tuned模型。
  • 微调需要准备prompt格式的数据。

5. 定义预测函数

def predict(text):

  input = f"判断是否医学论文:{text}"
  
  response = model.chat(tokenizer, input)

  return response
  • 封装预测逻辑在一个函数中。
  • 根据输入文本构造prompt输入,调用模型生成响应。
  • 将模型输出返回。

6. 预测测试集

predictions = []

for i in range(len(test_df)):
  
  text = 获取测试集论文信息
  
  pred = predict(text)
  
  predictions.append(pred)
  • 循环遍历测试集每一项。
  • 构造输入文本,调用预测函数生成结果。
  • 将预测结果保存到predictions列表中。

7. 生成提交文件

submit = test_df[['id', 'keywords', 'predictions']]

submit.to_csv('submit.csv', index=False) 
  • 从测试集中取出需要的列,和预测一起构造提交文件格式。
  • to_csv()可将DataFrame写入csv文件。

总结

ChatGLM2-6B是一个基于Transformer架构预训练的巨大语言模型,它通过在大规模文本语料上进行自监督学习,获得了强大的语言理解和生成能力。

要利用ChatGLM2-6B进行文本二分类,主要分以下几个步骤:

  • 数据准备
    收集文本分类任务需要的训练集和验证集/测试集,并转换成Prompt格式。prompt包含分类说明、文本输入和期望输出。
  • 模型微调
    在预训练ChatGLM2-6B模型基础上,使用prompt格式的数据进一步训练,使模型适应文本二分类任务。这一步会稍微调节模型参数。
  • Prompt设计
    为模型设计合适的prompt模板,指导模型对新输入的文本进行判断和分类。
  • 模型推理
    对新输入的文本,先转换为prompt输入,然后输入到fine-tuned的ChatGLM2-6B中,从生成的响应中解析出分类结果。
  • 模型输出解析
    解析模型生成的响应,提取分类结果,转换为标准化的0/1输出。

通过微调+Prompt设计,ChatGLM2-6B可以高效地进行文本二分类,完全利用了其强大的语言理解能力,并优于传统的分类模型。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
对于ChatGLM2-6B的微调参数,你可以按照以下步骤进行操作: 1. 首先,根据你的需求,下载ChatGLM2-6B的代码仓库并安装所需的依赖包。你可以使用以下命令进行克隆和安装: ``` git clone https://github.com/THUDM/ChatGLM2-6B cd ChatGLM2-6B pip install -r requirements.txt cd ptuning/ pip install rouge_chinese nltk jieba datasets ``` 2. 接下来,你可以根据自己的数据集构建训练所需的数据集。 3. 在P-tuning v2的训练过程中,模型只保存了PrefixEncoder部分的参数。因此,在进行推理时,你需要同时加载原始的ChatGLM-6B模型和PrefixEncoder的权重。为了实现这一点,在evaluate.sh脚本中,你需要指定相应的参数。 总结起来,你可以通过自己验证并更换模型路径,使用自定义的数据集来微调ChatGLM2-6B模型微调后的参数可以在指定的checkpoint地址中找到。同时,在推理过程中,需要同时加载ChatGLM-6B的模型和PrefixEncoder的权重。希望对你有帮助!<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [ChatGLM2-6B、ChatGLM-6B 模型介绍及训练自己数据集实战](https://blog.csdn.net/dream_home8407/article/details/130099656)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值