使用多模态LLM进行半结构化图像检索

在本文中,我们将展示如何在图像上执行半结构化检索。

我们可以通过Gemini Pro Vision从图像中推断出结构化输出,然后将这些结构化输出索引到向量数据库中。通过自动检索功能,我们可以对这些数据进行结构化和语义查询。

环境设置

首先,安装所需的Python包:

%pip install llama-index-multi-modal-llms-gemini
%pip install llama-index-vector-stores-qdrant
%pip install llama-index-embeddings-gemini
%pip install llama-index-llms-gemini
!pip install llama-index 'google-generativeai>=0.3.0' matplotlib qdrant_client

获取Google API密钥

import os

GOOGLE_API_KEY = "<你的GOOGLE_API密钥>"
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

下载图像

在这里我们下载了来自Kaggle的完整SROIE v2数据集,该数据集包含扫描的收据图像。

获取图像文件

from pathlib import Path
import random
from typing import Optional

def get_image_files(dir_path, sample: Optional[int] = 10, shuffle: bool = False):
    dir_path = Path(dir_path)
    image_paths = []
    for image_path in dir_path.glob("*.jpg"):
        image_paths.append(image_path)

    if shuffle:
        random.shuffle(image_paths)
    
    return image_paths[:sample] if sample else image_paths

image_files = get_image_files("SROIE2019/test/img", sample=100)

使用Gemini提取结构化输出

定义ReceiptInfo类

from pydantic import BaseModel, Field

class ReceiptInfo(BaseModel):
    company: str = Field(..., description="Company name")
    date: str = Field(..., description="Date field in DD/MM/YYYY format")
    address: str = Field(..., description="Address")
    total: float = Field(..., description="total amount")
    currency: str = Field(..., description="Currency of the country (in abbreviations)")
    summary: str = Field(..., description="Extracted text summary of the receipt")

定义pydantic_gemini函数

from llama_index.multi_modal_llms.gemini import GeminiMultiModal
from llama_index.core.program import MultiModalLLMCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser

prompt_template_str = "Can you summarize the image and return a response with the following JSON format:"

async def pydantic_gemini(output_class, image_documents, prompt_template_str):
    gemini_llm = GeminiMultiModal(api_key=GOOGLE_API_KEY, model_name="models/gemini-pro-vision")

    llm_program = MultiModalLLMCompletionProgram.from_defaults(
        output_parser=PydanticOutputParser(output_class),
        image_documents=image_documents,
        prompt_template_str=prompt_template_str,
        multi_modal_llm=gemini_llm,
        verbose=True
    )

    response = await llm_program.acall()
    return response

运行图像文件并提取数据

from llama_index.core import SimpleDirectoryReader
from llama_index.core.async_utils import run_jobs

async def aprocess_image_file(image_file):
    print(f"Image file: {image_file}")
    img_docs = SimpleDirectoryReader(input_files=[image_file]).load_data()
    output = await pydantic_gemini(ReceiptInfo, img_docs, prompt_template_str)
    return output

async def aprocess_image_files(image_files):
    tasks = [aprocess_image_file(image_file) for image_file in image_files]
    outputs = await run_jobs(tasks, show_progress=True, workers=5)
    return outputs

outputs = await aprocess_image_files(image_files)

转换结构化表示为TextNode对象

from llama_index.core.schema import TextNode
from typing import List

def get_nodes_from_objs(objs: List[ReceiptInfo], image_files: List[str]) -> List[TextNode]:
    nodes = []
    for image_file, obj in zip(image_files, objs):
        node = TextNode(
            text=obj.summary,
            metadata={
                "company": obj.company,
                "date": obj.date,
                "address": obj.address,
                "total": obj.total,
                "currency": obj.currency,
                "image_file": str(image_file),
            },
            excluded_embed_metadata_keys=["image_file"],
            excluded_llm_metadata_keys=["image_file"]
        )
        nodes.append(node)
    return nodes

nodes = get_nodes_from_objs(outputs, image_files)
print(nodes[0].get_content(metadata_mode="all"))

将这些节点索引到向量存储中

import qdrant_client
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.llms.gemini import Gemini
from llama_index.core import Settings

# 创建本地Qdrant向量存储
client = qdrant_client.QdrantClient(path="qdrant_gemini")
vector_store = QdrantVectorStore(client=client, collection_name="collection")

# 全局设置
Settings.embed_model = GeminiEmbedding(model_name="models/embedding-001", api_key=GOOGLE_API_KEY)
Settings.llm = Gemini(api_key=GOOGLE_API_KEY)

storage_context = StorageContext.from_defaults(vector_store=vector_store)

index = VectorStoreIndex(nodes=nodes, storage_context=storage_context)

定义自动检索器

from llama_index.core.vector_stores import MetadataInfo, VectorStoreInfo
from llama_index.core.retrievers import VectorIndexAutoRetriever

vector_store_info = VectorStoreInfo(
    content_info="Receipts",
    metadata_info=[
        MetadataInfo(name="company", description="The name of the store", type="string"),
        MetadataInfo(name="address", description="The address of the store", type="string"),
        MetadataInfo(name="date", description="The date of the purchase (in DD/MM/YYYY format)", type="string"),
        MetadataInfo(name="total", description="The final amount", type="float"),
        MetadataInfo(name="currency", description="The currency of the country the purchase was made (abbreviation)", type="string")
    ]
)

retriever = VectorIndexAutoRetriever(
    index=index,
    vector_store_info=vector_store_info,
    similarity_top_k=2,
    empty_query_top_k=10,
    verbose=True
)

运行一些查询

from IPython.display import Image

def display_response(nodes: List[TextNode]):
    for node in nodes:
        print(node.get_content(metadata_mode="all"))
        display(Image(filename=node.metadata["image_file"], width=200))

nodes = retriever.retrieve("Tell me about some restaurant orders of noodles with total < 25")
display_response(nodes)

nodes = retriever.retrieve("Tell me about some grocery purchases")
display_response(nodes)

可能遇到的错误

  1. API Key错误:确保你的GOOGLE API密钥是有效的,并已正确设置在环境变量中。
  2. 数据文件未找到:确保你下载的数据集路径正确无误。
  3. 包安装失败:确保你有稳定的网络连接以下载所需的包。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值