基于深度学习的以图搜图

  • 使用预训练的卷积神经网络提取图片中的特征,生成特征向量。
  • 利用图片库中所有图片数据构建 <id, feature vector> 数据。
  • 使用 Faiss 创建 Index ,利用 <id, feature vector> 数据生成索引。
  • 针对待检索图片,使用模型提取图片特征向量,然后使用 Index 检索 TopK 相似图片的 id。
  • 可视化检索结果

1. 导包

import os
import time
import torch
import faiss
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

%matplotlib inline
GPU 加速
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# cuda

2.自定义数据集

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


class MyDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.data_path = data_path
        self.data = []
        
        img_path = os.path.join(data_path, 'img.txt')
        with open(img_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                line = line.strip()
                img_name = os.path.join(data_path, line)
                img = Image.open(img_name)
                if img.mode == 'RGB':
                    self.data.append(line)
    
    
    def __getitem__(self, idx):
        # take the data sample by it's index
        img_path = os.path.join(self.data_path, self.data[idx])
        # read image
        img = Image.open(img_path)
        # apply the transform
        if self.transform:
            img = self.transform(img)
            
        # return the image and index
        dict_data = {
            'index': idx,
            'img': img
        }
        return dict_data
    
    
    def __len__(self):
        return len(self.data)
img_folder = 'JPEGImages'
val_dataset = MyDataset(img_folder, transform=transform)
batch_size = 64
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print('Val_dataset: ', val_dataset.__len__())
print('iter: ', int(val_dataset.__len__()/batch_size)+1)
Val_dataset:  17125
iter:  268

3.预训练模型+自定义特征值提取器

# 加载预训练模型
def load_model():
    model = models.resnet18(pretrained=True)
    model.to(device)
    model.eval()
    return model


# 定义 特征提取器
def feature_extract(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x
model = load_model()

for idx, batch in enumerate(val_dataloader):
    img = batch['img']  # 图片数据表示 --> 图片特征
    index = batch['index']
    img = img.to(device)
    feature = feature_extract(model, img)
    feature = feature.data.cpu().numpy()
    imgs_path = [os.path.join(img_folder, val_dataset.data[i] + '.txt') for i in index]
    assert len(feature) == len(imgs_path)
    
    for i in range(len(imgs_path)):
        feature_list = [str(f) for f in feature[i]]
        img_path = imgs_path[i]
        
        with open(img_path, 'w', encoding='utf-8') as f:
            f.write(" ".join(feature_list))
    print('*' * 60)
    print(idx * batch_size)

4.图片向量化

# 获取图片特征¶
def img2feat(pic_file):
    feat = []
    with open(pic_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        feat = [float(f) for f in lines[0].split()]
    return feat
ids = []
data = []

img_folder = 'VOC2012'#'VOC2012_small/'
img_path = os.path.join(img_folder,'img.txt')
with open(img_path,'r',encoding='utf-8') as f:
    for line in f.readlines():
        img_name = line.strip()
        img_id = img_name.split('.')[0]
        pic_txt_file = os.path.join( img_folder,"{}.txt".format(img_name) )
        
        if not os.path.exists(pic_txt_file):
            continue
            
        feat = img2feat(pic_txt_file)
        ids.append(int(img_id))
        data.append(np.array(feat))
        
# 构建数据<id,data> 
ids = np.array(ids)
data = np.array(data).astype('float32')
d = 512 # feature 特征长度(模型的结果) 
print(" 特征向量记录数: ",data.shape)
print(" 特征向量ID的记录数:",ids.shape)
 特征向量记录数:  (17125, 512)
 特征向量ID的记录数: (17125,)

5.创建 Faiss 索引 Index

# 创建图片特征索引 - 方案1
# index = faiss.index_factory(d,"IDMap,Flat")
# index.add_with_ids(data,ids)


# 创建图片特征索引-方案2(  资源有限,效果更好 )
###IDMap 支持add_with_ids 
###如果很在意,使用”PCARx,...,SQ8“ 如果保存全部原始数据的开销太大,可以用这个索引方式。包含三个部分,
# 1.降维
# 2.聚类
# 3.scalar 量化,每个向量编码为8bit 不支持GPU
index = faiss.index_factory(d, "IDMap,PCAR16,IVF50,SQ8") 
index.train(data)
index.add_with_ids(data, ids)


# 索引文件保存磁盘
faiss.write_index(index,'index_file.index') # 讲index保存index_file.index 的文件
# index = faiss.read_index("index_file.index")
# print(index.ntotal) # 查看索引库大小
加载 Faiss Index 索引文件
index = faiss.read_index('index_file.index')
print('索引记录数:', index.ntotal)
# 索引记录数: 17125

6.Faiss 相似 TopK 检索

def index_search(feat,topK ):
    """
        feat: 检索的图片特征
        topK: 返回最高topK相似的图片
        
    """
    
    feat = np.expand_dims( np.array(feat),axis=0 )
    feat = feat.astype('float32')
    
    start_time = time.time()
    dis,ind = index.search( feat,topK )
    end_time = time.time()
    
    print( 'index_search consume time:{}ms'.format(  int(end_time - start_time) * 1000  ) )
    return dis,ind # 距离,相似图片id

7.可视化检索结果

def visual_plot(ind,dis,topK,query_img = None):       
    # 相似照片
    cols = 4
    rows = int(topK / cols)
    idx = 0
    
    fig,axes = plt.subplots(rows,cols,figsize=(20 ,5*rows),tight_layout=True)
    #axes[0,0].imshow(query_img)
    
    for row in range(rows):
        for col in range(cols):
            _id = ind[0][idx]
            _dis = dis[0][idx]
            
            img_path = os.path.join(img_folder,'{}.jpg'.format(_id))
            #print(img_path)
            
            if query_img is not None and idx == 0:
                axes[row,col].imshow(query_img)
                axes[row,col].set_title( 'query',fontsize = 20  )
            else:
                img = plt.imread(  img_path   )
                axes[row,col].imshow(img)
                axes[row,col].set_title( 'matched_-{}_{}'.format(_id,int(_dis)) ,fontsize = 20  )
            idx+=1
            
    plt.savefig('pic')
img_folder = 'VOC2012/'
# img_id = '100211.jpg'
img_id = '100002.jpg'
topK = 20
img_path = os.path.join( img_folder,img_id)
print(img_path) # 查看  这个img_path 的相似图片

img = Image.open(img_path)
img = transform(img) # torch.Size([3, 224, 224])
img = img.unsqueeze(0) # torch.Size([1, 3, 224, 224])
img = img.to(device)

# 对我们的图片进行预测
with torch.no_grad():
    # 图片-> 图片特征
    print('1.图片特征提取')
    feature = feature_extract( model,img )
    # 特征-> 检索
    feature_list = feature.data.cpu().tolist()[0]
    print('2.基于特征的检索,从faiss获取相似度图片')
    # 相似图片可视化
    dis,ind = index_search( feature_list,topK=topK )
    print('ind = ',ind)
    print('3.图片可视化展示')
    # 当前图片
    query_img = plt.imread( img_path )
    visual_plot( ind,dis,topK,query_img)
VOC2012/100002.jpg
1.图片特征提取
2.基于特征的检索,从faiss获取相似度图片
index_search consume time:0ms
ind =  [[100002 101430 116500 101585 116528 100507 104768 107651 112514 102820
  112416 116458 106167 111781 116247 103299 103154 106012 115086 111156]]
3.图片可视化展示

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值