🐲赛题地址
🐲引入模块
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json
from tqdm import tqdm
引入必要的模块,其中 ChatSparkLLM
和 ChunkPrintHandler
用于与星火认知大模型进行交互,ChatMessage
用于构建消息,json
模块用于处理JSON数据,tqdm
用于显示进度条。
🐲配置星火认知大模型
定义星火认知大模型的URL、App ID、APIKey和APISecret。配置信息在:控制台-讯飞开放平台
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
SPARKAI_DOMAIN = 'generalv3.5'
🐲提取信息的提示模板
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。
...(省略部分内容)
"""
🐲读取和写入JSON文件的函数
def read_json(json_file_path):
with open(json_file_path, 'r') as f:
data = json.load(f)
return data
def write_json(json_file_path, data):
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
🐲获取大模型的输出
def get_completions(text):
messages = [ChatMessage(
role="user",
content=text
)]
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].text
定义一个函数,用于向星火认知大模型发送消息并获取回复。api使用教程在:星火认知大模型Web API文档 | 讯飞开放平台文档中心
🐲提取JSON字符串并转换为字典
def convert_all_json_in_text_to_dict(text):
dicts, stack = [], []
for i in range(len(text)):
if text[i] == '{':
stack.append(i)
elif text[i] == '}':
begin = stack.pop()
if not stack:
dicts.append(json.loads(text[begin:i+1]))
return dicts
定义一个函数,用于从大模型输出的字符串中提取json。
🐲检查和补全JSON格式
class JsonFormatError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
def check_and_complete_json_format(data):
required_keys = {
"基本信息-姓名": str,
"基本信息-手机号码": str,
"基本信息-邮箱": str,
"基本信息-地区": str,
"基本信息-详细地址": str,
"基本信息-性别": str,
"基本信息-年龄": str,
"基本信息-生日": str,
"咨询类型": list,
"意向产品": list,
"购买异议点": list,
"客户预算-预算是否充足": str,
"客户预算-总体预算金额": str,
"客户预算-预算明细": str,
"竞品信息": str,
"客户是否有意向": str,
"客户是否有卡点": str,
"客户购买阶段": str,
"下一步跟进计划-参与人": list,
"下一步跟进计划-时间点": str,
"下一步跟进计划-具体事项": str
}
if not isinstance(data, list):
raise JsonFormatError("Data is not a list")
for item in data:
if not isinstance(item, dict):
raise JsonFormatError("Item is not a dictionary")
for key, value_type in required_keys.items():
if key not in item:
item[key] = [] if value_type == list else ""
if not isinstance(item[key], value_type):
raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
if value_type == list and not all(isinstance(i, str) for i in item[key]):
raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")
定义一个函数和异常类,用于检查和补全JSON格式,确保所有必要字段都存在且类型正确。
🐲主程序
if __name__ == "__main__":
retry_count = 5 # 重试次数
result = []
error_data = []
# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")
for index, data in tqdm(enumerate(test_data)):
index += 1
is_success = False
for i in range(retry_count):
try:
res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
infos = convert_all_json_in_text_to_dict(res)
infos = check_and_complete_json_format(infos)
result.append({
"infos": infos,
"index": index
})
is_success = True
break
except Exception as e:
print("index:", index, ", error:", e)
continue
if not is_success:
data["index"] = index
error_data.append(data)
write_json("output.json", result)
主程序部分首先读取训练和测试数据,然后遍历测试数据并调用大模型获取提取的信息,检查和补全JSON格式,最终将结果写入输出文件中。
#ai夏令营#datawhale#夏令营#ai