创建HNSW结构的小工具

上一篇讲解了HNSW算法,这篇博客是为了实现图像检索,做了一个tool,主要借助python的GUI编程库,实现了创建HNSW中index(也就是整个结构)的小工具。工具里我另外嵌入了提取特征的部分,使用的是Resnet。
*使用工具需要安装:

  • tkinter库:sudo pip3 install python-tk
  • pytorch: www.baidu.com
  • numpy: sudo pip3 install numpy
  • natsort: sudo pip3 install natsort
  • hnswlib: sudo pip3 install hnswlib

界面及使用介绍

这
这么丑的界面,除了我没有人能做出来了。路径输入你图片的文件夹名称以提取特征,比如“…/faces/”,注意最后一定要加/,因为是文件夹。
如果想直接创建,就先输入M和Max elems,再点击创建索引,最后会生成一个index.idx,如果,只是想增量索引,那就输入文件夹名之后点增量索引就可以。最后的删除索引并不会真的删除,会把你原来的index.idx重命名为index_backup.idx 细心如我

源码


from tkinter import *
import configparser
import cv2
import os
import natsort
import hnswlib
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models, transforms
from PIL import Image
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

class CreateTool():
    def __init__(self,master):
        # 这里是变量
        self.imageDir = ''
        self.features=[]
        self.feature_dim = 2048
        self.index_path = '../index.idx'
        # 这里是窗口
        self.parent = master
        self.parent.title('Creat Index')
        self.frame = Frame(self.parent)

        self.frame.pack(fill=BOTH, expand=1)
        # self.parent.resizable(width=False, height=False)

        self.label = Label(self.frame, text="路径(文件夹):")
        self.label.grid(row=0, column=0)
        self.entryDir = Entry(self.frame)
        self.entryDir.grid(row=0, column=1, columnspan=3,sticky=W+E+N)
        
        self.labelM = Label(self.frame, text="M:")
        self.labelM.grid(row=1, column=1)
        self.entryM = Entry(self.frame)
        self.entryM.grid(row=1, column=2)

        self.label_max_elements = Label(self.frame, text="Max elems:")
        self.label_max_elements.grid(row=2, column=1)
        self.entry_max_elements = Entry(self.frame)
        self.entry_max_elements.grid(row=2, column=2)
        self.createBtn = Button(self.frame, text="创建索引", command=self.createIndex)
        self.createBtn.grid(row=1, rowspan=3,column=0,sticky=E+N+W+S)
        
        self.addBtn = Button(self.frame, text="增量索引", command=self.addIndex)
        self.addBtn.grid(row=4, column=0,columnspan=3,pady=20,sticky=W+E+N)
        
        self.delBtn = Button(self.frame,text='删除索引',command=self.delIndex)
        self.delBtn.grid(row=5, column=0,columnspan=3,sticky=W+E+N)
        
    def loadImage(self, elements_num):
        # 这个函数是对文件中的图片重命名
        self.imageDir = self.entryDir.get()
        # print(self.imageDir)
        # 获取文件夹中的文件(自然排序)
        fileList = os.listdir(self.imageDir)
        fileList = natsort.natsorted(fileList)
        print(fileList)
        # 对文件重命名
        insert_len = len(fileList)
        for i in range(elements_num,elements_num + insert_len):
            src = os.path.join(self.imageDir,fileList[i - elements_num])
            img_name = fileList[i - elements_num]
            type_name = img_name.split('.')[1]
            dst = os.path.join('../faces', str(i)+'.'+type_name)
            os.rename(src,dst)
        # 重新获取文件夹中的文件(自然排序)
        fileList = os.listdir("../faces/")
        fileList = natsort.natsorted(fileList)
        return fileList[elements_num:elements_num + insert_len]
    def extractFeatures(self, files_list):
        model = net()
        model = model.cuda()
        use_gpu = torch.cuda.is_available()
        print('==================extracting features==============================')
        features = []
        print('file_list',files_list)
        for x_path in files_list:
            x_path = '../faces/' + x_path
            print("img_path-" + x_path)
            # file_name = x_path.split('/')[-1]
            # fx_path = os.path.join(features_dir, file_name + '.txt')
            tmp_feature = extractor(x_path, model, use_gpu)
            features.append(tmp_feature[0])
        features = np.array(features)
    def createIndex(self):
        # 获取图像
        img_list=self.loadImage(0)
        # print(img_list)
        # 提取图像特征
        # 特征提取参数设置
        self.features = self.extractFeatures(img_list)
        # 建立索引
         print('==================building index==============================')
        # self.feature_dim = len(self.features[0])
        # print(self.feature_dim)
        p = hnswlib.Index(space='l2', dim=self.feature_dim)
        p.init_index(max_elements=int(self.entry_max_elements.get()), ef_construction=100, M=int(self.entryM.get()))
        p.set_ef(100)
        p.set_num_threads(4)
        labels_index = np.arange(len(self.features))
        
        p.add_items(self.features, labels_index)
        print('saving...')
        p.save_index(self.index_path)
        print("==================building finished=======================")
        print("{} elements are built successfully".format(len(p.get_ids_list())))
        
    def addIndex(self):
        p = hnswlib.Index(space='l2', dim=self.feature_dim)
        p.load_index(self.index_path)
        elements_num = len(p.get_ids_list())
        img_list = self.loadImage(elements_num)
        print('img_list',img_list)
        features = self.extractFeatures(img_list)
         print('===========adding items========================')
        # self.feature_dim = len(features[0])
        labels_index = np.arange(elements_num, elements_num + len(features))
        # print(self.feature_dim)
        # print(labels_index)
        # print(features.shape)
        p.add_items(features, labels_index)
        p.save_index(self.index_path)
        print('adding finished!')
        
    def delIndex(self):
        if os.path.exists(self.index_path):
            os.rename(self.index_path,'../index_backup.idx')
            # os.remove(self.index_path)

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.net = models.resnet50(pretrained=True)
    def forward(self, input):
        output = self.net.conv1(input)
        output = self.net.bn1(output)
        output = self.net.relu(output)
        output = self.net.maxpool(output)
        output = self.net.layer1(output)
        output = self.net.layer2(output)
        output = self.net.layer3(output)
        output = self.net.layer4(output)
        output = self.net.avgpool(output)
        return output

def extractor(img_path, net, use_gpu):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()]
    )
    img = Image.open(img_path)
    img = transform(img)
    # print(img.shape)
    x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
    # print(x.shape)
    if use_gpu:
        x = x.cuda()

    y = net(x).cpu()
    y = torch.squeeze(y)
    y = y.data.numpy()
    # print('y.shape---',y.reshape(1,-1).shape)
    # np.savetxt(saved_path, y, delimiter=',')
    return y.reshape(1,-1)

if __name__ == '__main__':
    root = Tk()
    root.geometry("300x200")
    tool = CreateTool(root)
    root.mainloop()

** 友情提醒,记得修改程序中的文件路径,我写的是相对路径。Have fun.

下一篇预告

在这里插入图片描述
Python中的Flak搭建后台 + Bootstrap写静态网页 + hnsw近邻算法 = 图像检索Web应用

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值