进阶 baseline2【微调方向】 + 知识点讲解

#AI夏令营 #Datawhale #夏令营#这个夏令营不简单

1. 数据集制作

1.1 环境配置

先对原始群聊数据做初步抽取,设置讯飞3.5的api环境配置。和baseline1的配置一样。

!pip uninstall websocket-client
!pip install --upgrade spark_ai_python websocket-client


from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm


def chatbot(prompt):
    #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
    #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
    SPARKAI_APP_ID = ''
    SPARKAI_API_SECRET = ''
    SPARKAI_API_KEY = ''
    #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_DOMAIN = 'generalv3.5'
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    messages = [ChatMessage(
        role="user",
        content=prompt
    )]
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].message.content

与baseline1一样,这里不过多解释

1.2 数据处理Prompt

这里我们对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。

一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。

过长的上下文也会导致训练时长和费用倍增。

这个prompt相较于baseline01区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:

客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。

content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

****群聊对话****
{content}

****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''

1.3 训练数据集制作

jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个jsonl文件(每行是一个json数据的文件),格式如下:

jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
print(jsonl_data)

结果:

{'instruction': '假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。', 'input': '请调小空气净化器的湿度到1', 'output': '{"intent":"CONTROL","slots":[{"name":"device","normValue":"airCleaner","value":"空气净化器"},{"name":"insType","normValue":"set","value":"调小"},{"name":"attr","normValue":"humidity","value":"湿度"},{"name":"attrValue","normValue":"1","value":"1"}],"sample":"请调小空气净化器的湿度到1"}'}

print(jsonl_data["instruction"])
print(jsonl_data["input"])
print(jsonl_data["output"])

结果:

假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。
请调小空气净化器的湿度到1
{"intent":"CONTROL","slots":[{"name":"device","normValue":"airCleaner","value":"空气净化器"},{"name":"insType","normValue":"set","value":"调小"},{"name":"attr","normValue":"humidity","value":"湿度"},{"name":"attrValue","normValue":"1","value":"1"}],"sample":"请调小空气净化器的湿度到1"}

通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。

import json

# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
    data = json.load(file)
# 训练集制作

# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
    # 训练集行数(130)不符合要求,范围:1500~90000000
    # 遍历数据列表,并将每一行写入文件
    # 这里为了满足微调需求我们重复12次数据集 130*12=1560
    
    for line_data in tqdm(data):
        line_input = line_data["chat_text"] 
        line_output = line_data["infos"]
        content = line_input
        
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        # print(res)
        line_write = {
            "instruction":jsonl_data["instruction"],
            "input":json.dumps(res, ensure_ascii=False),
            "output":json.dumps(line_output, ensure_ascii=False)
        }
        # 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
        for time in range(12):
            file.write(json.dumps(line_write, ensure_ascii=False) + '\n')  # '\n' 用于在每行末尾添加换行符

这段代码用于生成自定义训练数据集的Python脚本片段,目的是为了训练一个能够从群聊对话中提取特定信息的AI模型。

  • 目标定义:创建一个训练数据集文件traindata.jsonl,用于微调模型,使其能分析群聊记录并提取关于客户的基本信息、意向预算、购买准备情况及跟进计划等四类关键数据。
  • 数据准备
    • 使用with open...语句以写入模式打开traindata.jsonl的文件,设置编码为utf-8以支持中文等多字节字符。
  • 数据增强
    • 通过tqdm库来显示进度条,遍历data列表中的每一条记录。
    • 对于每条记录,构造一个复杂的prompt(提示),它是一个Markdown格式的字符串,包含了固定的指令说明和动态插入的群聊对话内容。
    • 调用chatbot(prompt=prompt)的函数,根据构造的prompt生成分析结果。生成的响应res应包含所需的信息提取结果。
    • 将原始数据、生成的响应以及期望的输出整合成一个新的字典line_write,结构遵循特定的格式,便于模型训练使用。
  • 数据扩充
    • 由于原始数据集大小(130行)不足以满足训练需求(至少1500行),通过一个循环(for time in range(12)),将每条处理后的数据重复写入文件12次,从而达到扩容目的。
  • 文件写入
    • 使用json.dumps()将字典转换为JSON格式的字符串,并通过file.write()将这个字符串写入文件,每条数据后添加换行符\n以便于区分不同的记录。

构建一个针对特定任务(从群聊中提取客户信息)的训练数据集,通过人工构造提示和利用模型生成响应,结合数据重复策略以满足训练集的最小规模要求。

1.4 测试集数据制作

测试数据和训练数据相似,都是通过api清洗后存储。

# 验证集制作(提交版本)
# input,target

import json

# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
    data_test = json.load(file)

这里的验证数据我们以csv文件存储,有input和target两列

import csv

# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
    # 创建一个csv writer对象
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["input","target"])
    # 遍历数据列表,并将每一行写入CSV文件
    for line_data in tqdm(data_test):
        content = line_data["chat_text"]
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        
        # print(line_data["chat_text"])
        ## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
        line_list = [res, "-"]   
        csvwriter.writerow(line_list)
        # break
  1. 初始化CSV文件:使用open函数以写入模式('w')打开(或新建)文件test.csv,同时设置newline=''以避免在不同操作系统间产生额外的空行问题,并指定编码为utf-8确保文本兼容性。接着,创建一个csv.writer对象来写入CSV数据。
  2. 写入表头:通过csvwriter.writerow(["input","target"])写入CSV的第一行作为列标题,分别代表“输入”和“目标”两列。
  3. 处理和写入数据
  • 遍历数据列表data_test,每条数据包含键chat_text,代表一段群聊对话内容。
  • 对于每条数据,构造一个详细的分析提示prompt,内容格式与之前类似,要求从群聊中分析提取特定客户信息。
  • 调用chatbot(prompt=prompt)函数获取分析结果res
  • 将结果和一个占位符"-"(代表目标列数据)作为一个列表line_list,通过csvwriter.writerow(line_list)写入CSV文件的一行。

2. 模型微调

在微调平台中进行基本配置和数据配置,选择Spark Pro模型。训练完成后,发布服务。拿到resourceId、APPID、APIKey、APISecret

3. 微调推理

# 定义写入函数

def write_json(json_file_path, data):
    #"""写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

在main.ipynb的微调推理部分填入APPID、APIKey、APISecret

在SparkApi.py文件的108行,引号中填入resourceId

import SparkApi
import json
#以下密钥信息从控制台获取
appid = ""     #填写控制台中获取的 APPID 信息
api_secret = ""   #填写控制台中获取的 APISecret 信息
api_key =""    #填写控制台中获取的 APIKey 信息

#调用微调大模型时,设置为“patch”
domain = "patchv3"

#云端环境的服务地址
# Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat"  # 微调v1.5环境的地址
Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat"  # 微调v3.0环境的地址


text =[]

# length = 0

def getText(role,content):
    jsoncon = {}
    jsoncon["role"] = role
    jsoncon["content"] = content
    text.append(jsoncon)
    return text

def getlength(text):
    length = 0
    for content in text:
        temp = content["content"]
        leng = len(temp)
        length += leng
    return length

def checklen(text):
    while (getlength(text) > 8000):
        del text[0]
    return text
    


def core_run(text,prompt):
    # print('prompt',prompt)
    text.clear
    Input = prompt
    question = checklen(getText("user",Input))
    SparkApi.answer =""
    # print("星火:",end = "")
    SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
    getText("assistant",SparkApi.answer)
    # print(text)
    return text[-1]['content']

text = []
res = core_run(text,'你好吗?')

使用WebSocket连接到星火(Spark)平台API,旨在与一个经过微调的大规模语言模型进行交互。

导入与初始化

  • 引入了自定义的SparkApi模块,该模块应包含与星火API交互的方法,如建立WebSocket连接、发送消息、接收响应等。
  • 设置了三个必要的认证参数:appidapi_secretapi_key,这些值需从星火平台控制台获取。
  • 定义了domain变量,其值为"patch",意味着将要使用微调模型的特定版本(patchv3)。
  • 指定了WebSocket服务地址,这里使用的是针对微调v3.0环境的地址。

函数定义

  • getText(role, content):此函数接收角色(如"user"或"assistant")和内容,构建一个包含角色与消息内容的字典,并将其添加到全局列表text中,最后返回整个列表。
  • getlength(text):计算text列表中所有消息内容的总长度。
  • checklen(text):若消息内容总长度超过8000字符,则从列表开头开始删除消息,直至总长度不超过限制。这个过程确保了发送给API的消息不会因为过长而遇到问题。
  • core_run(text, prompt):这是核心功能函数,它首先清空text列表,然后构造一个新的消息(用户提问),检查并确保其长度合规,通过SparkApi.main方法发起请求到星火API,获取回答,并最终返回模型的回复内容。

主流程

  • 初始化全局变量text为空列表。
  • 调用core_run(text, '你好吗?')函数,向模型发送一条简单的问候消息"你好吗?",并打印出模型的回复内容。
import pandas as pd
import re

# 读取Excel文件
df_test = pd.read_csv('test.csv',)
data_dict_empty = {
                "基本信息-姓名": "",
                "基本信息-手机号码": "",
                "基本信息-邮箱": "",
                "基本信息-地区": "",
                "基本信息-详细地址": "",
                "基本信息-性别": "",
                "基本信息-年龄": "",
                "基本信息-生日": "",
                "咨询类型": [],
                "意向产品": [],
                "购买异议点": [],
                "客户预算-预算是否充足": "",
                "客户预算-总体预算金额": "",
                "客户预算-预算明细": "",
                "竞品信息": "",
                "客户是否有意向": "",
                "客户是否有卡点": "",
                "客户购买阶段": "",
                "下一步跟进计划-参与人": [],
                "下一步跟进计划-时间点": "",
                "下一步跟进计划-具体事项": ""
            }
submit_data = []
for id,line_data in tqdm(enumerate(df_test['input'])):
    # print(line_data)
    content = line_data
    text = []
    prompt = json.dumps(content,ensure_ascii=False)
    
    # print(json.dumps(content,ensure_ascii=False))
    res = core_run(text,prompt)
    try:
        data_dict = json.loads(res)
    except json.JSONDecodeError as e:
        data_dict = data_dict_empty
    submit_data.append({"infos":data_dict,"index":id+1})
# 预计执行8min

处理从test.csv文件中读取的输入数据,通过调用core_run函数与AI模型交互,解析模型返回的结果,并将解析后的数据结构化,最后将这些数据收集到submit_data列表中。

  • 初始化:定义一个空列表submit_data,用于存储处理后的每条数据记录。
  • 遍历数据:使用enumerate(df_test['input'])遍历CSV文件中input列的所有数据项,其中id是当前行的索引(从0开始),line_data是该行的输入文本。
  • 准备与调用AI模型
    • line_data的内容转换为JSON字符串(并确保非ASCII字符被正确处理),作为与AI模型交互的输入prompt
    • 调用core_run(text, prompt)函数,其中text被初始化为空列表,这个调用会触发与AI模型的交互,并返回模型的响应字符串res
  • 解析响应
    • 尝试将模型返回的响应res解析为JSON格式,存储在data_dict中。
    • 如果解析过程中出现json.JSONDecodeError异常(即模型返回的不是有效的JSON格式数据),则使用之前定义的空模板字典data_dict_empty作为默认值,以确保数据结构的一致性.
  • 构建提交数据:对于每条处理过的数据,构建一个新的字典,包含解析后的信息字段(data_dict)和原始数据的索引(index),并将这个字典追加到submit_data列表中。
submit_data
write_json("submit.json",submit_data)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值