ChatGLM多轮对话微调-多轮对话训练数据的自动生成(标注)

        通常使用大模型进行业务数据微调的时候,需要对历史对话数据进行细粒度的整理,比如:1-3轮对话数据的微调,以便模型能够学会多轮对话。以ChatGLM为例,微调对话任务的时候,微调会导致模型的理解能力别削弱(无法理解相似语义的输入),即当输入数据prompt的分布与训练数据分布不一致时,模型不会按照训练集的response进行输出,而是使用模型原有的能力进行输出,模型输出结果出现不可控的情况。这个时候需要对输入的数据进行数据增强,数据的方法很多,但个人认为对于样本比较少的对话,最有效的方式应该是人工进行标注,即人工写出输入数据prompt的各种可能的语义相似的样本来(根据对数据增强方式的理解,如:释义、采样和加噪),有人说数据增强的方式怎么做也无法与人工标注的效果相比,只适合于写论文,这里不做评价和扩展。仅针对多轮对话进行1-3轮的对话数据自动标注说明。

       假定历史对话的格式为:

#test.txt
坐席:Y0
客户:X0
坐席:Y1
客户:X1
.....
坐席:Yn
客户:Xn

        说明:1轮指的是n=0时,坐席和客户说的话作为输入,3轮指的是n=2时,坐席和客户说的话作为输入。

        最多3轮,个人认为1-3轮的叠加能解决大部分场景的多轮对话的问题。

1.读取历史对话文本test.txt

import pandas as pd
data =[]
file_name = 'test'
with open(f'{file_name}.txt') as f:
    data = f.readlines()
print(data)

2.自动生成1-3轮对话标注

#to ChatGLM格式
lines=[]
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-3:i-1]
        if len(temps) == 2:
            history = [[temps[0],temps[1]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-5:i-1]
        if len(temps) == 4:
            history = [[temps[0],temps[1]],[temps[2],temps[3]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
prompt=''
for i,row in enumerate(data):
    if i>0 and i%2==0:
        temps = data[i-7:i-1]
        if len(temps) == 6:
            history = [[temps[0],temps[1]],[temps[2],temps[3]],[temps[4],temps[5]]]
        else:
            history = [['','']]
        lines.append({"prompt":prompt.replace('\n',''),"response":row.replace('\n',''),"history":history})
        prompt = row
    else:
        prompt = row
    if i==len(data)-1:
        prompt = ''
print(lines)

3.显示自动标注的结果

df = pd.DataFrame(lines)
df

4.保存生成的标注数据

import json
with open ('train.json','w') as f:
    json.dump(lines,f)

  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值