目录
听了baseline1的精读分享直播,对程序进行简要梳理。
一、环境配置
spark_ai_python要求python版本3.8以上
!pip install --upgrade -q spark_ai_python tqdm jsonschema python-dotenv
导入包,其中dotenv是一个零依赖的模块,它的主要功能是从.env文件中加载环境变量,此处从.env文件中加载ID、SECRET和KEY变量
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json
from dotenv import load_dotenv
# 加载.env文件中的环境变量
"""
SPARKAI_APP_ID=""
SPARKAI_API_SECRET=""
SPARKAI_API_KEY=""
"""
load_dotenv()
在讯飞开放控制台获得大模型调用密钥信息
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = os.getenv("SPARKAI_APP_ID")
SPARKAI_API_SECRET = os.getenv("SPARKAI_API_SECRET")
SPARKAI_API_KEY = os.getenv("SPARKAI_API_KEY")
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
测试星火大模型是否可以正常使用
def get_completions(text):
messages = [ChatMessage(
role="user",
content=text
)]
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].text
# 测试模型配置是否正确
text = "你好"
get_completions(text)
二、数据处理
目的:读取训练集和测试集的数据本身
定义两个函数用于读取json文件和写入json文件
def read_json(json_file_path):
"""读取json文件"""
with open(json_file_path, 'r') as f:
data = json.load(f)
return data
def write_json(json_file_path, data):
"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")
查看对话数据集,是群聊数据,有部分影响大模型结果的信息存在:图片和引用消息等
# 查看对话数据
print(train_data[100]['chat_text'])
格式化为json格式
# 查看对话标签
def print_json_format(data):
"""格式化输出json格式"""
print(json.dumps(data, indent=4, ensure_ascii=False))
print_json_format(train_data[100]['infos'])
定义函数提取文本中的json字符串,type为list
def convert_all_json_in_text_to_dict(text):
"""提取LLM输出文本中的json字符串"""
dicts, stack = [], []
for i in range(len(text)):
if text[i] == '{':
stack.append(i)
elif text[i] == '}':
begin = stack.pop()
if not stack:
dicts.append(json.loads(text[begin:i+1]))
return dicts
llm_output = """
```json
[{
"基本信息-姓名": "李强1",
"基本信息-手机号码": "11059489858"
}]
```
"""
# 测试一下效果
json_res = convert_all_json_in_text_to_dict(llm_output)
print_json_format(json_res)
print(type(json_res))
三、promot工程
promot编写思路:任务目标-抽取数据定义-抽取内容引入-抽取规则强调
将群聊对话输入大模型
prompt 设计
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。
运行测试
content = train_data[100]['chat_text']
res = get_completions(PROMPT_EXTRACT.format(content=content))
json_res = convert_all_json_in_text_to_dict(res)
print_json_format(json_res)
查看原格式,含有markdown标签
print(res)
查看数据对应的标签
# 查看训练数据对应的标签
print_json_format(train_data[100]['infos'])
四、数据抽取
检查json格式并补全,防止大模型将空字段删除以及输出格式异常
check_and_complete_json_format函数对大模型抽取的结果进行字段格式的检查以及缺少的字段进行补全
import json
class JsonFormatError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
def check_and_complete_json_format(data):
required_keys = {
"基本信息-姓名": str,
"基本信息-手机号码": str,
"基本信息-邮箱": str,
"基本信息-地区": str,
"基本信息-详细地址": str,
"基本信息-性别": str,
"基本信息-年龄": str,
"基本信息-生日": str,
"咨询类型": list,
"意向产品": list,
"购买异议点": list,
"客户预算-预算是否充足": str,
"客户预算-总体预算金额": str,
"客户预算-预算明细": str,
"竞品信息": str,
"客户是否有意向": str,
"客户是否有卡点": str,
"客户购买阶段": str,
"下一步跟进计划-参与人": list,
"下一步跟进计划-时间点": str,
"下一步跟进计划-具体事项": str
}
if not isinstance(data, list):
raise JsonFormatError("Data is not a list")
for item in data:
if not isinstance(item, dict):
raise JsonFormatError("Item is not a dictionary")
for key, value_type in required_keys.items():
if key not in item:
item[key] = [] if value_type == list else ""
if not isinstance(item[key], value_type):
raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
if value_type == list and not all(isinstance(i, str) for i in item[key]):
raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")
return data
# Example usage:
json_data = '''
[
{
"基本信息-姓名": "张三",
"基本信息-邮箱": "zhangsan@example.com",
"基本信息-地区": "北京市",
"基本信息-详细地址": "朝阳区某街道",
"基本信息-性别": "男",
"基本信息-年龄": "30",
"基本信息-生日": "1990-01-01",
"咨询类型": "",
"意向产品": ["产品A"],
"购买异议点": ["价格高"],
"客户预算-预算是否充足": "充足",
"客户预算-总体预算金额": "10000",
"客户预算-预算明细": "详细预算内容",
"竞品信息": "竞争对手B",
"客户是否有意向": "有意向",
"客户是否有卡点": "无卡点",
"客户购买阶段": "合同中",
"下一步跟进计划-参与人": ["客服A"],
"下一步跟进计划-时间点": "2024-07-01",
"下一步跟进计划-具体事项": "沟通具体事项"
}
]
'''
try:
data = json.loads(json_data)
completed_data = check_and_complete_json_format(data)
print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:
print(f"JSON format error: {e.message}")
使用另一个jsonschema库进行更加简便的格式验证
import json
from jsonschema import validate, Draft7Validator
from jsonschema.exceptions import ValidationError
class JsonFormatError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"基本信息-姓名": {"type": "string", "default": ""},
"基本信息-手机号码": {"type": "string", "default": ""},
"基本信息-邮箱": {"type": "string", "default": ""},
"基本信息-地区": {"type": "string", "default": ""},
"基本信息-详细地址": {"type": "string", "default": ""},
"基本信息-性别": {"type": "string", "default": ""},
"基本信息-年龄": {"type": "string", "default": ""},
"基本信息-生日": {"type": "string", "default": ""},
"咨询类型": {"type": "array", "items": {"type": "string"}, "default": []},
"意向产品": {"type": "array", "items": {"type": "string"}, "default": []},
"购买异议点": {"type": "array", "items": {"type": "string"}, "default": []},
"客户预算-预算是否充足": {"type": "string", "enum": ["充足", "不充足", ""], "default": ""},
"客户预算-总体预算金额": {"type": "string", "default": ""},
"客户预算-预算明细": {"type": "string", "default": ""},
"竞品信息": {"type": "string", "default": ""},
"客户是否有意向": {"type": "string", "enum": ["有意向", "无意向", ""], "default": ""},
"客户是否有卡点": {"type": "string", "enum": ["有卡点", "无卡点", ""], "default": ""},
"客户购买阶段": {"type": "string", "default": ""},
"下一步跟进计划-参与人": {"type": "array", "items": {"type": "string"}, "default": []},
"下一步跟进计划-时间点": {"type": "string", "default": ""},
"下一步跟进计划-具体事项": {"type": "string", "default": ""}
},
"required": [
"基本信息-姓名", "基本信息-手机号码", "基本信息-邮箱", "基本信息-地区",
"基本信息-详细地址", "基本信息-性别", "基本信息-年龄", "基本信息-生日",
"咨询类型", "意向产品", "购买异议点", "客户预算-预算是否充足",
"客户预算-总体预算金额", "客户预算-预算明细", "竞品信息",
"客户是否有意向", "客户是否有卡点", "客户购买阶段",
"下一步跟进计划-参与人", "下一步跟进计划-时间点", "下一步跟进计划-具体事项"
]
}
}
def validate_and_complete_json(data):
# Create a validator with the ability to fill in default values
validator = Draft7Validator(schema)
for item in data:
errors = sorted(validator.iter_errors(item), key=lambda e: e.path)
for error in errors:
# If the property is missing and has a default, apply the default value
for subschema in error.schema_path:
if 'default' in error.schema:
item[error.schema_path[-1]] = error.schema['default']
break
# Validate the completed data
try:
validate(instance=data, schema=schema)
except ValidationError as e:
raise JsonFormatError(f"JSON format error: {e.message}")
return data
# Example usage:
json_data = '''
[
{
"基本信息-姓名": "张三",
"基本信息-手机号码": "12345678901",
"基本信息-邮箱": "zhangsan@example.com",
"基本信息-地区": "北京市",
"基本信息-详细地址": "朝阳区某街道",
"基本信息-性别": "男",
"基本信息-年龄": "30",
"基本信息-生日": "1990-01-01",
"咨询类型": ["询价"],
"意向产品": ["产品A"],
"购买异议点": ["价格高"],
"客户预算-预算是否充足": "充足",
"客户预算-总体预算金额": "10000",
"客户预算-预算明细": "详细预算内容",
"竞品信息": "竞争对手B",
"客户是否有意向": "有意向",
"客户是否有卡点": "无卡点",
"客户购买阶段": "合同中",
"下一步跟进计划-参与人": ["客服A"],
"下一步跟进计划-时间点": "2024-07-01",
"下一步跟进计划-具体事项": "沟通具体事项"
}
]
'''
try:
data = json.loads(json_data)
completed_data = validate_and_complete_json(data)
print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:
print(f"JSON format error: {e.message}")
防止数据格式异常,可重新调用API获取数据
if error_data:
retry_count = 10 # 重试次数
error_data_temp = []
while True:
if error_data_temp:
error_data = error_data_temp
error_data_temp = []
for data in tqdm(error_data):
is_success = False
for i in range(retry_count):
try:
res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
infos = convert_all_json_in_text_to_dict(res)
infos = check_and_complete_json_format(infos)
result.append({
"infos": infos,
"index": data["index"]
})
is_success = True
break
except Exception as e:
print("index:", index, ", error:", e)
continue
if not is_success:
error_data_temp.append(data)
if not error_data_temp:
break
result = sorted(result, key=lambda x: x["index"])
如果有错误数据,重新请求
from tqdm import tqdm
retry_count = 5 # 重试次数
result = []
error_data = []
for index, data in tqdm(enumerate(test_data)):
index += 1
is_success = False
for i in range(retry_count):
try:
res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
infos = convert_all_json_in_text_to_dict(res)
infos = check_and_complete_json_format(infos)
result.append({
"infos": infos,
"index": index
})
is_success = True
break
except Exception as e:
print("index:", index, ", error:", e)
continue
if not is_success:
data["index"] = index
error_data.append(data)
写入output.json并提交
write_json("output.json", result)
五、个人感悟
在运行过程中报错:
Did not find spark_app_id, please add an environment variable `IFLYTEK_SPARK_APP_ID` which contains it, or pass `spark_app_id` as a named parameter. (type=value_error)
经过多次尝试后发现是需要在目录下自己创建.env文件,将环境变量写进去。
不知道是不是这么做的,但是我添加了文本文件,重命名为.env,之后该文件就不见了,但是程序可以运行,猜测是因为这个名称改变了路径。