基于星火大模型的群聊对话分角色要素提取挑战赛

#AI夏令营 #Datawhale #夏令营
在这篇文章中,我们将逐步讲解如何使用星火认知大模型(Spark 3.5 Max)进行群聊对话记录的结构化信息提取。本文将涵盖从安装库、配置导入、模型测试、数据读取到最终生成提交文件的完整过程。

步骤1:安装依赖

首先,安装需要的Python库:

!pip install --upgrade -q spark_ai_python

请注意,如果在安装过程中遇到如下警告,可以忽略:

WARNING: Skipping page https://mirror.baidu.com/pypi/simple/tenacity/ because the GET request got Content-Type: application/octet-stream.
WARNING: Skipping page https://mirror.baidu.com/pypi/simple/pip/ because the GET request got Content-Type: application/octet-stream.

步骤2:配置导入

接下来,我们需要配置导入必要的包和设置相关参数。

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json

SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
SPARKAI_DOMAIN = 'generalv3.5'

步骤3:模型测试

编写一个函数来测试我们的模型配置是否正确。

def get_completions(text):
    messages = [ChatMessage(role="user", content=text)]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
print(get_completions(text))

步骤4:数据读取

定义读取和写入JSON文件的函数,并读取训练和测试数据。

def read_json(json_file_path):
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")

步骤5:Prompt设计

为模型设计Prompt,确保提取信息的准确性和格式一致性。

PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。

表单格式如下:
info: Array<Dict(
    "基本信息-姓名": string | "",
    "基本信息-手机号码": string | "",
    ...
    "下一步跟进计划-具体事项": string | ""
)>

请分析以下群聊对话记录,并根据上述格式提取信息:

**对话记录:**

请将提取的信息以JSON格式输出。

步骤6:主函数启动

编写主函数来处理对话记录并提取信息。

import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def convert_all_json_in_text_to_dict(text):
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

def print_json_format(data):
    print(json.dumps(data, indent=4, ensure_ascii=False))

def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        ...
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data

步骤7:生成提交文件

使用模型提取测试数据中的信息并生成提交文件。

from tqdm import tqdm

retry_count = 5
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({"infos": infos, "index": index})
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)

步骤8:保存输出

将结果保存为JSON文件。

write_json("output.json", result)

下载输出文件

在文件区域下载生成的output.json文件,并回到比赛页面提交结果!

🎉 恭喜你!你已经成功运行了完整的竞赛代码并获得了自己的第一个分数!接下来,继续完成速通学习手册中的步骤,争取更好的成绩吧!

赛事链接:点击跳转

希望这篇教程对你有所帮助,如果有任何疑问,请在评论区留言,我们将竭诚为你解答!

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值