#AI夏令营 #Datawhale #夏令营
一、数据集制作
首先导入库并且配置基础信息
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm
def chatbot(prompt):
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
messages = [ChatMessage(
role="user",
content=prompt
)]
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].message.content
之后设计提示词,这里的提示词需要更精炼一些,过长的提示词会导致模型的训练时间过长,耗费也会很大,但这并不一定会有更好的训练效果。
样例prompt
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
可以看到他并没有给出后续需要输出的具体格式,因为这个模型是用来优化原有数据集的。
具体代码在飞桨AI Studio星河社区-人工智能学习与实训社区 (baidu.com)
二、平台微调
讯飞平台:https://training.xfyun.cn/model/add
可以选择Spark Lite或者Spark Pro,将获得的训练集和测试集上传,训练得到微调模型,将得到的四个信息,resourceId、APPID、APIKey、APISecret填入一下代码
import SparkApi
import json
#以下密钥信息从控制台获取
appid = "" #填写控制台中获取的 APPID 信息
api_secret = "" #填写控制台中获取的 APISecret 信息
api_key ="" #填写控制台中获取的 APIKey 信息
#调用Spark Lite微调大模型时,设置为“patch”
domain = "patch"
#调用Spark pro微调大模型时,设置为“patchv3”
domain = "patchv3"
#云端环境的服务地址
Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat" # 微调v1.5环境的地址,对应Spark Lite
#Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat" # 微调v3.0环境的地址,对应Spark Pro
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234",
"patch_id": [""] #调用微调大模型时必传, 否则不传。对应resourceId
},
"parameter": {
"chat": {
"domain": domain,
"temperature": 0.1,
"max_tokens": 4096
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
得到训练结果文件,得分为26.15,比之前只修改提示词时大约高了5、6分。
三、感悟
如果需要进一步提升分数,我认为有以下努力方向
- 进一步完善提示词
- 增添好的数据集训练
由于不清楚评分标准,感觉可以提供更明确的指向性提示词让模型训练指向更明确或者是用更宽泛的提示词让模型有更好的泛化能力,但是由于一天只能提交三次以及模型训练时间过久,我觉得并没有那么好调试。
后续我认为需要加强自己的代码能力,这次的模型调优是在星火大模型傻瓜式操作以及班助给出的样例代码下完成的。大模型真的很有用。