【图像搜索】基于pytorch官方预训练模型的简易相似图片搜索

简易的相似图像搜索算法

图片数据库

原图片数据库

查询结果

查询结果

本文主要方法

流程

  1. 预训练模型 + 单张图像的特征 ( 逐个保存,形成数据库)+ 特征压缩(选做)

  2. 获取查询图像的特征向量

  3. 将查询的特征向量与数据库保存的所有特征进行余弦距离计算

  4. 返回结果

实际

  1. 编写自定义图片数据集读取代码
  2. pytorch SWAV预训练模型paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.)
  3. 抽取数据集每一张图片,模型推理得到,4096维度的特征向量 ,保存每一个特征向量
  4. 得到将需要查询的图片的4096维度的向量
  5. 计算查询向量与所有其他图片的余弦距离,并返回距离最近的topk个图片,完成查询

主要参考:

基于论文复杂结构_搜索算法(牛津数据集)

  • PyTorch+flask演示

  • End-to-end Learning of Deep Visual Representations for Image Retrieval

  • 分别学习相似图像与不相似图像的特征:

  • https://github.com/keshik6/deep-image-retrieval#pytorch-source-code

基于自编码机(AE)_搜索算法(以cafir-10为例):

  • PyTorch

  • 有效果图

  • https://blog.csdn.net/weixin_43786143/article/details/116137867

基于成对相似度数据_搜索算法(以cafir-10为例)

  • PaddlePaddle(百度飞浆)
  • https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_search/image_search.html

基于预训练模型_搜索算法(任一小数据集)

  • Keras
  • 基于vgg16预训练模型
  • github 433 star:https://github.com/willard-yuan/flask-keras-cnn-image-retrieval
  • 问题:图片较多时,无法直接使用该项目

Hash图像_特征的获取(多种hash算法)

  • PyTorch
  • 提取图像传统特征,并转换为hash编码
  • https://github.com/JohannesBuchner/imagehash

主要代码

测试数据集导入

python test_read_img.py

保存数据集中的图片到文件夹

修改文件夹路径,以及图片后缀,运行:

python save_fearures_2_npy.py

搜索相同、相似图片

python torch_pretrain_swav_search_one_image.py

代码

test_read_img.py

from  torch_dataset import  GameDataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

IMG_FOLDER = r'./datasets/imgs'
IMG_WIDTH, IMG_HEIGHT=128,128

dataset_transform = transforms.Compose([
          #
        transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#设置数据集大小,num_img=-1表示读取全部图片,num_img=10,表示读取前10张图片
test_dataset=GameDataset(IMG_FOLDER,num_img=-1,transform=dataset_transform)
print("num of imgs",len(test_dataset))
# num of imgs 89
one_image=test_dataset[5]
print(f"type : {type(one_image)},size: {one_image.size()}")
# type : <class 'torch.Tensor'>,size: torch.Size([3, 128, 128]

#使用dataloader抽取数据
# dataset 2 train_tensor
BATCH_SIZE = 10
train_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, drop_last=False, shuffle=False)
train_batch_images=next(iter(train_dataloader))
print(f"size: {train_batch_images.size()}")
# size: torch.Size([10, 3, 128, 128])

torch_pretrain_swav_search_one_image.py

import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

#my own
from torch_dataset import GameDataset
import shutil

from PIL import Image
# 取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor;
transform = transforms.Compose(
    [
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


device = torch.device("cuda"if torch.cuda.is_available() else "cpu")
def move_error_image(in_path):
    dirname,filename=os.path.split(in_path)
    outfile=os.path.join(dirname,"error/"+filename)
    shutil.move(in_path,outfile)

def get_query_feature(img_path,model):
    img_src = Image.open(img_path).convert('RGB')

    tensor_img=transform(img_src)
    input_img = tensor_img.to(device)
    input_img = input_img.unsqueeze(0)
    Encode = model(input_img)
    Encode=Encode.cpu().detach().numpy()
    Encode=Encode.flatten()
    return Encode

def copy_sort_by_query(group_filenames,out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    else:
        #删除之前的分类结果,并重新创建一个空文件夹
        shutil.rmtree(out_dir)
        os.makedirs(out_dir)

    for i in range(len(group_filenames)):
        img_path =(group_filenames[i])
        # print(img_path)
        img_dir,name = os.path.split(img_path)
        new_img_path=os.path.join(out_dir,str(i)+"_"+name)
        #复制图片
        shutil.copy(img_path, new_img_path)




def query_feat_in_datasets(query_feat,QUERY_IMAGES_FTS):
    similarity=[]
    for i, file in enumerate(tqdm(QUERY_IMAGES_FTS[:-1])):
        # print("file",file)
        file_fts = (np.load(file))
        file_fts = file_fts.flatten()
        cos_sim = np.dot(query_feat, file_fts.T) / (np.linalg.norm(query_feat) * np.linalg.norm(file_fts))

        similarity.append(cos_sim)
    the_last_image_id=i
    print("query_feat_in_datasets",similarity)
    return similarity,the_last_image_id

def get_query_results(similarity,query_img_file,QUERY_IMAGES,top_k = 20):
    # Get best matches using similarity

    similarity = np.asarray(similarity)
    # 矩阵运算后升维度将维度很重要
    print(similarity)
    # argsort()函数是将x中的元素从小到大排列,所以先加一个负号
    indexes = np.squeeze(-similarity).argsort(axis=0)[:top_k]  # 返回索引非常重要,
    print(indexes)
    topk_similarity = [similarity[index] for index in indexes]
    print("topk_similarity", topk_similarity)
    # print(similarity)
    # print(indexes)
    best_matches_paths = [QUERY_IMAGES[index] for index in indexes]

    # print(best_matches_paths)

    # save 搜索的同名文件创立
    out_query_fodler = query_img_file.replace(".jpg", "")
    # print(out_query_fodler)
    # print("query_img_file",query_img_file)
    best_matches_paths.insert(0, query_img_file)
    # print(best_matches_paths)
    copy_sort_by_query(best_matches_paths, out_query_fodler)


IMG_FOLDER = r'./datasets/imgs'
img_fts_dir = r"./datasets/imgs/features"
query_img_file=r'./datasets/imgs/ia_100000581.jpg'

if __name__ == '__main__':
#    device=torch.device("cuda:0")

# 设置图片缓存的总大小,为原来的100倍 ,否则图片多了会报错
    import PIL
    PIL.PngImagePlugin.MAX_TEXT_MEMORY=6710886400
    print("PIL.PngImagePlugin.MAX_TEXT_MEMORY",PIL.PngImagePlugin.MAX_TEXT_MEMORY)

#设置数据集大小
    youkia_dataset=GameDataset(IMG_FOLDER,num_img=-1,transform=transform)
    imgs_all_path=youkia_dataset.img_paths
    print("imgs_num:",len(youkia_dataset))

    # dataset 2 train_tensor
    BATCH_SIZE=1
    train_dataloader = DataLoader(youkia_dataset, batch_size=BATCH_SIZE,drop_last=True, shuffle=False)

    model = torch.hub.load('facebookresearch/swav', 'resnet50w2').cuda()
    model.eval() # eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值

    # Creat image database
    QUERY_IMAGES_FTS = [os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_fts_dir))]
    #
    QUERY_IMAGES = [os.path.join(IMG_FOLDER, file.replace(".npy",".jpg")) for file in sorted(os.listdir(img_fts_dir))]

    query_feat=get_query_feature(query_img_file,model)
    # print("query_feat",query_feat.shape)
    # Create similarity list
    similarity = []
    the_last_image_id=0
    try:
        similarity,the_last_image_id = query_feat_in_datasets(query_feat, QUERY_IMAGES_FTS)
        # print("similarity",similarity)
        get_query_results(similarity, query_img_file, QUERY_IMAGES)


    except Exception as e:
        get_query_results(similarity, query_img_file, QUERY_IMAGES)
        print("出现以下异常",e)
        print("an exception caught in line :", e.__traceback__.tb_lineno)  # 发生异常所在的行数
        print("imgs_all_path[i]",the_last_image_id,imgs_all_path[the_last_image_id])

        # move_error_image(imgs_all_path[i])

torch_dataset.py

import os

import os

from torch.utils.data import Dataset,DataLoader

from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from PIL import Image
# import PIL
import numpy as np


# to tensor
# Converts a PIL Image or numpy.ndarray (H x W x C) in the range
#     [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].


class GameDataset(Dataset):
    def __init__(self, img_dir, num_img=-1,transform=None, target_transform=None):
        # PIL.PngImagePlugin.MAX_TEXT_MEMORY = 6710886400
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        self.img_paths = []
        self.file_names = os.listdir(self.img_dir)
        for file in self.file_names:
            if os.path.splitext(file)[1].endswith(('jpg', 'png')):
                self.img_paths.append("%s/%s" % (self.img_dir, file))
        self.img_paths=self.img_paths[:num_img]
        # print("img_paths",self.img_paths)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):

        img_src = Image.open(self.img_paths[idx]).convert('RGB')
        if self.transform:
            img_src = self.transform(img_src)
        return img_src


save_fearures_2_npy.py

import torch
import torchvision.transforms as transforms
# from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset,DataLoader
#my own
from torch_dataset import GameDataset
from sklearn.decomposition import PCA
from tqdm import tqdm

IMG_WIDTH, IMG_HEIGHT=128,128
IMG_FOLDER = r'./datasets/imgs'
IMG_WIDTH, IMG_HEIGHT=128,128


dataset_transform = transforms.Compose([
          #
        transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def save_each_img_feature(train_dataset,model,imgs_all_path):
    import PIL
    # 设置图片缓存的总大小,为原来的100倍 ,否则图片多了会报错
    PIL.PngImagePlugin.MAX_TEXT_MEMORY = 6710886400
    print("PIL.PngImagePlugin.MAX_TEXT_MEMORY", PIL.PngImagePlugin.MAX_TEXT_MEMORY)
    for (i, (Image)) in enumerate(tqdm(train_dataset)):

        testEncode = model(Image)
        feature=testEncode.cpu().detach().numpy()
        feature=feature.flatten()
        if i==0:
            print("feature.shape",feature.shape)
        save_feature(feature,imgs_all_path[i],outdir_name='features')

def save_feature(feature,img_in_path,outdir_name='features'):
    dirname,filename=os.path.split(img_in_path)
    out_dir=os.path.join(dirname,outdir_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    save_path=out_dir+"/"+filename.replace(".jpg", "")
    print(save_path)
    np.save(save_path, feature)



if __name__=="__main__":

    rn50w2 = torch.hub.load('facebookresearch/swav', 'resnet50w2').cuda()
    # rn50w4 = torch.hub.load('facebookresearch/swav', 'resnet50w4')
    # rn50w5 = torch.hub.load('facebookresearch/swav', 'resnet50w5')
    # 读取数据

    game_data = GameDataset(IMG_FOLDER,num_img=-1, transform=dataset_transform)
    imgs_all_path=game_data.img_paths
    train_dataloader = DataLoader(game_data, batch_size=1, shuffle=False)
    train_features = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    # 推理
    rn50w2.eval()


    ###################
    save_each_img_feature(train_dataloader,rn50w2,imgs_all_path)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

曾小蛙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值