随着人工智能技术的飞速发展,多模态信息检索与理解成为了一个热门的研究领域。本文将介绍如何使用CLIP(Contrastive Language–Image Pre-training)模型进行图像检索,并结合一个大型语言模型(MLLM)对检索到的图像进行理解并回答问题。之前写过一篇博客是用CLIP做图文检索的,链接如下:多模态图文检索实战——基于CLIP实现图文检索系统(附源码)当时有提到这是RAG的一种前身,本篇博客就再此基础上增加MLLM对检索到的图片做理解,以此来回答用户的问题!
一、导入相关依赖包
首先,我们需要导入所需的Python库。这些库包括用于处理CLIP模型的transformers
库,用于图像处理的PIL
库,以及用于数值计算的numpy
库。此外,我们还需要导入用于加载和推理大型语言模型的lmdeploy
库。
import time
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore")
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image
没安装使用pip安装即可,对于版本没有特殊的要求。
二、数据集下载
本文所使用的数据集是2017 AI Challenger,数据集对给定的每一张图片有五句话的中文描述。数据集包含30万张图片,150万句中文描述。训练集:210,000 张,验证集:30,000 张,测试集 A:30,000 张,测试集 B:30,000 张。为了演示和下载方便,本博客下载了验证集,并选择前2000张图片作为匹配的图片库。 数据集的下载链接和详细的描述如下:Ai Challenger Caption图像中文描述(2017)
三、加载CLIP模型与处理器
接下来,我们加载CLIP模型和处理器。这里使用的是clip-vit-large-patch14
模型,它结合了Vision Transformer(ViT)和对比学习技术,能够在图像和文本之间建立强大的关联。
model = CLIPModel.from_pretrained("/root/model/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("/root/model/clip-vit-large-patch14", max_length=77)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
模型下载可以参考我之前这篇博客,
LLM/深度学习Linux常用指令与代码(进阶),在这里也给出clip-vit-large-patch14的huggingface的下载链接,需要的小伙伴可以自行下载即可,clip-vit-large-patch14。
四、定义辅助函数
为了处理文本和图像的嵌入,我们定义了几个辅助函数,详细介绍可以参考这篇博客:多模态图文检索实战——基于CLIP实现图文检索系统(附源码)
text_embedding(text)
:生成文本嵌入。get_image_embedding(image_path)
:从本地路径读取图像并生成图像嵌入。cosine_similarity(vec1, vec2)
:计算两个向量之间的余弦相似度。calulate_similarity(query, candidates, query_type="text")
:计算查询与候选图像之间的相似度,并返回最匹配的图像。getImage_embedding(candidates)
:批量生成候选图像的嵌入。
# 函数:生成文本嵌入
def text_embedding(text):
inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
embedding = model.get_text_features(**inputs)
return embedding.cpu().numpy()
def get_image_embedding(image_path):
try:
# 从本地路径读入图片
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features.cpu().numpy()
except Exception as e:
print(f"Error loading image {
e}")