Datawhale AI夏令营 baseline1精读分享直播

目录

一、环境配置

二、数据处理

三、promot工程

四、数据抽取 

五、个人感悟


听了baseline1的精读分享直播,对程序进行简要梳理。

一、环境配置

spark_ai_python要求python版本3.8以上

!pip install --upgrade -q spark_ai_python tqdm jsonschema python-dotenv

导入包,其中dotenv是一个零依赖的模块,它的主要功能是从.env文件中加载环境变量,此处从.env文件中加载ID、SECRET和KEY变量 

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json
from dotenv import load_dotenv

# 加载.env文件中的环境变量
"""
SPARKAI_APP_ID=""
SPARKAI_API_SECRET=""
SPARKAI_API_KEY=""
"""
load_dotenv()

 在讯飞开放控制台获得大模型调用密钥信息

#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = os.getenv("SPARKAI_APP_ID")
SPARKAI_API_SECRET = os.getenv("SPARKAI_API_SECRET")
SPARKAI_API_KEY = os.getenv("SPARKAI_API_KEY")
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'

 测试星火大模型是否可以正常使用

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
get_completions(text)

二、数据处理

目的:读取训练集和测试集的数据本身

定义两个函数用于读取json文件和写入json文件 

def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")

查看对话数据集,是群聊数据,有部分影响大模型结果的信息存在:图片和引用消息等

# 查看对话数据
print(train_data[100]['chat_text'])

格式化为json格式 

# 查看对话标签
def print_json_format(data):
    """格式化输出json格式"""
    print(json.dumps(data, indent=4, ensure_ascii=False))

print_json_format(train_data[100]['infos'])

定义函数提取文本中的json字符串,type为list

def convert_all_json_in_text_to_dict(text):
    """提取LLM输出文本中的json字符串"""
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

llm_output = """
```json
[{
    "基本信息-姓名": "李强1",
    "基本信息-手机号码": "11059489858"
}]
```
"""

# 测试一下效果
json_res = convert_all_json_in_text_to_dict(llm_output)
print_json_format(json_res)
print(type(json_res))

三、promot工程

promot编写思路:任务目标-抽取数据定义-抽取内容引入-抽取规则强调

将群聊对话输入大模型 

 prompt 设计
PROMPT_EXTRACT = """
你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。

 运行测试

content = train_data[100]['chat_text']
res = get_completions(PROMPT_EXTRACT.format(content=content))
json_res = convert_all_json_in_text_to_dict(res)
print_json_format(json_res)

查看原格式,含有markdown标签

print(res)

 

 查看数据对应的标签

# 查看训练数据对应的标签
print_json_format(train_data[100]['infos'])

四、数据抽取 

检查json格式并补全,防止大模型将空字段删除以及输出格式异常

check_and_complete_json_format函数对大模型抽取的结果进行字段格式的检查以及缺少的字段进行补全

import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data

# Example usage:
json_data = '''
[
    {
        "基本信息-姓名": "张三",
        "基本信息-邮箱": "zhangsan@example.com",
        "基本信息-地区": "北京市",
        "基本信息-详细地址": "朝阳区某街道",
        "基本信息-性别": "男",
        "基本信息-年龄": "30",
        "基本信息-生日": "1990-01-01",
        "咨询类型": "",
        "意向产品": ["产品A"],
        "购买异议点": ["价格高"],
        "客户预算-预算是否充足": "充足",
        "客户预算-总体预算金额": "10000",
        "客户预算-预算明细": "详细预算内容",
        "竞品信息": "竞争对手B",
        "客户是否有意向": "有意向",
        "客户是否有卡点": "无卡点",
        "客户购买阶段": "合同中",
        "下一步跟进计划-参与人": ["客服A"],
        "下一步跟进计划-时间点": "2024-07-01",
        "下一步跟进计划-具体事项": "沟通具体事项"
    }
]
'''

try:
    data = json.loads(json_data)
    completed_data = check_and_complete_json_format(data)
    print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:
    print(f"JSON format error: {e.message}")

 使用另一个jsonschema库进行更加简便的格式验证

import json
from jsonschema import validate, Draft7Validator
from jsonschema.exceptions import ValidationError

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

schema = {
    "type": "array",
    "items": {
        "type": "object",
        "properties": {
            "基本信息-姓名": {"type": "string", "default": ""},
            "基本信息-手机号码": {"type": "string", "default": ""},
            "基本信息-邮箱": {"type": "string", "default": ""},
            "基本信息-地区": {"type": "string", "default": ""},
            "基本信息-详细地址": {"type": "string", "default": ""},
            "基本信息-性别": {"type": "string", "default": ""},
            "基本信息-年龄": {"type": "string", "default": ""},
            "基本信息-生日": {"type": "string", "default": ""},
            "咨询类型": {"type": "array", "items": {"type": "string"}, "default": []},
            "意向产品": {"type": "array", "items": {"type": "string"}, "default": []},
            "购买异议点": {"type": "array", "items": {"type": "string"}, "default": []},
            "客户预算-预算是否充足": {"type": "string", "enum": ["充足", "不充足", ""], "default": ""},
            "客户预算-总体预算金额": {"type": "string", "default": ""},
            "客户预算-预算明细": {"type": "string", "default": ""},
            "竞品信息": {"type": "string", "default": ""},
            "客户是否有意向": {"type": "string", "enum": ["有意向", "无意向", ""], "default": ""},
            "客户是否有卡点": {"type": "string", "enum": ["有卡点", "无卡点", ""], "default": ""},
            "客户购买阶段": {"type": "string", "default": ""},
            "下一步跟进计划-参与人": {"type": "array", "items": {"type": "string"}, "default": []},
            "下一步跟进计划-时间点": {"type": "string", "default": ""},
            "下一步跟进计划-具体事项": {"type": "string", "default": ""}
        },
        "required": [
            "基本信息-姓名", "基本信息-手机号码", "基本信息-邮箱", "基本信息-地区", 
            "基本信息-详细地址", "基本信息-性别", "基本信息-年龄", "基本信息-生日",
            "咨询类型", "意向产品", "购买异议点", "客户预算-预算是否充足", 
            "客户预算-总体预算金额", "客户预算-预算明细", "竞品信息", 
            "客户是否有意向", "客户是否有卡点", "客户购买阶段", 
            "下一步跟进计划-参与人", "下一步跟进计划-时间点", "下一步跟进计划-具体事项"
        ]
    }
}

def validate_and_complete_json(data):
    # Create a validator with the ability to fill in default values
    validator = Draft7Validator(schema)
    for item in data:
        errors = sorted(validator.iter_errors(item), key=lambda e: e.path)
        for error in errors:
            # If the property is missing and has a default, apply the default value
            for subschema in error.schema_path:
                if 'default' in error.schema:
                    item[error.schema_path[-1]] = error.schema['default']
                    break

    # Validate the completed data
    try:
        validate(instance=data, schema=schema)
    except ValidationError as e:
        raise JsonFormatError(f"JSON format error: {e.message}")

    return data

# Example usage:
json_data = '''
[
    {
        "基本信息-姓名": "张三",
        "基本信息-手机号码": "12345678901",
        "基本信息-邮箱": "zhangsan@example.com",
        "基本信息-地区": "北京市",
        "基本信息-详细地址": "朝阳区某街道",
        "基本信息-性别": "男",
        "基本信息-年龄": "30",
        "基本信息-生日": "1990-01-01",
        "咨询类型": ["询价"],
        "意向产品": ["产品A"],
        "购买异议点": ["价格高"],
        "客户预算-预算是否充足": "充足",
        "客户预算-总体预算金额": "10000",
        "客户预算-预算明细": "详细预算内容",
        "竞品信息": "竞争对手B",
        "客户是否有意向": "有意向",
        "客户是否有卡点": "无卡点",
        "客户购买阶段": "合同中",
        "下一步跟进计划-参与人": ["客服A"],
        "下一步跟进计划-时间点": "2024-07-01",
        "下一步跟进计划-具体事项": "沟通具体事项"
    }
]
'''

try:
    data = json.loads(json_data)
    completed_data = validate_and_complete_json(data)
    print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:
    print(f"JSON format error: {e.message}")

防止数据格式异常,可重新调用API获取数据

if error_data:
    retry_count = 10 # 重试次数
    error_data_temp = []
    while True:
        if error_data_temp:
            error_data = error_data_temp
            error_data_temp = []
        for data in tqdm(error_data):
            is_success = False
            for i in range(retry_count):
                try:
                    res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
                    infos = convert_all_json_in_text_to_dict(res)
                    infos = check_and_complete_json_format(infos)
                    result.append({
                        "infos": infos,
                        "index": data["index"]
                    })
                    is_success = True
                    break
                except Exception as e:
                    print("index:", index, ", error:", e)
                    continue
            if not is_success:
                error_data_temp.append(data)
        if not error_data_temp:
            break
    result = sorted(result, key=lambda x: x["index"])

如果有错误数据,重新请求 

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)

写入output.json并提交 

write_json("output.json", result)

五、个人感悟

在运行过程中报错:

Did not find spark_app_id, please add an environment variable `IFLYTEK_SPARK_APP_ID` which contains it, or pass `spark_app_id` as a named parameter. (type=value_error)

经过多次尝试后发现是需要在目录下自己创建.env文件,将环境变量写进去。

不知道是不是这么做的,但是我添加了文本文件,重命名为.env,之后该文件就不见了,但是程序可以运行,猜测是因为这个名称改变了路径。

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值