不微调范式模拟官网评分
比赛说明:
#AI夏令营 #Datawhale #夏令营
主要参考datawhale夏令营活动:零基础入门大模型技术竞赛。
连接:https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS
比赛网址:
https://challenge.xfyun.cn/topic/info?type=role-element-extraction&ch=dw24_y0SCtd
说明:
- 1,主要适用于不微调的范式。
- 2,针对每次在修改prompt,或者COT之后,想要查看性能如何时,都要提交到官网等待。但是受限于官网每个人每天只能提交3次,无法得到更多的反馈。
- 3,在这里主要从训练集train.json中,随机挑选数据,作为验证集,模仿官网的评分细则,用于验证性能指标。当验证性能满意后,再放到test.json数据进行推理,并提交官网。
步骤:
- 1,模型api配置及加载测试:
- 2,数据加载:加载训练集,数据预处理,数据分析,可设置验证集比例
- 3,prompt设计:提示工程或者COT的方式,根据数据分析设计提示;
- 4,模型推理:输出符合格式预测;
- 5,结果测试:采用与讯飞比赛官网相同的得分计算策略。
为了方便展示,参考群里某大佬画的不微调范式的概要图:
step1: 模型api配置及加载测试
# api配置
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json
#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
# 模型对话测试
def get_completions(text):
messages = [ChatMessage(
role="user",
content=text
)]
spark = ChatSparkLLM(
spark_api_url=SPARKAI_URL,
spark_app_id=SPARKAI_APP_ID,
spark_api_key=SPARKAI_API_KEY,
spark_api_secret=SPARKAI_API_SECRET,
spark_llm_domain=SPARKAI_DOMAIN,
streaming=False,
)
handler = ChunkPrintHandler()
a = spark.generate([messages], callbacks=[handler])
return a.generations[0][0].text
# 测试模型配置是否正确
text = "你好,请问你是谁?"
get_completions(text)
step2: 数据加载与数据分析:
加载训练集,部分化为验证集,可设置验证集比例
def read_json(json_file_path):
"""读取json文件"""
with open(json_file_path, 'r') as f:
data = json.load(f)
return data
def write_json(json_file_path, data):
"""写入json文件"""
with open(json_file_path, 'w') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
# 读取数据
train_data = read_json("dataset/train.json")
print('done!')
# 查看数据格式
print(train_data[1]['chat_text'])
# 简单的数据清洗:将对话的无关信息删除。不要让“[图片]”这种信息干扰;
# 如:[链接],[图片],[玫瑰],以及:????上线功能,H5红包,【收集表】2023年度满意度评价等与内容无关的字段去除掉。
import re
def clean_chat_text(chat_text):
# 定义正则表达式用于匹配链接、图片、特殊表情和无关字段
patterns = [
r"【收集表】 2023年度服务满意度评价",
r"https?://\S+",
r"\{[\w\W]*?\}",
r'\[.*?\]'
]
# 移除匹配到的内容
for pattern in patterns:
chat_text = re.sub(pattern, '', chat_text)
# 移除多余的空格和换行符
chat_text = re.sub(r'\n+', '\n', chat_text).strip()
return chat_text
# 遍历每个样本,清洗chat_text字段
for sample in train_data:
if "chat_text" in sample:
sample["chat_text"] = clean_chat_text(sample["chat_text"])
# 验证集划分
import json
import random
def split_data(data, validation_size):
# # 随机打乱数据
# random.shuffle(data)
# 划分数据
validation_data = data[:validation_size]
train_data = data[validation_size:]
return train_data, validation_data
validation_size