文章目录
简易的相似图像搜索算法
图片数据库
查询结果
本文主要方法
流程
-
预训练模型 + 单张图像的特征 ( 逐个保存,形成数据库)+ 特征压缩(选做)
-
获取查询图像的特征向量
-
将查询的特征向量与数据库保存的所有特征进行余弦距离计算
-
返回结果
实际
- 编写自定义图片数据集读取代码
- pytorch SWAV预训练模型 (paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.)
- 抽取数据集每一张图片,模型推理得到,4096维度的特征向量 ,保存每一个特征向量
- 得到将需要查询的图片的4096维度的向量
- 计算查询向量与所有其他图片的余弦距离,并返回距离最近的topk个图片,完成查询
主要参考:
基于论文复杂结构_搜索算法(牛津数据集)
-
PyTorch+flask演示
-
End-to-end Learning of Deep Visual Representations for Image Retrieval
-
分别学习相似图像与不相似图像的特征:
-
https://github.com/keshik6/deep-image-retrieval#pytorch-source-code
基于自编码机(AE)_搜索算法(以cafir-10为例):
-
PyTorch
-
有效果图
-
https://blog.csdn.net/weixin_43786143/article/details/116137867
基于成对相似度数据_搜索算法(以cafir-10为例)
- PaddlePaddle(百度飞浆)
- https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_search/image_search.html
基于预训练模型_搜索算法(任一小数据集)
- Keras
- 基于vgg16预训练模型
- github 433 star:https://github.com/willard-yuan/flask-keras-cnn-image-retrieval
- 问题:图片较多时,无法直接使用该项目
Hash图像_特征的获取(多种hash算法)
- PyTorch
- 提取图像传统特征,并转换为hash编码
- https://github.com/JohannesBuchner/imagehash
主要代码
测试数据集导入
python test_read_img.py
保存数据集中的图片到文件夹
修改文件夹路径,以及图片后缀,运行:
python save_fearures_2_npy.py
搜索相同、相似图片
python torch_pretrain_swav_search_one_image.py
代码
test_read_img.py
from torch_dataset import GameDataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
IMG_FOLDER = r'./datasets/imgs'
IMG_WIDTH, IMG_HEIGHT=128,128
dataset_transform = transforms.Compose([
#
transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
#设置数据集大小,num_img=-1表示读取全部图片,num_img=10,表示读取前10张图片
test_dataset=GameDataset(IMG_FOLDER,num_img=-1,transform=dataset_transform)
print("num of imgs",len(test_dataset))
# num of imgs 89
one_image=test_dataset[5]
print(f"type : {type(one_image)},size: {one_image.size()}")
# type : <class 'torch.Tensor'>,size: torch.Size([3, 128, 128]
#使用dataloader抽取数据
# dataset 2 train_tensor
BATCH_SIZE = 10
train_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, drop_last=False, shuffle=False)
train_batch_images=next(iter(train_dataloader))
print(f"size: {train_batch_images.size()}")
# size: torch.Size([10, 3, 128, 128])
torch_pretrain_swav_search_one_image.py
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
#my own
from torch_dataset import GameDataset
import shutil
from PIL import Image
# 取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor;
transform = transforms.Compose(
[
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
device = torch.device("cuda"if torch.cuda.is_available() else "cpu")
def move_error_image(in_path):
dirname,filename=os.path.split(in_path)
outfile=os.path.join(dirname,"error/"+filename)
shutil.move(in_path,outfile)
def get_query_feature(img_path,model):
img_src = Image.open(img_path).convert('RGB')
tensor_img=transform(img_src)
input_img = tensor_img.to(device)
input_img = input_img.unsqueeze(0)
Encode = model(input_img)
Encode=Encode.cpu().detach().numpy()
Encode=Encode.flatten()
return Encode
def copy_sort_by_query(group_filenames,out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
else:
#删除之前的分类结果,并重新创建一个空文件夹
shutil.rmtree(out_dir)
os.makedirs(out_dir)
for i in range(len(group_filenames)):
img_path =(group_filenames[i])
# print(img_path)
img_dir,name = os.path.split(img_path)
new_img_path=os.path.join(out_dir,str(i)+"_"+name)
#复制图片
shutil.copy(img_path, new_img_path)
def query_feat_in_datasets(query_feat,QUERY_IMAGES_FTS):
similarity=[]
for i, file in enumerate(tqdm(QUERY_IMAGES_FTS[:-1])):
# print("file",file)
file_fts = (np.load(file))
file_fts = file_fts.flatten()
cos_sim = np.dot(query_feat, file_fts.T) / (np.linalg.norm(query_feat) * np.linalg.norm(file_fts))
similarity.append(cos_sim)
the_last_image_id=i
print("query_feat_in_datasets",similarity)
return similarity,the_last_image_id
def get_query_results(similarity,query_img_file,QUERY_IMAGES,top_k = 20):
# Get best matches using similarity
similarity = np.asarray(similarity)
# 矩阵运算后升维度将维度很重要
print(similarity)
# argsort()函数是将x中的元素从小到大排列,所以先加一个负号
indexes = np.squeeze(-similarity).argsort(axis=0)[:top_k] # 返回索引非常重要,
print(indexes)
topk_similarity = [similarity[index] for index in indexes]
print("topk_similarity", topk_similarity)
# print(similarity)
# print(indexes)
best_matches_paths = [QUERY_IMAGES[index] for index in indexes]
# print(best_matches_paths)
# save 搜索的同名文件创立
out_query_fodler = query_img_file.replace(".jpg", "")
# print(out_query_fodler)
# print("query_img_file",query_img_file)
best_matches_paths.insert(0, query_img_file)
# print(best_matches_paths)
copy_sort_by_query(best_matches_paths, out_query_fodler)
IMG_FOLDER = r'./datasets/imgs'
img_fts_dir = r"./datasets/imgs/features"
query_img_file=r'./datasets/imgs/ia_100000581.jpg'
if __name__ == '__main__':
# device=torch.device("cuda:0")
# 设置图片缓存的总大小,为原来的100倍 ,否则图片多了会报错
import PIL
PIL.PngImagePlugin.MAX_TEXT_MEMORY=6710886400
print("PIL.PngImagePlugin.MAX_TEXT_MEMORY",PIL.PngImagePlugin.MAX_TEXT_MEMORY)
#设置数据集大小
youkia_dataset=GameDataset(IMG_FOLDER,num_img=-1,transform=transform)
imgs_all_path=youkia_dataset.img_paths
print("imgs_num:",len(youkia_dataset))
# dataset 2 train_tensor
BATCH_SIZE=1
train_dataloader = DataLoader(youkia_dataset, batch_size=BATCH_SIZE,drop_last=True, shuffle=False)
model = torch.hub.load('facebookresearch/swav', 'resnet50w2').cuda()
model.eval() # eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值
# Creat image database
QUERY_IMAGES_FTS = [os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_fts_dir))]
#
QUERY_IMAGES = [os.path.join(IMG_FOLDER, file.replace(".npy",".jpg")) for file in sorted(os.listdir(img_fts_dir))]
query_feat=get_query_feature(query_img_file,model)
# print("query_feat",query_feat.shape)
# Create similarity list
similarity = []
the_last_image_id=0
try:
similarity,the_last_image_id = query_feat_in_datasets(query_feat, QUERY_IMAGES_FTS)
# print("similarity",similarity)
get_query_results(similarity, query_img_file, QUERY_IMAGES)
except Exception as e:
get_query_results(similarity, query_img_file, QUERY_IMAGES)
print("出现以下异常",e)
print("an exception caught in line :", e.__traceback__.tb_lineno) # 发生异常所在的行数
print("imgs_all_path[i]",the_last_image_id,imgs_all_path[the_last_image_id])
# move_error_image(imgs_all_path[i])
torch_dataset.py
import os
import os
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms
from PIL import Image
# import PIL
import numpy as np
# to tensor
# Converts a PIL Image or numpy.ndarray (H x W x C) in the range
# [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
class GameDataset(Dataset):
def __init__(self, img_dir, num_img=-1,transform=None, target_transform=None):
# PIL.PngImagePlugin.MAX_TEXT_MEMORY = 6710886400
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
self.img_paths = []
self.file_names = os.listdir(self.img_dir)
for file in self.file_names:
if os.path.splitext(file)[1].endswith(('jpg', 'png')):
self.img_paths.append("%s/%s" % (self.img_dir, file))
self.img_paths=self.img_paths[:num_img]
# print("img_paths",self.img_paths)
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_src = Image.open(self.img_paths[idx]).convert('RGB')
if self.transform:
img_src = self.transform(img_src)
return img_src
save_fearures_2_npy.py
import torch
import torchvision.transforms as transforms
# from PIL import Image
import os
import numpy as np
from torch.utils.data import Dataset,DataLoader
#my own
from torch_dataset import GameDataset
from sklearn.decomposition import PCA
from tqdm import tqdm
IMG_WIDTH, IMG_HEIGHT=128,128
IMG_FOLDER = r'./datasets/imgs'
IMG_WIDTH, IMG_HEIGHT=128,128
dataset_transform = transforms.Compose([
#
transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def save_each_img_feature(train_dataset,model,imgs_all_path):
import PIL
# 设置图片缓存的总大小,为原来的100倍 ,否则图片多了会报错
PIL.PngImagePlugin.MAX_TEXT_MEMORY = 6710886400
print("PIL.PngImagePlugin.MAX_TEXT_MEMORY", PIL.PngImagePlugin.MAX_TEXT_MEMORY)
for (i, (Image)) in enumerate(tqdm(train_dataset)):
testEncode = model(Image)
feature=testEncode.cpu().detach().numpy()
feature=feature.flatten()
if i==0:
print("feature.shape",feature.shape)
save_feature(feature,imgs_all_path[i],outdir_name='features')
def save_feature(feature,img_in_path,outdir_name='features'):
dirname,filename=os.path.split(img_in_path)
out_dir=os.path.join(dirname,outdir_name)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
save_path=out_dir+"/"+filename.replace(".jpg", "")
print(save_path)
np.save(save_path, feature)
if __name__=="__main__":
rn50w2 = torch.hub.load('facebookresearch/swav', 'resnet50w2').cuda()
# rn50w4 = torch.hub.load('facebookresearch/swav', 'resnet50w4')
# rn50w5 = torch.hub.load('facebookresearch/swav', 'resnet50w5')
# 读取数据
game_data = GameDataset(IMG_FOLDER,num_img=-1, transform=dataset_transform)
imgs_all_path=game_data.img_paths
train_dataloader = DataLoader(game_data, batch_size=1, shuffle=False)
train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
# 推理
rn50w2.eval()
###################
save_each_img_feature(train_dataloader,rn50w2,imgs_all_path)