目标:通过Faiss和DINOv2进行场景识别,确保输入的照片和注册的图片,保持内容一致。
MetaAI 通过开源 DINOv2,在计算机视觉领域取得了一个显着的里程碑,这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务(图像分类、实例检索、视频理解)以及像素级视觉任务(深度估计、语义分割)的通用特征。
Faiss是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集,甚至可能无法容纳在 RAM 中的向量集。
#!usr/bin/env python
# -*- coding:utf-8 -*-
# pip install transformers faiss-gpu torch Pillow
import torch
import os
import concurrent.futures
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import faiss
from tqdm import tqdm
import numpy as np
from utils_img import *
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class SceneRecognition:
def __init__(self, dimension=384, threshold=0.8, batch_size=128):
"""
初始化 SceneRecognition 类
Parameters:
dimension (int): 向量的维度,默认为 384
"""
# 加载模型和处理器
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.processor = AutoImageProcessor.from_pretrained(
'models/dinov2-small')
self.model = AutoModel.from_pretrained('models/dinov2-small').to(
self.device)
# 保存特征向量的维度
self.dimension = dimension
# 搭配 faiss.normalize_L2 创建 Faiss 的余弦相似度索引
self.index = faiss.IndexFlatIP(self.dimension)
# 保存图片向量对应的 id
self.ids = []
# 保存特征向量数据库
self.db_path = "vector.index"
# 图片路径
self.images_path = []
# 相似度阈值
self.threshold = threshold
self.batch_size = batch_size
# 初始化
self.init()
def read_image_open(self, image_path):
"""
打开图像并返回图像和图像路径
Parameters:
image_path (str): 图像的路径
Returns:
Image: 打开的图像
str: 图像的路径
"""
# 默认以 RGB 模式打开
image = Image.open(image_path)
return image, image_path
def read_images_from_folder(self, file_path):
"""
从文件夹中读取图像
Parameters:
file_path (str or list): 文件夹的路径或文件路径列表
Returns:
list: 图像列表
list: 图像路径列表
"""
image_list = []
image_path_list = []
try:
if type(file_path) is not list:
task_list = get_files(file_path)
else:
task_list = file_path
# 使用线程池执行器
with concurrent.futures.ThreadPoolExecutor(
max_workers=6) as executor:
res = executor.map(self.read_image_open, task_list)
image_list, image_path_list = list(zip(*res))
return image_list, image_path_list
except Exception as e:
return None, None
def download_model(self):
"""
从 huggingface_hub 下载模型
"""
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="facebook/dinov2-small", # 模型 ID
local_dir="./models/dinov2-small") # 指定本地地址保存模型
def init(self):
"""
初始化模型
"""
batch_images = np.zeros((self.batch_size, 1920, 1080, 3),
dtype=np.int8)
batch_inputs = self.processor(images=batch_images,
return_tensors="pt").to(self.device)
batch_outputs = self.model(**batch_inputs)
def add_vector_to_index(self, embedding, index):
"""
将向量添加到索引中
Parameters:
embedding (torch.Tensor): 特征向量
index (faiss.Index): Faiss 索引
"""
vector = embedding.detach().cpu().numpy()
vector = np.float32(vector)
faiss.normalize_L2(vector)
index.add(vector)
def create_database_from_images(self, file_path):
"""
从图像创建数据库
Parameters:
file_path (str or list): 图像文件夹的路径或图像路径列表
"""
image_list, self.images_path = self.read_images_from_folder(file_path)
self.extract_features_form_images(image_list)
def create_database_from_batch_images(self, file_path):
"""
从批量图像创建数据库
Parameters:
file_path (str or list): 图像文件夹的路径或图像路径列表
"""
image_list, self.images_path = self.read_images_from_folder(file_path)
self.extract_features_form_batch_images(image_list)
def extract_features_form_images(self, images):
"""
从图像列表提取特征
Parameters:
images (list): 图像列表
"""
for image_id, img in enumerate(images):
with torch.no_grad():
inputs = self.processor(images=img,
return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
features = outputs.last_hidden_state.mean(dim=1)
self.add_vector_to_index(features, self.index)
# 记录特征向量的 id
self.ids.append(image_id)
# faiss.write_index(self.index, self.db_path)
def extract_features_form_batch_images(self, image_list):
"""
从批量图像中提取特征
Parameters:
image_list (list): 图像列表
"""
img_num = len(image_list)
for beg_img_no in tqdm(range(0, img_num, self.batch_size),
desc="Extracting features"):
end_img_no = min(img_num, beg_img_no + self.batch_size)
batch_image = image_list[beg_img_no:end_img_no]
with torch.no_grad():
batch_inputs = self.processor(images=batch_image,
return_tensors="pt").to(
self.device)
batch_outputs = self.model(**batch_inputs)
batch_features = batch_outputs.last_hidden_state.mean(dim=1)
for image_id, features in enumerate(batch_features):
self.add_vector_to_index(features.reshape(1, -1),
self.index)
# 记录特征向量的 id
image_id += beg_img_no
self.ids.append(image_id)
# faiss.write_index(self.index, self.db_path)
def search_similar_batch_images(self, user_image_path, k=1):
"""
批量搜索相似图像
Parameters:
user_image_path (str or list): 用户图像的路径或路径列表
k (int): 搜索的近邻数量
Returns:
list: 相似图像的路径列表
"""
image_list, image_path_list = self.read_images_from_folder(
user_image_path)
img_num = len(image_list)
similar_base_images = []
similar_query_images = []
for beg_img_no in tqdm(range(0, img_num, self.batch_size),
desc="Searching similar images"):
end_img_no = min(img_num, beg_img_no + self.batch_size)
batch_image = image_list[beg_img_no:end_img_no]
batch_image_path = image_path_list[beg_img_no:end_img_no]
with torch.no_grad():
batch_inputs = self.processor(images=batch_image,
return_tensors="pt").to(
self.device)
batch_outputs = self.model(**batch_inputs)
query_features = batch_outputs.last_hidden_state.mean(dim=1)
batch_query_vector = query_features.detach().cpu().numpy()
batch_query_vector = np.float32(batch_query_vector)
faiss.normalize_L2(batch_query_vector)
batch_distances, batch_indices = self.index.search(
batch_query_vector, k)
if len(batch_distances) > 0:
for image_path, dis, ind in zip(batch_image_path,
batch_distances,
batch_indices):
# 保存超过阈值的最相似特征向量
if dis[0] > self.threshold:
similar_base_images.append(
self.images_path[ind[0]])
similar_query_images.append(image_path)
dissimilar_images = list(
set(image_path_list).difference(set(similar_query_images)))
return similar_query_images
def search_similar_images(self, query_image_path, k=1):
"""
搜索相似图像
Parameters:
query_image_path (str): 查询图像的路径
k (int): 搜索的近邻数量
Returns:
list: 相似图像的路径列表
"""
img = Image.open(query_image_path).convert('RGB')
with torch.no_grad():
inputs = self.processor(images=img,
return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
query_features = outputs.last_hidden_state.mean(dim=1)
query_vector = query_features.detach().cpu().numpy()
query_vector = np.float32(query_vector)
faiss.normalize_L2(query_vector)
distances, indices = self.index.search(query_vector, k)
similar_base_images = []
similar_query_images = []
if len(distances) > 0:
for dis, ind in zip(distances, indices):
# 保存超过阈值的最相似特征向量
if dis > self.threshold:
similar_base_images.append(self.images_path[ind[0]])
similar_query_images.append(query_image_path)
return similar_query_images
def remove_image_by_id(self, image_id):
"""
根据 ID 删除图像
Parameters:
image_id (int): 图像的 ID
"""
# 先删除高索引的元素,再删除低索引的元素,避免索引错位的问题。
index_to_remove = [
i for i, stored_id in enumerate(self.ids) if stored_id == image_id
] # 找到需要删除的特征向量的索引
for i in sorted(index_to_remove, reverse=True):
# 从 Faiss 索引中删除对应的特征向量
self.index.remove_ids(np.array([i]))
# 从 id 列表中删除对应的 id
del self.ids[i]
def get_num_images(self):
"""
获取保存的图像数量
Returns:
int: 保存的图像数量
"""
# 返回保存的图片向量数量
return len(self.ids)
def clear_data_base(self):
"""
清空数据库
"""
# 重置 Faiss 索引
self.index.reset()
# 清空 id 列表
self.ids.clear()
if __name__ == '__main__':
# 创建一个 SceneRecognition 实例
scene_rec = SceneRecognition()
scene_rec.create_database_from_batch_images('imgs')
scene_rec.search_similar_batch_images('images')
参考网址:
- https://blog.csdn.net/level_code/article/details/137772620
- https://blog.csdn.net/weixin_38739735/article/details/136979083
- https://blog.csdn.net/u010970956/article/details/134945210
- https://blog.csdn.net/hh1357102/article/details/135066581
- https://blog.csdn.net/sinat_34770838/article/details/137021023
- https://zhuanlan.zhihu.com/p/704250322
- https://zhuanlan.zhihu.com/p/668148439
- https://blog.51cto.com/u_14273/10165547
- https://blog.csdn.net/ResumeProject/article/details/135350945
- https://www.zhihu.com/question/637818872/answer/3380169469
- https://zhuanlan.zhihu.com/p/644077057