基于星火大模型的群聊对话分角色要素提取挑战-baseline步骤及精读

#AI夏令营 #Datawhale #夏令营

Step1:下载相关库(大概 10s)

安装环境

!pip install --upgrade -q spark_ai_python

Step2:配置导入

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json


#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'

添加个人配置,用于初始化和配置星火认知大模型(Spark 3.5 Max)的客户端设置

导入必要的模块

  • ChatSparkLLMChunkPrintHandler 来自 sparkai.llm.llm 模块,它们分别用于创建与星火模型交互的聊天LLM客户端和处理模型响应的分块打印处理器。

  • ChatMessage 用来构造发送给模型的消息对象。

星火认知大模型的配置参数:

  • SPARKAI_URL: 指定了连接到星火模型服务的WebSocket地址。这里使用的是针对Spark 3.5 Max版本的URL。

  • SPARKAI_APP_ID, SPARKAI_API_SECRET, SPARKAI_API_KEY: 这些是访问星火模型服务所需的认证信息,包括应用ID、API密钥和API密钥。在实际使用中,这些值需要从讯飞开放平台的控制台获取并填入。

  • SPARKAI_DOMAIN: 指定使用的星火模型领域或版本,这里是generalv3.5,适用于通用场景的3.5版本模型。

Step3:模型测试

def get_completions(text): 
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
get_completions(text)

设置文本,与大模型交互,大模型给出回答,以“你好”作为测试

get_completions 的作用是使用星火认知大模型(Spark)生成文本完成或回应。

  1. 函数定义def get_completions(text): 定义了一个接收单个参数 text 的函数,该参数预期是用户想要模型生成回答或完成的文本内容。

  2. 构造消息messages = [ChatMessage(role="user", content=text)] 创建了一个消息列表,其中包含一个元素,这个元素是 ChatMessage 对象,指明了角色为 "user" 并且内容为传入的 text

  3. 初始化 ChatSparkLLM 实例:这一部分实例化了 ChatSparkLLM 类,传入了一系列参数用于配置模型的访问方式,包括:

    • spark_api_url: 模型服务的WebSocket URL。

    • spark_app_id, spark_api_key, spark_api_secret: 分别是应用ID、API密钥和密钥,这些都是从讯飞开放平台获取的认证信息。

    • spark_llm_domain: 指定使用的模型领域版本,这里是 generalv3.5

    • streaming=False:表示是否采用流式传输模式接收模型回复,默认为False,即一次性返回完整回复

  4. 创建回调处理器handler = ChunkPrintHandler() 创建了一个 ChunkPrintHandler 对象,用于处理模型可能的分块输出,虽然在这里 streaming=False,所以该处理器实际上可能不会被触发。

  5. 调用模型生成回答a = spark.generate([messages], callbacks=[handler]) 使用之前设置的信息和用户输入的问题调用模型的 generate 方法生成回答。这里通过 callbacks 参数指定了一个处理器列表,用于处理模型输出过程中的事件,即使在非流式模式下,该机制也可能用于内部处理。

  6. 提取并返回回答return a.generations[0][0].text 从模型的响应中提取第一代(generations[0])的第一个回复([0]),然后返回其文本内容。这里的索引假设每次调用只会生成一个回复,且该回复直接位于列表的第一个位置。

  7. 测试模型配置:最后,通过调用 get_completions 函数并传入简单的测试文本 "你好" 来验证整个流程是否配置正确,并查看模型的回复。

Step4:数据读取

def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")

查看训练集数据对应的标签,大部分是空值

read_jsonwrite_json处理JSON文件。

  1. read_json(json_file_path):

    • 功能: 这个函数用于读取一个给定路径的JSON文件。

    • 参数:

      • json_file_path (str): 需要读取的JSON文件的路径。

    • 过程:

      • 使用 with open(json_file_path, 'r') as f: 打开指定路径的文件,这里的 'r' 表示以读取模式打开文件。

      • 使用 json.load(f) 将文件内容读取为Python数据结构(通常是字典或列表)。

      • 最后,关闭文件并返回读取到的数据。

  2. write_json(json_file_path, data):

    • 功能: 这个函数用于将给定的Python数据结构写入到一个JSON文件中。

    • 参数:

      • json_file_path (str): 要写入的JSON文件的路径。

      • data (any): 要写入的数据,可以是字典、列表等能够在JSON中表示的数据结构。

    • 过程:

      • 使用 with open(json_file_path, 'w') as f: 以写入模式打开指定路径的文件,如果文件已存在则会被覆盖,不存在则创建。

      • 使用 json.dump(data, f, ensure_ascii=False, indent=4) 将数据写入文件。其中,ensure_ascii=False 允许写入非ASCII字符(如中文),indent=4 则使得输出的JSON文件具有良好的可读性,每个层级缩进4个空格。

这段代码非常实用,常用于处理机器学习、数据分析等任务中,需要从JSON文件加载数据或将处理后的数据保存回JSON文件的场景。

Step5:Prompt设计

# prompt 设计
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。

表单格式如下:
info: Array<Dict(
    "基本信息-姓名": string | "",  // 客户的姓名。
    "基本信息-手机号码": string | "",  // 客户的手机号码。
    "基本信息-邮箱": string | "",  // 客户的电子邮箱地址。
    "基本信息-地区": string | "",  // 客户所在的地区或城市。
    "基本信息-详细地址": string | "",  // 客户的详细地址。
    "基本信息-性别": string | "",  // 客户的性别。
    "基本信息-年龄": string | "",  // 客户的年龄。
    "基本信息-生日": string | "",  // 客户的生日。
    "咨询类型": string[] | [],  // 客户的咨询类型,如询价、答疑等。
    "意向产品": string[] | [],  // 客户感兴趣的产品。
    "购买异议点": string[] | [],  // 客户在购买过程中提出的异议或问题。
    "客户预算-预算是否充足": string | "",  // 客户的预算是否充足。示例:充足, 不充足
    "客户预算-总体预算金额": string | "",  // 客户的总体预算金额。
    "客户预算-预算明细": string | "",  // 客户预算的具体明细。
    "竞品信息": string | "",  // 竞争对手的信息。
    "客户是否有意向": string | "",  // 客户是否有购买意向。示例:有意向, 无意向
    "客户是否有卡点": string | "",  // 客户在购买过程中是否遇到阻碍或卡点。示例:有卡点, 无卡点
    "客户购买阶段": string | "",  // 客户当前的购买阶段,如合同中、方案交流等。
    "下一步跟进计划-参与人": string[] | [],  // 下一步跟进计划中涉及的人员(客服人员)。
    "下一步跟进计划-时间点": string | "",  // 下一步跟进的时间点。
    "下一步跟进计划-具体事项": string | ""  // 下一步需要进行的具体事项。
)>

请分析以下群聊对话记录,并根据上述格式提取信息:

**对话记录:**
```
{content}
```

请将提取的信息以JSON格式输出。
不要添加任何澄清信息。
输出必须遵循上面的模式。
不要添加任何没有出现在模式中的附加字段。
不要随意删除字段。

**输出:**
```
[{{
    "基本信息-姓名": "姓名",
    "基本信息-手机号码": "手机号码",
    "基本信息-邮箱": "邮箱",
    "基本信息-地区": "地区",
    "基本信息-详细地址": "详细地址",
    "基本信息-性别": "性别",
    "基本信息-年龄": "年龄",
    "基本信息-生日": "生日",
    "咨询类型": ["咨询类型"],
    "意向产品": ["意向产品"],
    "购买异议点": ["购买异议点"],
    "客户预算-预算是否充足": "充足或不充足",
    "客户预算-总体预算金额": "总体预算金额",
    "客户预算-预算明细": "预算明细",
    "竞品信息": "竞品信息",
    "客户是否有意向": "有意向或无意向",
    "客户是否有卡点": "有卡点或无卡点",
    "客户购买阶段": "购买阶段",
    "下一步跟进计划-参与人": ["跟进计划参与人"],
    "下一步跟进计划-时间点": "跟进计划时间点",
    "下一步跟进计划-具体事项": "跟进计划具体事项"
}}, ...]
```
"""
  1. 表单格式定义:首先,定义了一个详细的JSON结构模板,其中包含了多个类别,如“基本信息”、“咨询类型”、“意向产品”等,每个类别下有具体的属性项,如“姓名”、“手机号码”、“意向产品”等。每个属性后都有类型定义,如string(字符串)或数组类型Array<string>,以及可能的默认值(如""表示空字符串)。

  2. 提取指令:明确指示模型需要精确匹配表单格式中的属性,不得擅自增加、减少或修改属性。强调输出严格遵循规范,确保信息的准确性和一致性。

  3. 对话记录占位符{content}作为对话记录的占位符,意味着实际运行时需要将具体的群聊对话内容填充到这个位置。

  4. 输出示例:提供了一个输出示例,展示了如何将提取的信息按照规定的JSON格式返回。每个信息项都以示例的形式给出,如"姓名": "姓名",实际应用中“姓名”字段应被真实姓名所替换。

Step6:主函数启动

import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def convert_all_json_in_text_to_dict(text):
    """提取LLM输出文本中的json字符串"""
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

# 查看对话标签
def print_json_format(data):
    """格式化输出json格式"""
    print(json.dumps(data, indent=4, ensure_ascii=False))



def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data

用于处理和验证从自然语言处理模型(如星火大模型)得到的、包含JSON片段的文本数据,确保数据符合预期的结构化格式,并能正确处理和展示这些数据。

  1. 定义异常类 JsonFormatError:当处理的JSON格式不符合预期时,抛出此异常,便于异常处理

  2. 函数 convert_all_json_in_text_to_dict(text):从文本中提取所有的JSON字符串并转换为Python字典列表。遍历文本字符,遇到左花括号 { 时,将其索引压入栈;遇到右花括号 } 时,弹出栈顶索引并检查栈是否为空,若为空则说明一个完整的JSON对象结束,使用 json.loads() 将其转换为字典并加入到结果列表中。

  3. 函数 print_json_format(data):以易读的格式打印出给定的JSON数据。使用 json.dumps() 方法,设定缩进为4,且允许中文字符不被转义输出。

  4. 函数 check_and_complete_json_format(data)

    • 首先,定义了预期的键值对结构,包括键名和对应的预期数据类型。

    • 函数接收一个数据列表作为输入,然后对每个元素进行检查和处理:

      • 确保输入数据是一个列表,否则抛出异常。

      • 遍历列表中的每个字典,检查并确保每个字典包含所有必需的键,对于不存在的键,根据其预期类型赋予默认值(空字符串或空列表)。

      • 验证每个键值对的实际类型是否与预期类型相符,如果类型不匹配或列表中的元素不是字符串,则抛出异常。

    • 最后,返回经过检查和补全后的数据列表。

主要用于处理和校验从文本中提取的结构化信息,确保它们满足特定的格式要求,以便进一步的程序处理或分析。例如,从星火大模型的输出中提取客户信息时,可以确保所有必要的字段都被正确提取且格式合规,从而提高数据处理的可靠性和效率。

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)

主要目的是从一组测试数据(test_data)中,针对每一项数据尝试调用函数get_completions来获取基于特定模板(PROMPT_EXTRACT)生成的完成信息,之后对这些信息进行处理并收集结果。如果过程中发生错误,会进行重试,并记录未能成功处理的数据。具体解析如下:

  1. 导入tqdm库:用于显示进度条,提高用户体验,让用户直观看到处理进度。

  2. 定义重试次数retry_count = 5表示对于每一个数据项,如果处理失败,最多尝试重新处理5次。

  3. 初始化结果列表和错误数据列表result用于存储处理成功的数据结果,error_data用于存储处理失败的数据,以便后续处理或分析。

  4. 循环处理测试数据:使用enumerate(test_data)遍历test_data,同时获取索引index和对应的数据dataindex += 1是因为原索引是从0开始,但示例中可能是希望显示用户友好的从1开始的索引。

  5. 尝试调用和处理数据:对于每一条数据,使用一个内部循环尝试最多retry_count次调用get_completions函数。该函数接收一个格式化后的模板字符串作为参数,模板字符串由PROMPT_EXTRACT.format(content=data["chat_text"])生成,其中data["chat_text"]是当前数据项中的聊天文本内容。

  6. 处理获取的结果:如果get_completions调用成功,其返回的结果将通过convert_all_json_in_text_to_dict函数转换成字典列表,进一步通过check_and_complete_json_format函数进行格式检查和补充缺失信息,最终将处理后的信息和索引一起存入result列表中,并标记此次处理成功(is_success = True),跳出循环。

  7. 异常处理:如果在尝试过程中遇到任何异常(例如网络错误、数据格式错误等),会捕获异常并打印出错的索引和错误信息,然后继续尝试直到达到最大重试次数。

  8. 记录失败数据:如果所有重试均失败,则将当前的data项的索引更新并添加到error_data列表中,用于记录哪些数据处理未成功。

这是一种常见的数据处理模式,它通过重试机制增强了数据处理的健壮性,并且提供了清晰的进度显示和错误处理逻辑,确保数据处理的可靠性,同时也方便后续对未成功处理的数据进行复查或再次处理。

# 故障数据处理

if error_data:
    retry_count = 10 # 重试次数
    error_data_temp = []
    while True:
        if error_data_temp:
            error_data = error_data_temp
            error_data_temp = []
        for data in tqdm(error_data):
            is_success = False
            for i in range(retry_count):
                try:
                    res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
                    infos = convert_all_json_in_text_to_dict(res)
                    infos = check_and_complete_json_format(infos)
                    result.append({
                        "infos": infos,
                        "index": data["index"]
                    })
                    is_success = True
                    break
                except Exception as e:
                    print("index:", index, ", error:", e)
                    continue
            if not is_success:
                error_data_temp.append(data)
        if not error_data_temp:
            break
    result = sorted(result, key=lambda x: x["index"])

对之前处理过程中未能成功的数据(error_data)进行额外重试和处理的部分。

  1. 重置重试次数:首先,将重试次数调整为10次,相较于初次处理的5次,这次给予更多机会以处理那些初次尝试中失败的数据。

  2. 初始化临时错误数据列表:创建error_data_temp用于暂存当前轮次中仍然处理失败的数据,避免直接修改原error_data导致循环逻辑混乱。

  3. 循环处理错误数据:进入一个无限循环,直到没有数据再需要重试(即error_data_temp为空)才会跳出循环。在循环内部:

    • 将当前error_data的内容复制给error_data_temp,并清空error_data,这样做的目的是为了在新一轮循环开始前保留上一轮未成功处理的数据,同时允许新一批失败数据临时存放于error_data中。

    • 使用tqdm遍历error_data_temp中的每个数据项,尝试再次调用get_completions函数处理数据,逻辑与之前相同:包括重试、异常捕获、成功处理的数据加入结果列表、以及失败数据暂时存回error_data_temp

    • 如果某条数据在这次循环中被成功处理,is_success会被设为True,从而不会再次进入error_data_temp,否则将继续保留在下次重试的队列中。

  4. 排序结果:当所有数据处理完毕,无更多错误数据需要重试时,使用sorted()函数根据索引值(index)对最终的result列表进行排序,确保结果按照原始数据的顺序排列。

设计了一个二次处理机制,专注于处理初次尝试中未能成功转化的数据,通过增加重试次数并采用循环逻辑,力求最大化数据处理的成功率,最后确保输出结果的顺序正确性。

Step7:生成提交文件

# 保存输出
write_json("output.json", result)

  • 27
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值