引言
本篇文章将介绍如何使用LLaVa(Large Language and Vision Assistant)和LlamaIndex进行多模态检索增强图像描述(RAG)。我们将详细说明如何通过LlamaIndex来加载LLaVa模型并结合多模态知识进行图像理解和描述生成。本文内容包括数据提取、构建检索器、图像理解、以及结合多模态知识库对图像进行增强描述。
环境配置
首先,我们需要安装一些必要的软件包来处理数据及进行多模态分析。
%pip install llama-index-vector-stores-qdrant
%pip install llama-index-readers-file
%pip install llama-index-multi-modal-llms-replicate
%pip install unstructured replicate
%pip install llama_index ftfy regex tqdm
%pip install git+https://github.com/openai/CLIP.git
%pip install torch torchvision
%pip install matplotlib scikit-image
%pip install -U qdrant_client
数据提取与处理
我们可以使用Unstructured来解析10-K文件中的表格元素和非表格元素:
import os
from llama_index.readers.file import FlatReader
from pathlib import Path
from llama_index.core.node_parser import UnstructuredElementNodeParser
reader = FlatReader()
docs_2021 = reader.load_data(Path("tesla_2021_10k.htm"))
node_parser = UnstructuredElementNodeParser()
if not os.path.exists("2021_nodes.pkl"):
raw_nodes_2021 = node_parser.get_nodes_from_documents(docs_2021)
pickle.dump(raw_nodes_2021, open("2021_nodes.pkl", "wb"))
else:
raw_nodes_2021 = pickle.load(open("2021_nodes.pkl", "rb"))
nodes_2021, objects_2021 = node_parser.get_nodes_and_objects(raw_nodes_2021)
构建可组合的检索器
我们基于提取的表格及其摘要搭建一个可组合的检索器:
from llama_index.core import VectorStoreIndex
vector_index = VectorStoreIndex(nodes=nodes_2021, objects=objects_2021)
query_engine = vector_index.as_query_engine(similarity_top_k=5, verbose=True)
图像理解与检索增强描述生成
通过LLaVa模型进行图像理解并结合知识库进行增强描述生成:
from llama_index.multi_modal_llms.replicate import ReplicateMultiModal
from llama_index.core.schema import ImageDocument
from llama_index.multi_modal_llms.replicate.base import REPLICATE_MULTI_MODAL_LLM_MODELS
llava_multi_modal_llm = ReplicateMultiModal(
model=REPLICATE_MULTI_MODAL_LLM_MODELS["llava-13b"],
max_new_tokens=200,
temperature=0.1,
)
prompt = "which Tesla factory is shown in the image? Please answer just the name of the factory."
llava_response = llava_multi_modal_llm.complete(
prompt=prompt,
image_documents=[ImageDocument(image_path=imageUrl)],
)
rag_response = query_engine.query(llava_response.text)
print(rag_response)
完整代码示例
以下是完整的代码示例,展示了使用LLaVa与LlamaIndex进行多模态检索增强图像描述的全过程:
import os
import pickle
from pathlib import Path
import openai
import requests
from PIL import Image
import matplotlib.pyplot as plt
from llama_index.readers.file import FlatReader
from llama_index.core.node_parser import UnstructuredElementNodeParser
from llama_index.core import VectorStoreIndex
from llama_index.multi_modal_llms.replicate import ReplicateMultiModal
from llama_index.core.schema import ImageDocument
from llama_index.multi_modal_llms.replicate.base import REPLICATE_MULTI_MODAL_LLM_MODELS
# 环境变量设置
REPLICATE_API_TOKEN = "..." # Your Replicate API token here
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
OPENAI_API_KEY = "..." # Your OpenAI API key here
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# 数据加载与处理
reader = FlatReader()
docs_2021 = reader.load_data(Path("tesla_2021_10k.htm"))
node_parser = UnstructuredElementNodeParser()
raw_nodes_2021 = node_parser.get_nodes_from_documents(docs_2021)
nodes_2021, objects_2021 = node_parser.get_nodes_and_objects(raw_nodes_2021)
vector_index = VectorStoreIndex(nodes=nodes_2021, objects=objects_2021)
query_engine = vector_index.as_query_engine(similarity_top_k=5, verbose=True)
# 图像加载与展示
imageUrl = "./texas.jpg"
image = Image.open(imageUrl).convert("RGB")
plt.figure(figsize=(16, 5))
plt.imshow(image)
plt.show()
# LLaVa模型与图像理解
llava_multi_modal_llm = ReplicateMultiModal(
model=REPLICATE_MULTI_MODAL_LLM_MODELS["llava-13b"],
max_new_tokens=200,
temperature=0.1,
)
prompt = "which Tesla factory is shown in the image? Please answer just the name of the factory."
llava_response = llava_multi_modal_llm.complete(
prompt=prompt,
image_documents=[ImageDocument(image_path=imageUrl)]
)
print(llava_response.text)
# 检索增强描述生成
rag_response = query_engine.query(llava_response.text)
print(rag_response)
潜在错误及解决方案
-
API Token 错误:
- 错误信息:
Authentication Error
- 解决方案:请确保你已正确设置了
REPLICATE_API_TOKEN
和OPENAI_API_KEY
。
- 错误信息:
-
数据文件加载错误:
- 错误信息:
FileNotFoundError
- 解决方案:请确认数据文件路径是否正确,并确保文件存在。
- 错误信息:
-
图像加载错误:
- 错误信息:
OSError: cannot identify image file
- 解决方案:确保图像文件路径正确,且文件格式为支持的格式(如JPG或PNG)。
- 错误信息:
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!
参考资料: