想通过缩略图找原图?之前p过的图像想找原图?汇报时使用压缩过的图像现在想找原图?如何从大量图像文件中快速找到与目标图像相似的那个?pytorch 只需要几行代码就可以搞定。
模型的选取
一般进行特征提取使用图像分类网络即可。参看上一篇 使用pytorch中的resnet预训练模型进行快速图像分类。
代码如下
提取查询图像和候选图像的特征,计算二者的余弦相似度,相似度越大则图像越相似。输出图像的路径,将相似的图像保存到指定目录下。
# load model
import torch
import torchvision
model = torchvision.models.resnet101(pretrained=True)
# or any of these variants
# resnet18, resnet34, resnet50, resnet101, resnet152
model.eval()
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
# args
path_to_query = 'target.jpg' # given one query image
path_to_data = '/path/to/data/' # gallery images
BATCH_SIZE = 256
target_dir = 'targetdir' # we can save similar images in target directory
threshold = 0.8 # pick out if the cosine similarity > threshold.
# build dataloader
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_data = torchvision.datasets.ImageFolder(path_to_data, preprocess)
image_names = test_data.samples
data_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)
# load model to GPU
model.to('cuda')
count = 0
result = []
# test and query
with torch.no_grad():
# load query image
input_image = Image.open(path_to_query)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
input_batch = input_batch.to('cuda')
# build feature extractor
resnet50_feature_extractor = model
resnet50_feature_extractor.fc = torch.nn.Linear(2048,2048) # 512,512 2048,1024 ...
# the size varies for different models. Refer to official implementations for the size of feature maps
# ---以下几行必须要有:---
# torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight) # for parallel distributed training
# torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
# ---------------------
# extract feature
resnet50_feature_extractor = resnet50_feature_extractor.cuda()
q_feature = resnet50_feature_extractor(input_batch)
# load gallery images
for (x, y) in tqdm(data_loader, desc="Evaluating", leave=False):
x = x.to('cuda')
y = y.to('cuda')
# extract fature
output = resnet50_feature_extractor(x)
# calculate cosine similarity to query
similarity = torch.cosine_similarity(q_feature, output, dim=1)
for index in range(output.shape[0]):
if similarity[index] > threshold:
result.append(image_names[count*BATCH_SIZE+index][0])
count += 1
# from shutil import copyfile
# import os
# os.makedirs(target_dir, exist_ok=True)
# for r in result:
# copyfile(r, target_dir+'/'+r.split('/')[-1])
# print the results
for r in result:
print(r)