Datawhale | AI+X AI 夏令营2024 Task2 笔记
#AI夏令营 #Datawhale #夏令营
30 分钟体验一站式 baseline!(点击即可跳转)
项目链接:https://aistudio.baidu.com/projectdetail/8095619
baseline1:
环境配置
!pip install --upgrade -q spark_ai_python tqdm
⚠️注意:spark_ai_python要求python版本一定要3.8以上
数据处理
import json
def read_json(json_file_path):
"""读取json文件"""
with open(json_file_path, 'r') as f:
data = json.load(f)
return data
def write_json(json_file_path, data):
"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")
# 查看对话数据
print(train_data[100]['chat_text'])
样例:可以看到下面的对话数据很像我们的聊天截图,这样的聊天数据其实在日常生活中很常见。那么这个任务也很有必要,就是从这些杂乱的记录里抽取出我们需要的数据。
prompt工程
这里的prompt编写规则可以进行这样的理解:
任务目标——抽取数据定义——抽取内容引入——抽取规则强调
# prompt 设计
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。
表单格式如下:
info: Array<Dict(
"基本信息-姓名": string | "", // 客户的姓名。
"基本信息-手机号码": string | "", // 客户的手机号码。
"基本信息-邮箱": string | "", // 客户的电子邮箱地址。
"基本信息-地区": string | "", // 客户所在的地区或城市。
"基本信息-详细地址": string | "", // 客户的详细地址。
"基本信息-性别": string | "", // 客户的性别。
"基本信息-年龄": string | "", // 客户的年龄。
"基本信息-生日": string | "", // 客户的生日。
"咨询类型": string[] | [], // 客户的咨询类型,如询价、答疑等。
"意向产品": string[] | [], // 客户感兴趣的产品。
"购买异议点": string[] | [], // 客户在购买过程中提出的异议或问题。
"客户预算-预算是否充足": string | "", // 客户的预算是否充足。示例:充足, 不充足
"客户预算-总体预算金额": string | "", // 客户的总体预算金额。
"客户预算-预算明细": string | "", // 客户预算的具体明细。
"竞品信息": string | "", // 竞争对手的信息。
"客户是否有意向": string | "", // 客户是否有购买意向。示例:有意向, 无意向
"客户是否有卡点": string | "", // 客户在购买过程中是否遇到阻碍或卡点。示例:有卡点, 无卡点
"客户购买阶段": string | "", // 客户当前的购买阶段,如合同中、方案交流等。
"下一步跟进计划-参与人": string[] | [], // 下一步跟进计划中涉及的人员(客服人员)。
"下一步跟进计划-时间点": string | "", // 下一步跟进的时间点。
"下一步跟进计划-具体事项": string | "" // 下一步需要进行的具体事项。
)>
请分析以下群聊对话记录,并根据上述格式提取信息:
**对话记录:**
```
{content}
```
请将提取的信息以JSON格式输出。
不要添加任何澄清信息。
输出必须遵循上面的模式。
不要添加任何没有出现在模式中的附加字段。
不要随意删除字段。
**输出:**
```
[{{
"基本信息-姓名": "姓名",
"基本信息-手机号码": "手机号码",
"基本信息-邮箱": "邮箱",
"基本信息-地区": "地区",
"基本信息-详细地址": "详细地址",
"基本信息-性别": "性别",
"基本信息-年龄": "年龄",
"基本信息-生日": "生日",
"咨询类型": ["咨询类型"],
"意向产品": ["意向产品"],
"购买异议点": ["购买异议点"],
"客户预算-预算是否充足": "充足或不充足",
"客户预算-总体预算金额": "总体预算金额",
"客户预算-预算明细": "预算明细",
"竞品信息": "竞品信息",
"客户是否有意向": "有意向或无意向",
"客户是否有卡点": "有卡点或无卡点",
"客户购买阶段": "购买阶段",
"下一步跟进计划-参与人": ["跟进计划参与人"],
"下一步跟进计划-时间点": "跟进计划时间点",
"下一步跟进计划-具体事项": "跟进计划具体事项"
}}, ...]
```
"""
数据抽取
使用prompt进行调试发现以下几个问题:
-
大模型总是不能直接输出python直接可读取的json格式,如:
```json [ { "基本信息-姓名": "张三", "基本信息-手机号码": "12345678901", "基本信息-邮箱": "zhangsan@example.com", "基本信息-地区": "北京市", "基本信息-详细地址": "朝阳区某街道", "基本信息-性别": "男", "基本信息-年龄": "30", "基本信息-生日": "1990-01-01", "咨询类型": ["询价"], "意向产品": ["产品A"], "购买异议点": ["价格高"], "客户预算-预算是否充足": "充足", "客户预算-总体预算金额": "10000", "客户预算-预算明细": "详细预算内容", "竞品信息": "竞争对手B", "客户是否有意向": "有意向", "客户是否有卡点": "无卡点", "客户购买阶段": "合同中", "下一步跟进计划-参与人": ["客服A"], "下一步跟进计划-时间点": "2024-07-01", "下一步跟进计划-具体事项": "沟通具体事项" } ] ```
故使用函数
convert_all_json_in_text_to_dict
对json数据进行提取def convert_all_json_in_text_to_dict(text): """提取LLM输出文本中的json字符串""" dicts, stack = [], [] for i in range(len(text)): if text[i] == '{': stack.append(i) elif text[i] == '}': begin = stack.pop() if not stack: dicts.append(json.loads(text[begin:i+1])) return dicts
-
大模型偶尔会出现缺少字段的情况,故使用
check_and_complete_json_format
函数对大模型抽取的结果进行字段格式的检查以及缺少的字段进行补全。from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler from sparkai.core.messages import ChatMessage import json from tqdm import tqdm #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看 SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat' #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看 SPARKAI_APP_ID = '' SPARKAI_API_SECRET = '' SPARKAI_API_KEY = '' #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看 SPARKAI_DOMAIN = 'generalv3.5' # prompt 设计 PROMPT_EXTRACT = """ 你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。 表单格式如下: info: Array<Dict( "基本信息-姓名": string | "", // 客户的姓名。 "基本信息-手机号码": string | "", // 客户的手机号码。 "基本信息-邮箱": string | "", // 客户的电子邮箱地址。 "基本信息-地区": string | "", // 客户所在的地区或城市。 "基本信息-详细地址": string | "", // 客户的详细地址。 "基本信息-性别": string | "", // 客户的性别。 "基本信息-年龄": string | "", // 客户的年龄。 "基本信息-生日": string | "", // 客户的生日。 "咨询类型": string[] | [], // 客户的咨询类型,如询价、答疑等。 "意向产品": string[] | [], // 客户感兴趣的产品。 "购买异议点": string[] | [], // 客户在购买过程中提出的异议或问题。 "客户预算-预算是否充足": string | "", // 客户的预算是否充足。示例:充足, 不充足 "客户预算-总体预算金额": string | "", // 客户的总体预算金额。 "客户预算-预算明细": string | "", // 客户预算的具体明细。 "竞品信息": string | "", // 竞争对手的信息。 "客户是否有意向": string | "", // 客户是否有购买意向。示例:有意向, 无意向 "客户是否有卡点": string | "", // 客户在购买过程中是否遇到阻碍或卡点。示例:有卡点, 无卡点 "客户购买阶段": string | "", // 客户当前的购买阶段,如合同中、方案交流等。 "下一步跟进计划-参与人": string[] | [], // 下一步跟进计划中涉及的人员(客服人员)。 "下一步跟进计划-时间点": string | "", // 下一步跟进的时间点。 "下一步跟进计划-具体事项": string | "" // 下一步需要进行的具体事项。 )> 请分析以下群聊对话记录,并根据上述格式提取信息: **对话记录:** ``` {content} ``` 请将提取的信息以JSON格式输出。 不要添加任何澄清信息。 输出必须遵循上面的模式。 不要添加任何没有出现在模式中的附加字段。 不要随意删除字段。 **输出:** ``` [{{ "基本信息-姓名": "姓名", "基本信息-手机号码": "手机号码", "基本信息-邮箱": "邮箱", "基本信息-地区": "地区", "基本信息-详细地址": "详细地址", "基本信息-性别": "性别", "基本信息-年龄": "年龄", "基本信息-生日": "生日", "咨询类型": ["咨询类型"], "意向产品": ["意向产品"], "购买异议点": ["购买异议点"], "客户预算-预算是否充足": "充足或不充足", "客户预算-总体预算金额": "总体预算金额", "客户预算-预算明细": "预算明细", "竞品信息": "竞品信息", "客户是否有意向": "有意向或无意向", "客户是否有卡点": "有卡点或无卡点", "客户购买阶段": "购买阶段", "下一步跟进计划-参与人": ["跟进计划参与人"], "下一步跟进计划-时间点": "跟进计划时间点", "下一步跟进计划-具体事项": "跟进计划具体事项" }}, ...] ``` """ def read_json(json_file_path): """读取json文件""" with open(json_file_path, 'r') as f: data = json.load(f) return data def write_json(json_file_path, data): """写入json文件""" with open(json_file_path, 'w') as f: json.dump(data, f, ensure_ascii=False, indent=4) def get_completions(text): messages = [ChatMessage( role="user", content=text )] spark = ChatSparkLLM( spark_api_url=SPARKAI_URL, spark_app_id=SPARKAI_APP_ID, spark_api_key=SPARKAI_API_KEY, spark_api_secret=SPARKAI_API_SECRET, spark_llm_domain=SPARKAI_DOMAIN, streaming=False, ) handler = ChunkPrintHandler() a = spark.generate([messages], callbacks=[handler]) return a.generations[0][0].text def convert_all_json_in_text_to_dict(text): """提取LLM输出文本中的json字符串""" dicts, stack = [], [] for i in range(len(text)): if text[i] == '{': stack.append(i) elif text[i] == '}': begin = stack.pop() if not stack: dicts.append(json.loads(text[begin:i+1])) return dicts class JsonFormatError(Exception): def __init__(self, message): self.message = message super().__init__(self.message) def check_and_complete_json_format(data): required_keys = { "基本信息-姓名": str, "基本信息-手机号码": str, "基本信息-邮箱": str, "基本信息-地区": str, "基本信息-详细地址": str, "基本信息-性别": str, "基本信息-年龄": str, "基本信息-生日": str, "咨询类型": list, "意向产品": list, "购买异议点": list, "客户预算-预算是否充足": str, "客户预算-总体预算金额": str, "客户预算-预算明细": str, "竞品信息": str, "客户是否有意向": str, "客户是否有卡点": str, "客户购买阶段": str, "下一步跟进计划-参与人": list, "下一步跟进计划-时间点": str, "下一步跟进计划-具体事项": str } if not isinstance(data, list): raise JsonFormatError("Data is not a list") for item in data: if not isinstance(item, dict): raise JsonFormatError("Item is not a dictionary") for key, value_type in required_keys.items(): if key not in item: item[key] = [] if value_type == list else "" if not isinstance(item[key], value_type): raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}") if value_type == list and not all(isinstance(i, str) for i in item[key]): raise JsonFormatError(f"Key '{key}' does not contain all strings in the list") if __name__ == "__main__": retry_count = 5 # 重试次数 result = [] error_data = [] # 读取数据 train_data = read_json("dataset/train.json") test_data = read_json("dataset/test_data.json") for index, data in tqdm(enumerate(test_data)): index += 1 is_success = False for i in range(retry_count): try: res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"])) infos = convert_all_json_in_text_to_dict(res) infos = check_and_complete_json_format(infos) result.append({ "infos": infos, "index": index }) is_success = True break except Exception as e: print("index:", index, ", error:", e) continue if not is_success: data["index"] = index error_data.append(data) write_json("output.json", result)
baseline2:
https://aistudio.baidu.com/projectdetail/8090135
数据集制作
需要先对原始群聊数据做初步抽取,我们需要准备一下讯飞3.5的api环境配置。和baseline1的配置一样。
!pip uninstall websocket-client
!pip install --upgrade spark_ai_python websocket-client
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm
def chatbot(prompt):
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
messages = [ChatMessage(
role="user",
content=prompt
)]
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].message.content
数据处理Prompt
¶
这里我们对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。
一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。
过长的上下文也会导致训练时长和费用倍增。
好了我们来说说prompt。这个prompt相较于baseline1区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。
content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
训练数据集制作
jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个jsonl文件(每行是一个json数据的文件),格式如下:
jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
print(jsonl_data)
print(jsonl_data["instruction"])
print(jsonl_data["input"])
print(jsonl_data["output"])
import json
# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
data = json.load(file)
这里我们通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。
这里的训练时长大概40min左右,请耐心等待。
# 训练集制作
# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
# 训练集行数(130)不符合要求,范围:1500~90000000
# 遍历数据列表,并将每一行写入文件
# 这里为了满足微调需求我们重复12次数据集 130*12=1560
for line_data in tqdm(data):
line_input = line_data["chat_text"]
line_output = line_data["infos"]
content = line_input
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(res)
line_write = {
"instruction":jsonl_data["instruction"],
"input":json.dumps(res, ensure_ascii=False),
"output":json.dumps(line_output, ensure_ascii=False)
}
# 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
for time in range(12):
file.write(json.dumps(line_write, ensure_ascii=False) + '\n') # '\n' 用于在每行末尾添加换行符
测试集数据制作
# 验证集制作(提交版本)
# input,target
import json
# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
data_test = json.load(file)
import csv
# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
# 创建一个csv writer对象
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["input","target"])
# 遍历数据列表,并将每一行写入CSV文件
for line_data in tqdm(data_test):
content = line_data["chat_text"]
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
res = chatbot(prompt=prompt)
# print(line_data["chat_text"])
## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
line_list = [res, "-"]
csvwriter.writerow(line_list)
# break
data_test
运行结束后可生成测试集CSV文件。
模型微调
回到:https://training.xfyun.cn/dataset/datasetIndex
这次我们选择测试集即可
平台微调
点击左侧导航栏的我的模型服务接着拿到resourceId、APPID、APIKey、APISecret
微调推理
# 定义写入函数
def write_json(json_file_path, data):
#"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
在main.ipynb的微调推理部分填入APPID、APIKey、APISecret(注意顺序)
在SparkApi.py文件的108行,引号中填入你的resourceId
修改完成后按顺序运行即可
结果提交
快去提交结果吧 2024 iFLYTEK AI开发者大赛-讯飞开放平台