Datawhale夏令营大模型竞赛笔记(1):零基础入门Baseline学习

#ai夏令营#datawhale#夏令营
这篇文章是记录自己参加Datawhale夏令营期间的学习心得,欢迎大家交流讨论

1.Baseline运行链接

相关的baseline运行链接如下,个人认为链接介绍十分详细,小白跟着跑通是完全没有问题的
https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS

2.Baseline代码学习

为了更好地学习,个人简单对Baseline代码进行解读

2.1 下载相关库

!pip install --upgrade -q spark_ai_python

这一步是用以下载python接入星火大模型的库,用以方便后续更好地调用星火大模型API

2.2 导入配置

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json


#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'

这一步是导入相关的配置,并填入调用模型的相关信息

2.3 模型测试

为了测试模型是否能够正常完成调用,定义了以下的代码:

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
get_completions(text)

函数接受一个文本输入text作为参数。在函数内部,首先会创建了一个包含用户消息的列表messages,然后初始化一个ChatSparkLLM对象(ChatSparkLLM是一个用于与SparkAI模型进行交互的类,我们需要把API URL、应用ID、API密钥等都传递给它),接着创建了一个ChunkPrintHandler对象handler用于处理生成的文本。然后调用spark.generate()方法来生成文本,将用户消息传递给模型,并将生成的文本通过handler处理。最后,返回生成的文本作为函数的输出结果。

2.4 数据读取

这一步是读取比赛提供相关的json文件,用以后续的处理

def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")

2.5 Prompt设计

为了让大模型在接收到数据之后能够按照我们的需要进行输出,我们需要进行prompt设计,这段prompt主要有以下部分组成

  • 1.告诉大模型他的任务
  • 2.告诉大模型表单的格式
  • 3.告诉大模型聊天对话记录(由后续的PROMPT_EXTRACT.format(content=data["chat_text"])补充完整)
  • 4.告诉大模型输出的格式
    在这里插入图片描述

2.6 主函数启动

在正式调用大模型之前,代码还定义了两个函数

  • convert_all_json_in_text_to_dict:这个函数用于将大模型的文本输出转化为dict格式
  • check_and_complete_json_format: 这个函数用于检查提取的json下的每一项的格式是否正确,不正确则会报错(raise代码定义的JsonFormatError)
import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def convert_all_json_in_text_to_dict(text):
    """提取LLM输出文本中的json字符串"""
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

# 查看对话标签
def print_json_format(data):
    """格式化输出json格式"""
    print(json.dumps(data, indent=4, ensure_ascii=False))



def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data

完成了前面的准备工作,便可以正式开始调用大模型进行交互

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)

对于每一个test_data, 代码会尝试调用模型五次,由于模型的输出具有不确定性,因此每次都要对返回的文本进行解析然后检查,检查无误之后才会将数据保存到result中

2.7 生成提交文件

最后将result保存为output.json便完成了baseline的代码运行!

# 保存输出
write_json("output.json", result)

3 优化的想法

后续的优化个人认为主要就朝着两个方向进行优化

  • 优化prompt/修改temperature等参数
  • 微调大模型
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值