在这篇文章中,我们将展示如何使用LlamaIndex、GPT4-V和CLIP构建一个图像到图像的检索系统。我们将从维基百科页面下载文本和图像,建立多模态索引和向量存储,并使用多模态检索器根据图像查询检索相关图像。最后,我们将使用GPT4-V来推理输入图像和检索到的图像之间的相关性。
环境配置
首先,确保你已经安装了必要的库:
%pip install llama-index-multi-modal-llms-openai
%pip install llama-index-vector-stores-qdrant
%pip install llama_index ftfy regex tqdm
%pip install git+https://github.com/openai/CLIP.git
%pip install torch torchvision
%pip install matplotlib scikit-image
%pip install -U qdrant_client
下载维基百科的图像和文本
import os
import wikipedia
import urllib.request
from pathlib import Path
image_path = Path("mixed_wiki")
image_uuid = 0
image_metadata_dict = {}
MAX_IMAGES_PER_WIKI = 30
wiki_titles = [
"Vincent van Gogh",
"San Francisco",
"Batman",
"iPhone",
"Tesla Model S",
"BTS band",
]
if not image_path.exists():
Path.mkdir(image_path)
for title in wiki_titles:
images_per_wiki = 0
print(title)
try:
page_py = wikipedia.page(title)
list_img_urls = page_py.images
for url in list_img_urls:
if url.endswith(".jpg") or url.endswith(".png"):
image_uuid += 1
image_file_name = title + "_" + url.split("/")[-1]
image_metadata_dict[image_uuid] = {
"filename": image_file_name,
"img_path": "./" + str(image_path / f"{image_uuid}.jpg"),
}
urllib.request.urlretrieve(
url, image_path / f"{image_uuid}.jpg"
)
images_per_wiki += 1
if images_per_wiki > MAX_IMAGES_PER_WIKI:
break
except:
print(str(Exception("No images found for Wikipedia page: ")) + title)
continue
绘制维基百科的图像
from PIL import Image
import matplotlib.pyplot as plt
import os
image_paths = []
for img_path in os.listdir("./mixed_wiki"):
image_paths.append(str(os.path.join("./mixed_wiki", img_path)))
def plot_images(image_paths):
images_shown = 0
plt.figure(figsize=(16, 9))
for img_path in image_paths:
if os.path.isfile(img_path):
image = Image.open(img_path)
plt.subplot(3, 3, images_shown + 1)
plt.imshow(image)
plt.xticks([])
plt.yticks([])
images_shown += 1
if images_shown >= 9:
break
plot_images(image_paths)
构建多模态索引和向量存储
from llama_index.core.indices import MultiModalVectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import SimpleDirectoryReader, StorageContext
import qdrant_client
client = qdrant_client.QdrantClient(path="qdrant_img_db")
text_store = QdrantVectorStore(client=client, collection_name="text_collection")
image_store = QdrantVectorStore(client=client, collection_name="image_collection")
storage_context = StorageContext.from_defaults(vector_store=text_store, image_store=image_store)
documents = SimpleDirectoryReader("./mixed_wiki/").load_data()
index = MultiModalVectorStoreIndex.from_documents(documents, storage_context=storage_context)
检索图像
input_image = "./mixed_wiki/2.jpg"
plot_images([input_image])
retriever_engine = index.as_retriever(image_similarity_top_k=4)
retrieval_results = retriever_engine.image_to_image_retrieve(input_image)
retrieved_images = []
for res in retrieval_results:
retrieved_images.append(res.node.metadata["file_path"])
plot_images(retrieved_images[1:])
使用GPT4-V推理图像相关性
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.core.schema import ImageDocument
image_documents = [ImageDocument(image_path=input_image)]
for res_img in retrieved_images[1:]:
image_documents.append(ImageDocument(image_path=res_img))
openai_mm_llm = OpenAIMultiModal(
model="gpt-4-vision-preview", api_key="YOUR_API_KEY", max_new_tokens=1500
)
response = openai_mm_llm.complete(
prompt="Given the first image as the base image, what the other images correspond to?",
image_documents=image_documents,
)
print(response)
可能遇到的错误
- 网络错误: 下载维基百科图像时可能会遇到网络连接问题,请确保网络畅通。
- 图像文件损坏: 下载的图像文件可能损坏,导致无法打开,可以添加异常处理机制。
- API密钥错误: 使用GPT4-V时,API密钥可能无效或过期,请确保使用有效的API密钥。
如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!
参考资料: