使用LLamaIndex微调跨编码器的指南

使用LLamaIndex微调跨编码器的指南

本文将介绍如何使用LLamaIndex和Gradient平台对跨编码器进行微调。我们将以QASPER数据集为例,详细解释从数据加载、数据生成、微调、评估等一系列步骤。

前提条件

如果您在colab上打开此Notebook,首先需要安装LlamaIndex和相关依赖。

# 安装所需依赖
%pip install llama-index-finetuning-cross-encoders
%pip install llama-index-llms-openai
!pip install llama-index
!pip install datasets --quiet
!pip install sentence-transformers --quiet
!pip install openai --quiet

加载数据集

我们将从HuggingFace Hub下载QASPER数据集。

from datasets import load_dataset
import random

# 从HuggingFace 下载QASPER数据集
dataset = load_dataset("allenai/qasper")

# 分割数据集为训练集、验证集和测试集
train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
test_dataset = dataset["test"]

random.seed(42)  # 设置随机种子以确保可复现性

# 随机抽样800行训练数据
train_sampled_indices = random.sample(range(len(train_dataset)), 800)
train_samples = [train_dataset[i] for i in train_sampled_indices]

# 随机抽样100行测试数据
test_sampled_indices = random.sample(range(len(test_dataset)), 80)
test_samples = [test_dataset[i] for i in test_sampled_indices]

生成微调数据集

接下来,我们将从训练数据中提取需要的文本和问题,以生成适用于跨编码器微调的数据集。

from typing import List

def get_full_text(sample: dict) -> str:
    title = sample["title"]
    abstract = sample["abstract"]
    sections_list = sample["full_text"]["section_name"]
    paragraph_list = sample["full_text"]["paragraphs"]
    combined_sections_with_paras = ""
    if len(sections_list) == len(paragraph_list):
        combined_sections_with_paras += title + "\t"
        combined_sections_with_paras += abstract + "\t"
        for index in range(len(sections_list)):
            combined_sections_with_paras += str(sections_list[index]) + "\t"
            combined_sections_with_paras += "".join(paragraph_list[index])
        return combined_sections_with_paras
    else:
        print("Not the same number of sections as paragraphs list")

def get_questions(sample: dict) -> List[str]:
    questions_list = sample["qas"]["question"]
    return questions_list

doc_qa_dict_list = []

for train_sample in train_samples:
    full_text = get_full_text(train_sample)
    questions_list = get_questions(train_sample)
    local_dict = {"paper": full_text, "questions": questions_list}
    doc_qa_dict_list.append(local_dict)

df_train = pd.DataFrame(doc_qa_dict_list)
df_train.to_csv("train.csv")

微调跨编码器

我们将使用LlamaIndex来微调跨编码器模型。

from llama_index.finetuning.cross_encoders import CrossEncoderFinetuneEngine

# 初始化跨编码器微调引擎
finetuning_engine = CrossEncoderFinetuneEngine(dataset=final_finetuning_data_list, epochs=2, batch_size=8)

# 开始微调
finetuning_engine.finetune()

# 将微调后的模型推送到HuggingFace Hub
finetuning_engine.push_to_hub(repo_id="你的用户名/微调模型名称")

评估

我们使用不同的度量标准对微调后的模型进行评估。

import pandas as pd
import ast

# 加载评估数据集
df_test = pd.read_csv("test.csv", index_col=0)
df_test["questions"] = df_test["questions"].apply(ast.literal_eval)
df_test["answers"] = df_test["answers"].apply(ast.literal_eval)

# 评估
from llama_index.core import VectorStoreIndex, Document

without_reranker_hits = 0
finetuned_reranker_hits = 0
total_number_of_context = 0

# 初始化评估引擎
for index, row in df_test.iterrows():
    documents = [Document(text=row["paper"])]
    query_list = row["questions"]
    context_list = row["context"]

    assert len(query_list) == len(context_list)
    vector_index = VectorStoreIndex.from_documents(documents)

    # 构建检索器
    retriever_with_finetuned_reranker = vector_index.as_query_engine(
        similarity_top_k=8, response_mode="no_text", node_postprocessors=[rerank_finetuned]
    )

    for index in range(len(query_list)):
        query = query_list[index]
        context = context_list[index]
        total_number_of_context += 1

        response_with_finetuned_reranker = retriever_with_finetuned_reranker.query(query)
        with_finetuned_reranker_nodes = response_with_finetuned_reranker.source_nodes

        for node in with_finetuned_reranker_nodes:
            if context in node.node.text or node.node.text in context:
                finetuned_reranker_hits += 1

results_dict = {
    "Metric": "Hits",
    "Finetuned_cross_encoder": finetuned_reranker_hits,
    "Total Relevant Context": total_number_of_context,
}
df_reranker_eval_results = pd.DataFrame(results_dict)
display(df_reranker_eval_results)

可能遇到的错误

  1. 数据加载错误:确保数据集路径和文件名正确,且数据格式符合预期。
  2. API连接错误:使用中专API地址替换其他API地址,确保能正常访问。
  3. 模型微调错误:检查模型配置和参数设置是否正确,确保训练数据格式符合模型要求。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值