【以图搜图代码实现2】--faiss工具实现犬类以图搜图

第一篇:【以图搜图代码实现】–犬类以图搜图示例 使用保存成h5文件,使用向量积来度量相似性,实现了以图搜图,说明了可以优化的点。
第二篇:【使用resnet18训练自己的数据集】 准对模型问题进行了优化,取得了显著性的效果。
本篇继续第一篇中所说的优化方向,使用faiss实现以图搜图。

1.faiss使用介绍

Faiss的全称是Facebook AI Similarity Search,是FaceBook针对大规模相似度检索问题开发的一个工具,底层是使用C++代码实现的,提供了python的接口,号称对10亿量级的索引可以做到毫秒级检索。

使用faiss的基本步骤
1、数据转换:把原始数据转换为"float32"数据类型的向量。
2、index构建:用 faiss 构建index
3、数据添加:将向量add到创建的index中
4、通过创建的index进行检索

1.创建索引

import faiss

def create_index(datas_embedding):
    # 构建索引,L2代表构建的index采用的相似度度量方法为L2范数
    # 必须传入一个向量的维度,创建一个空的索引
    index = faiss.IndexFlatL2(datas_embedding.shape[1])  
    # 把向量数据加入索引
    index.add(datas_embedding)   
    return index

2.保存索引

def faiss_index_save(faiss_index, save_file_location):
    faiss.write_index(faiss_index, save_file_location)

3.加载索引

def faiss_index_load(faiss_index_save_file_location):
    index = faiss.read_index(faiss_index_save_file_location)
    return index

4.向索引中添加向量

def index_data_add(faiss_index, img_path):
    # 获得索引向量的数量
    print(faiss_index.ntotal)
    img_embedding = extract_image_features(img_path)
    faiss_index.add(img_embedding)
    print(faiss_index.ntotal)

5.删除索引中的向量

def index_data_delete(faiss_index):
    print(faiss_index.ntotal)
    # remove, 指定要删除的向量id,是一个np的array
    faiss_index.remove_ids(np.array([0]))
    print(faiss_index.ntotal)

可以看出使用Faiss工具更加的灵活,可以向索引中添加和删除向量。

2.faiss实现以图搜图

本篇代码有部分是在前两篇的基础之上的,这里使用11类犬类数据集微调之后的resnet18进行特征提取。
第一篇:【以图搜图代码实现】–犬类以图搜图示例
第二篇:【使用resnet18训练自己的数据集】

数据集准备和下载可以去看第二篇文章。

1.模型加载

为了更好的适配,对第一篇中的resnet18的初始化方法进行了修改,如下:

@Project :ImageRec
@File    :resnet18.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models

class ResNet18:
    def __init__(self,
                 out_feature = 11,
                 model_path='E:\\xxx\\ImageRec\\weights\\resnet18.pth'):
        self.trans = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        print("-----------loading resnet18------------")
        self.model = models.resnet18()
        num_feats = self.model.fc.in_features
        self.model.fc = nn.Linear(num_feats, out_feature)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()


    def extract_image_features(self, img_path):

        image = Image.open(img_path).convert('RGB')
        image_tensor = self.trans(image).unsqueeze(0)
        with torch.no_grad():
            features = self.model(image_tensor)
        return features

其中out_feature 根据自己的数据集的类别个数进行更改,我这里的犬类是11种。model_path是训练好的保存的权重文件【训练过程可以去看第二篇】

2.文件名映射

在第一篇:【以图搜图代码实现】–犬类以图搜图示例 中使用的是保存成h5文件,索引是没有要求是整数的,这里faiss要求是整数,搞了一个映射方法,同时也是为了在后面可视化的时候,能根据索引再解码得到对应的文件路径。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :Imgmap.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/29 18:02 
'''
import os
import uuid
import numpy as np


def getImgMap(img_path):
	# 为类别生成一个映射文件
    subnames = [f.split('\\')[-1] for f in os.listdir(img_path)]
    element_mapping = {}
 
    for i in range(len(subnames)):
        unique_id = str(i+2024)
        element_mapping[unique_id] = subnames[i]

    return element_mapping

def valueGetKey(mapping, target_value):


    for key, value in mapping.items():
        if value == target_value:
            # print(f"值 '{target_value}' 对应的键是: {key}")
            break
    return key


def nameMap(imgnames, img_path='E:\\xxx\\datas\\pet_dog\\train'):
    '''
    getImagVector函数得到的image_ids在保存为h5文件时进行了编码
    现在faiss工具中index需要是int类型的,这里进行映射转化
    :param img_path: 数据集目录,来得到类别映射
    :param imgnames: 需要映射的图片名称,解码之后是“中华田园犬_0”格式
    这里传参是列表
    :return:
    '''
    element_mapping = getImgMap(img_path)
    decode_names = [imgname.decode('utf-8') for imgname in imgnames]

    name_ids=[]
    for decode_name in decode_names:
        cla_name = decode_name.split("_")[0]
        img_name = decode_name.split("_")[-1]
        key = valueGetKey(element_mapping, cla_name)
        name_id = key+img_name
        name_ids.append(name_id)

    name_ids=np.array(name_ids).astype('int32')

    return name_ids




if __name__ == "__main__":

    database = 'E:\\xxx\\datas\\pet_dog\\train'
    element_mapping = getImgMap(database)
    print(element_mapping)
    print(element_mapping.get("2024"))

映射文件:

{‘2024’: ‘中华田园犬’, ‘2025’: ‘吉娃娃’, ‘2026’: ‘哈士奇’, ‘2027’: ‘德牧’, ‘2028’: ‘拉布拉多’, ‘2029’: ‘杜宾’, ‘2030’: ‘柴犬’, ‘2031’: ‘法国斗牛’, ‘2032’: ‘萨摩耶’, ‘2033’: ‘藏獒’, ‘2034’: ‘金毛’}
nameMap函数是将之前编码的图像名称进行解码,然后重新编码,编码成20240,20301,分别表示的中华田园犬文件夹下的0.jpg, 柴犬下面的1.jpg。这都是为了可视化的时候进行追溯,得到文件路径。
在这里插入图片描述

3.以图搜图实现

定义了一个类ImageRetrival,使用faiss实现创建索引,保存索引,加载索引和图像检索功能

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :faiss_index.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30 15:04 
'''


import os
import faiss
from utils.split_data import array_norm
from utils.Imgmap import nameMap, getImgMap
from model import ResNet18
from save_feature import getImagVectors
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei')  # 黑体

class ImageRetrival:
    def __init__(self, model_path,
                 index_dim=None):
        self.index_dim = index_dim
        self.index = faiss.IndexFlatL2(self.index_dim)
        self.model_path = model_path

    def build_index(self, image_files):
        # image_vectors图片特征,image_ids对应的标签
        image_vectors, image_ids = getImagVectors(image_files)

        # image_ids 在之前保存为h5文件时进行了编码,这里进行映射
        name_ids = nameMap(image_ids)

        index = faiss.IndexIDMap(self.index)
        index.add_with_ids(image_vectors, name_ids)
        return index

    def save_index(self, index, index_path):
        faiss.write_index(index, index_path)

    def load_index(self, index_path):
        return faiss.read_index(index_path)

    def image_topK_search(self, index, input_image, topK=None):

        resnet18 = ResNet18(out_feature=11,
                            model_path=self.model_path)
        queryVec = resnet18.extract_image_features(input_image)


        dist, ind = index.search(queryVec, topK)
        dist, ind = dist.flatten(), ind.flatten()
        res = array_norm(dist, ind)
        return res

4.运行调用

if __name__=="__main__":

    model_path='E:\\xxx\\Pycharm_files\\ImageRec\\weights\\resnet18.pth'
    # 1.创建索引
    imageRetrival = ImageRetrival(model_path=model_path,
                                  index_dim=11)
    image_files = 'E:\\xxx\\datas\\pet_dog\\train'
    save_index = "./weights/dog.index"
    index = imageRetrival.build_index(image_files)

    # # 2.保存索引
    imageRetrival.save_index(index, save_index)

    # 3.加载索引
    index_load = imageRetrival.load_index(save_index)
    #
    # # 4.相似度匹配
    input_image = './data/pic/德牧.jpg'
    out = imageRetrival.image_topK_search(index_load, input_image, topK=3)
    print(out)
    showFaissRes(image_files, input_image, out)

运行时选择性注销其中的某一步骤。
最后是可视化实现showFaissRes

5.可视化实现


def showFaissRes(image_files, input_image, faissRes):
    '''
    对faiss得到的结果进行可视化
    :param image_files: 图片数据库
    :param input_image: 查询图片路径
    :param faissRes: 返回的topk跟距离最近的结果[(ind, score), (ind, score)]
    :return:
    '''
    scores = []
    imgs = []
    info = []

    # 1.得到图片名称的映射
    element_mapping = getImgMap(image_files)
    imgs.append(mpimg.imread(input_image))
    info.append(input_image.split("/")[-1])

    for i in range(len(faissRes)):
        score = faissRes[i][1]
        ind = str(faissRes[i][0])
        scores.append(score)

        # 根据索引构建原本的图像路径ind格式:20276,前四个是类别表示
        claName = element_mapping.get(ind[:4])
        imgName = ind[4:]+".jpg"
        imgpath = image_files +"\\"+ claName+ "\\"+imgName
        imgs.append(mpimg.imread(imgpath))

        info.append(claName+"_"+ imgName+"_"+ str(score))
        print("图片名称是: " + claName+ imgName + " 对应得分是: %f" %score)

    num = int((len(faissRes) + 1) // 2)+1
    fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))

    # 确保即使只有一个子图,也可以进行索引
    if not isinstance(axs, np.ndarray):
        axs = np.array([[axs]])

    # 显示图像
    flat_index = 0
    for i in range(num):
        for j in range(num):
            if flat_index < len(imgs):
                img = imgs[flat_index]
                axs[i, j].imshow(img, cmap='gray')
                axs[i, j].axis('off')
                axs[i, j].set_title(info[flat_index])
                flat_index += 1
            else:
                axs[i, j].set_visible(False)

    plt.tight_layout()
    plt.show()

3.效果对比

第一篇:【以图搜图代码实现】–犬类以图搜图示例 预训练的resnet18

第二篇:【使用resnet18训练自己的数据集】 微调的resnet18
在这里插入图片描述

本章 Faiss实现: 分数不重要,本篇对分数进行了归一化。
在这里插入图片描述
准确性更高了。

Java中实现搜索(也称为像检索或相似度匹配)通常涉及到计算机视觉和机器学习技术,特别是使用深度学习模型如卷积神经网络(CNN)。这里我们不直接给出完整的代码,但可以提供一个基本框架和技术要点: 1. 像预处理:首先,你需要将片转换为一维向量,这通常是通过使用特征提取工具(例如`OpenCV`库)进行的。常用的预处理步骤包括缩放、裁剪、归一化等。 ```java import org.opencv.core.Core; import org.opencv.imgcodecs.Imgcodecs; import org.opencv.imgproc.Imgproc; Mat image = Imgcodecs.imread("image.jpg"); // Resize, crop, and normalize the image ``` 2. 特征提取:使用深度学习库(如`TensorFlow`, `Keras`, 或者`Dlib`的`face_recognition`模块)提取片的特征向量。对于CNN,这些向量代表了片的主要内容。 ```java FeatureExtractor featureExtractor = new FeatureExtractor(); MatOfFloat featureVector = featureExtractor.extract(image); ``` 3. 建立索引:将提取到的特征向量存储在一个数据结构中,比如FLANN(Fast Library for Approximate Nearest Neighbors)或Annoy(Approximate Nearest Neighbors Oh Yeah),用于快速查询相似片。 ```java Indexer indexer = new Indexer(); indexer.add(featureVector); // When searching, use indexer.search(queryFeatureVector, k) to find top-k similar images. ``` 4. 查询阶段:对新来的像执行同样的特征提取,并使用索引来查找最相似的像。 ```java Mat queryImage = Imgcodecs.imread("query.jpg"); MatOfFloat queryFeatureVector = featureExtractor.extract(queryImage); TopKNearest neighbors = indexer.search(queryFeatureVector, k); // k表示想要找到的最接近的片数量 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值