通过Faiss和DINOv2进行场景识别

目标:通过Faiss和DINOv2进行场景识别,确保输入的照片和注册的图片,保持内容一致。

MetaAI 通过开源 DINOv2,在计算机视觉领域取得了一个显着的里程碑,这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务(图像分类、实例检索、视频理解)以及像素级视觉任务(深度估计、语义分割)的通用特征。

Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集,甚至可能无法容纳在 RAM 中的向量集。

#!usr/bin/env python
# -*- coding:utf-8 -*-

# pip install transformers faiss-gpu torch Pillow
import torch
import os
import concurrent.futures
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import faiss
from tqdm import tqdm
import numpy as np
from utils_img import *

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


class SceneRecognition:

    def __init__(self, dimension=384, threshold=0.8, batch_size=128):
        """
        初始化 SceneRecognition 类
        Parameters:
        dimension (int): 向量的维度,默认为 384
        """
        # 加载模型和处理器
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.processor = AutoImageProcessor.from_pretrained(
            'models/dinov2-small')
        self.model = AutoModel.from_pretrained('models/dinov2-small').to(
            self.device)
        # 保存特征向量的维度
        self.dimension = dimension
        # 搭配 faiss.normalize_L2 创建 Faiss 的余弦相似度索引
        self.index = faiss.IndexFlatIP(self.dimension)
        # 保存图片向量对应的 id
        self.ids = []
        # 保存特征向量数据库
        self.db_path = "vector.index"
        # 图片路径
        self.images_path = []
        # 相似度阈值
        self.threshold = threshold
        self.batch_size = batch_size
        # 初始化
        self.init()

    def read_image_open(self, image_path):
        """
        打开图像并返回图像和图像路径
        Parameters:
        image_path (str): 图像的路径
        Returns:
        Image: 打开的图像
        str: 图像的路径
        """
        # 默认以 RGB 模式打开
        image = Image.open(image_path)
        return image, image_path

    def read_images_from_folder(self, file_path):
        """
        从文件夹中读取图像
        Parameters:
        file_path (str or list): 文件夹的路径或文件路径列表
        Returns:
        list: 图像列表
        list: 图像路径列表
        """
        image_list = []
        image_path_list = []
        try:
            if type(file_path) is not list:
                task_list = get_files(file_path)
            else:
                task_list = file_path
            # 使用线程池执行器
            with concurrent.futures.ThreadPoolExecutor(
                    max_workers=6) as executor:
                res = executor.map(self.read_image_open, task_list)
            image_list, image_path_list = list(zip(*res))
            return image_list, image_path_list
        except Exception as e:
            return None, None

    def download_model(self):
        """
        从 huggingface_hub 下载模型
        """
        from huggingface_hub import snapshot_download
        snapshot_download(
            repo_id="facebook/dinov2-small",  # 模型 ID
            local_dir="./models/dinov2-small")  # 指定本地地址保存模型

    def init(self):
        """
        初始化模型
        """
        batch_images = np.zeros((self.batch_size, 1920, 1080, 3),
                                dtype=np.int8)
        batch_inputs = self.processor(images=batch_images,
                                      return_tensors="pt").to(self.device)
        batch_outputs = self.model(**batch_inputs)

    def add_vector_to_index(self, embedding, index):
        """
        将向量添加到索引中
        Parameters:
        embedding (torch.Tensor): 特征向量
        index (faiss.Index): Faiss 索引
        """
        vector = embedding.detach().cpu().numpy()
        vector = np.float32(vector)
        faiss.normalize_L2(vector)
        index.add(vector)

    def create_database_from_images(self, file_path):
        """
        从图像创建数据库
        Parameters:
        file_path (str or list): 图像文件夹的路径或图像路径列表
        """
        image_list, self.images_path = self.read_images_from_folder(file_path)
        self.extract_features_form_images(image_list)

    def create_database_from_batch_images(self, file_path):
        """
        从批量图像创建数据库
        Parameters:
        file_path (str or list): 图像文件夹的路径或图像路径列表
        """
        image_list, self.images_path = self.read_images_from_folder(file_path)
        self.extract_features_form_batch_images(image_list)

    def extract_features_form_images(self, images):
        """
        从图像列表提取特征
        Parameters:
        images (list): 图像列表
        """
        for image_id, img in enumerate(images):
            with torch.no_grad():
                inputs = self.processor(images=img,
                                        return_tensors="pt").to(self.device)
                outputs = self.model(**inputs)
                features = outputs.last_hidden_state.mean(dim=1)
                self.add_vector_to_index(features, self.index)
                # 记录特征向量的 id
                self.ids.append(image_id)
        # faiss.write_index(self.index, self.db_path)

    def extract_features_form_batch_images(self, image_list):
        """
        从批量图像中提取特征
        Parameters:
        image_list (list): 图像列表
        """
        img_num = len(image_list)
        for beg_img_no in tqdm(range(0, img_num, self.batch_size),
                               desc="Extracting features"):
            end_img_no = min(img_num, beg_img_no + self.batch_size)
            batch_image = image_list[beg_img_no:end_img_no]
            with torch.no_grad():
                batch_inputs = self.processor(images=batch_image,
                                              return_tensors="pt").to(
                                                  self.device)
                batch_outputs = self.model(**batch_inputs)
                batch_features = batch_outputs.last_hidden_state.mean(dim=1)
                for image_id, features in enumerate(batch_features):
                    self.add_vector_to_index(features.reshape(1, -1),
                                             self.index)
                    # 记录特征向量的 id
                    image_id += beg_img_no
                    self.ids.append(image_id)
        # faiss.write_index(self.index, self.db_path)

    def search_similar_batch_images(self, user_image_path, k=1):
        """
        批量搜索相似图像
        Parameters:
        user_image_path (str or list): 用户图像的路径或路径列表
        k (int): 搜索的近邻数量
        Returns:
        list: 相似图像的路径列表
        """
        image_list, image_path_list = self.read_images_from_folder(
            user_image_path)
        img_num = len(image_list)
        similar_base_images = []
        similar_query_images = []
        for beg_img_no in tqdm(range(0, img_num, self.batch_size),
                               desc="Searching similar images"):
            end_img_no = min(img_num, beg_img_no + self.batch_size)
            batch_image = image_list[beg_img_no:end_img_no]
            batch_image_path = image_path_list[beg_img_no:end_img_no]
            with torch.no_grad():
                batch_inputs = self.processor(images=batch_image,
                                              return_tensors="pt").to(
                                                  self.device)
                batch_outputs = self.model(**batch_inputs)
                query_features = batch_outputs.last_hidden_state.mean(dim=1)
                batch_query_vector = query_features.detach().cpu().numpy()
                batch_query_vector = np.float32(batch_query_vector)
                faiss.normalize_L2(batch_query_vector)
                batch_distances, batch_indices = self.index.search(
                    batch_query_vector, k)

                if len(batch_distances) > 0:
                    for image_path, dis, ind in zip(batch_image_path,
                                                    batch_distances,
                                                    batch_indices):
                        # 保存超过阈值的最相似特征向量
                        if dis[0] > self.threshold:
                            similar_base_images.append(
                                self.images_path[ind[0]])
                            similar_query_images.append(image_path)
        dissimilar_images = list(
            set(image_path_list).difference(set(similar_query_images)))
        return similar_query_images

    def search_similar_images(self, query_image_path, k=1):
        """
        搜索相似图像
        Parameters:
        query_image_path (str): 查询图像的路径
        k (int): 搜索的近邻数量
        Returns:
        list: 相似图像的路径列表
        """
        img = Image.open(query_image_path).convert('RGB')
        with torch.no_grad():
            inputs = self.processor(images=img,
                                    return_tensors="pt").to(self.device)
            outputs = self.model(**inputs)
            query_features = outputs.last_hidden_state.mean(dim=1)
            query_vector = query_features.detach().cpu().numpy()
            query_vector = np.float32(query_vector)
            faiss.normalize_L2(query_vector)
            distances, indices = self.index.search(query_vector, k)
            similar_base_images = []
            similar_query_images = []
            if len(distances) > 0:
                for dis, ind in zip(distances, indices):
                    # 保存超过阈值的最相似特征向量
                    if dis > self.threshold:
                        similar_base_images.append(self.images_path[ind[0]])
                        similar_query_images.append(query_image_path)

        return similar_query_images

    def remove_image_by_id(self, image_id):
        """
        根据 ID 删除图像
        Parameters:
        image_id (int): 图像的 ID
        """
        # 先删除高索引的元素,再删除低索引的元素,避免索引错位的问题。
        index_to_remove = [
            i for i, stored_id in enumerate(self.ids) if stored_id == image_id
        ]  # 找到需要删除的特征向量的索引
        for i in sorted(index_to_remove, reverse=True):
            # 从 Faiss 索引中删除对应的特征向量
            self.index.remove_ids(np.array([i]))
            # 从 id 列表中删除对应的 id
            del self.ids[i]

    def get_num_images(self):
        """
        获取保存的图像数量
        Returns:
        int: 保存的图像数量
        """
        # 返回保存的图片向量数量
        return len(self.ids)

    def clear_data_base(self):
        """
        清空数据库
        """
        # 重置 Faiss 索引
        self.index.reset()
        # 清空 id 列表
        self.ids.clear()


if __name__ == '__main__':
    # 创建一个 SceneRecognition 实例
    scene_rec = SceneRecognition()
   
    scene_rec.create_database_from_batch_images('imgs')
    scene_rec.search_similar_batch_images('images')

参考网址:

  1. https://blog.csdn.net/level_code/article/details/137772620
  2. https://blog.csdn.net/weixin_38739735/article/details/136979083
  3. https://blog.csdn.net/u010970956/article/details/134945210
  4. https://blog.csdn.net/hh1357102/article/details/135066581
  5. https://blog.csdn.net/sinat_34770838/article/details/137021023
  6. https://zhuanlan.zhihu.com/p/704250322
  7. https://zhuanlan.zhihu.com/p/668148439
  8. https://blog.51cto.com/u_14273/10165547
  9. https://blog.csdn.net/ResumeProject/article/details/135350945
  10. https://www.zhihu.com/question/637818872/answer/3380169469
  11. https://zhuanlan.zhihu.com/p/644077057
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要使用faiss调用IndexNSG,需要执行以下步骤: 1. 安装faiss和IndexNSG 要使用faiss调用IndexNSG,需要首先安装faiss和IndexNSG。可以通过以下命令在Linux上安装faiss和IndexNSG: ``` pip install faiss-gpu pip install faiss-IndexNSG ``` 2. 加载数据 在使用IndexNSG之前,需要将数据加载到faiss中。可以使用以下代码将数据加载到faiss中: ```python import numpy as np import faiss # Load data data = np.load('data.npy') n, d = data.shape # Create index index = faiss.IndexFlatL2(d) index.add(data) ``` 3. 创建IndexNSG 要创建IndexNSG,需要使用faiss.IndexNSG类。可以使用以下代码创建一个IndexNSG: ```python # Create IndexNSG index_nsg = faiss.IndexNSG(d, 32, faiss.METRIC_L2) index_nsg.train(data) index_nsg.add(data) ``` 这里的d是数据的维度,32是NSG中每个节点的最大子节点数,METRIC_L2表示使用欧几里得距离度量。 4. 搜索 使用IndexNSG进行搜索与使用其他faiss索引相同。以下是一个简单的搜索示例: ```python # Search k = 10 query = np.random.rand(1, d).astype('float32') distances, indices = index_nsg.search(query, k) print('Query:\n', query) print('Distances:\n', distances) print('Indices:\n', indices) ``` 这里的k是要返回的最近邻居数量,query是查询向量。搜索结果包括每个最近邻居的距离和索引。 这些是使用faiss调用IndexNSG的基本步骤。注意,使用IndexNSG需要一些额外的配置和调整,以便获得最佳性能。可以参考faiss的文档和示例进行更深入的了解。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值