Annoy最近邻检索技术之 “图片检索”

本文主要介绍一下NN检索方式Annoy(Approximate Nearest Neighbors Oh Yeah)的应用,在前几篇的召回文章中(1.推荐系统召回模型之YouTubeNet;2. 推荐系统召回模型之MIND用户多兴趣网络实践)都涉及这个技术点,一直没有详细的介绍。本文以图片检索为应用场景,介绍一下Annoy。

1

Annoy算法原理

Annoy是Python的一个模块,用于空间检索近邻的数据。检索过程分成三步:

  • 建立索引过程;

  • 近邻查询过程;

  • 返回最终近邻节点;

首先先来一张2D数据分布图:

接下来按照步骤1,2和3进行分析。

1.1 建立索引过程

Annoy的目标是建立一个数据结构,使得查询一个点的最近邻点的时间复杂度是次线性。Annoy 通过建立一个二叉树来使得每个点查找时间复杂度是O(log n)。以下图为例,随机选择两个点,以这两个节点为初始中心节点,执行聚类数为2的kmeans过程,最终产生收敛后两个聚类中心点。这两个聚类中心点之间连一条线段(灰色短线),建立一条垂直于这条灰线,并且通过灰线中心点的线(黑色粗线)。这条黑色粗线把数据空间分成两部分。在多维空间的话,这条黑色粗线可以看成等距垂直超平面。

在划分的子空间内进行不停的递归迭代继续划分,直到每个子空间最多只剩下K个数据节点。

通过多次递归迭代划分的话,最终原始数据会形成类似下图的二叉树结构。二叉树底层是叶子节点记录原始数据节点,其他中间节点记录的是分割超平面的信息。Annoy建立这样的二叉树结构是希望满足这样的一个假设:  相似的数据节点应该在二叉树上位置更接近,一个分割超平面不应该把相似的数据节点分割二叉树的不同分支上。

根据上述步骤,建立多棵二叉树树,构成一个森林。

1.2 近邻查询过程

上面已完成节点索引建立过程。如何进行对一个数据点进行查找相似节点集合呢?比如下图的红色节点,查找的过程就是不断看他在分割超平面的哪一边。从二叉树索引结构来看,就是从根节点不停的往叶子节点遍历的过程。通过对二叉树每个中间节点(分割超平面相关信息)和查询数据节点进行相关计算来确定二叉树遍历过程是往这个中间节点左孩子节点走还是右孩子节点走。通过以上方式完成查询过程。

查询过程采用优先队列机制:采用一个优先队列来遍历二叉树,从根节点往下的路径,根据查询节点与当前分割超平面距离进行排序。

1.3 返回最终近邻节

步骤1会构建多棵二叉树树,每棵树都返回一堆近邻点后,如何得到最终的Top N相似集合呢?首先所有树返回近邻点都插入到优先队列中,求并集去重, 然后计算和查询点距离,最终根据距离值从近距离到远距离排序,返回Top-N近邻节点集合。

2

图片检索实践

先放一张本文检索的效果图:

检索结果:最相似的 Top-9张商品图片如下所示:

技术步骤:

  • 下载一批商品图片,本文使用的商品图片来源于某电商商城;

  • 下载vgg16模型;

  • 使用vgg16模型提取图片特征;

  • 使用Annoy技术对图片特征数据构建索引,及建树;

  • 输入一张图片特征数据,检索并返回最相似的Top-9张图片;

2.1 下载一批商品图片

本文使用的图片数据来源于某电商商城,下载了30个种类的图片数据,共计5130张。下载代码如下:

# encoding="utf-8"
from requests_html import HTMLSession
import re
import os
import time

sku_eng_list = ["Mobile-phone", "T-shirt", "Milk", "Mask", "Headset", \
"Wine", "Helmet", "Fan", "Sneaker", "Cup", \
"Glasses", "Backpack", "UAV", "Sofa", "Bicycle", \
"Cleanser", "Paper", "Bread", "Sausage", "Toilet", \
"Book", "Tire", "Clock", "Mango", "Shrimp", \
"Stroller", "Necklace", "Baby-bottle", "Yuba", "Pot"]


session = HTMLSession()

for inx, key in enumerate(["手机", "T恤", "牛奶", "口罩", "耳机", \
    "酒", "头盔", "风扇", "运动鞋", "杯子", \
    "眼镜", "背包", "无人机", "沙发", "自行车", \
    "洗面奶", "抽纸", "面包", "香肠", "马桶", \
    "书", "轮胎", "钟表", "芒果", "虾", \
    "童车", "项链", "奶瓶", "浴霸", "锅"]):

    for j in range(1, 10):
        
        time.sleep(2)
        url = 'https://search.jd.com/Search?keyword=%s&wq=%s&page=%s&s=90&click=0' % \
            (key, key, str(j))

        r = session.get(url)

        for i in range(1, 20):
            try:
                contain_pic_url = str(r.html.find('#J_goodsList > ul > li:nth-child('+str(i)+') > div > div > div.gl-i-tab-content > div.tab-content-item.tab-cnt-i-selected > div.p-img > a > img'))
                src_start = re.search('src',contain_pic_url).end() + 2
                src_end = int(re.search("'",contain_pic_url[src_start:]).start())
                pic_url = 'https:'+contain_pic_url[src_start:src_start + src_end]

                os.chdir('C:\\Users\\Desktop\\figures')
                pic = session.get(pic_url)
                open(sku_eng_list[inx]+'_page_'+str(j)+'_NO_'+str(i)+'.jpg','wb').write(pic.content)

            except:
                try:
                    contain_pic_url = str(r.html.find('#J_goodsList > ul > li:nth-child('+str(i)+') > div > div.p-img > a > img'))
                    src_start = re.search('src',contain_pic_url).end() + 2
                    src_end = int(re.search("'",contain_pic_url[src_start:]).start())
                    pic_url = 'https:'+contain_pic_url[src_start:src_start + src_end]

                    os.chdir('C:\\Users\\Desktop\\figures')
                    pic = session.get(pic_url)
                    open(sku_eng_list[inx]+'_page_'+str(j)+'_NO_'+str(i)+'.jpg','wb').write(pic.content)

                except:
                    pass

    print("Download %s done !!!" % sku_eng_list[inx])

2.2 下载vgg16模型

官方下载地址

https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

如果无法在官方地址下载,可以从百度网盘中下载:

链接:https://pan.baidu.com/s/1Exa8g_q9hVmqOU9SBrIxrg
提取码:qtsb

将 vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5 模型放入 ~/.keras/models 路径中即可。

2.3 使用vgg16模型提取图片特征

该版本的vgg16模型可以将图片转化为维度为 [7, 7, 512] 的浮点型数据,将该数据“压平”保存。

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
from tqdm import tqdm
import pickle
import os

# 1. 加载vgg16模型
model = VGG16(weights='imagenet', include_top=False)
#print(model.summary())


# 2. 提取图片特征
img_path = "figures/"

img_name_list = []
img_feature_list = []

for file in tqdm(os.listdir(img_path)):
    img_name_list.append(file)
    file_path = img_path + file
    
    img = image.load_img(file_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    features = model.predict(x)
    img_feature_list.append(features.reshape((7*7*512,)))


# 3. 将图片名称和图片特征保存为pkl格式
f = open("img_feature_list.pkl", 'wb')
pickle.dump(img_feature_list, f)
f.close()

g = open("img_name_list.pkl", 'wb')
pickle.dump(img_name_list, g)
g.close()

此时将得到 img_name_list.pkl 和 img_feature_list.pkl 两个文件,分别保存图片的名称和图片的特征。

2.4 使用Annoy技术对图片特征数据构建索引,及建树

建索引及建树脚本如下所示:

# encoding:utf-8
from annoy import AnnoyIndex
import pickle
import numpy as np
np.random.seed(20200601)
import sys, time
from tqdm import tqdm

def build_ann(name_path=None, \
       vec_path=None, \
       index_to_name_dict_path=None, \
       ann_path=None, \
       dim=64, \
       n_trees=10):

    name_path = open(name_path, 'rb')
    vec_path = open(vec_path, 'rb')
    img_name_list = pickle.load(name_path)
    img_vec_list = pickle.load(vec_path)
    

    ann = AnnoyIndex(dim)
    idx = 0
    batch_size = 100 * 10000
    index_to_name_dict = {}
    

    for name, vec in tqdm(zip(img_name_list, img_vec_list)):
        ann.add_item(idx, vec)
        index_to_name_dict[idx] = name

        idx += 1
        if idx % batch_size == 0:
            print("%s00w" % (int(idx/batch_size)))

    print("Add items Done!\nStart building trees")

    ann.build(n_trees)
    print("Build Trees Done!")
    
    ann.save(ann_path)
    print("Save ann to %s Done!" % (ann_path))

    fd = open(index_to_name_dict_path, 'wb')
    pickle.dump(index_to_name_dict, fd)
    fd.close()
    print("Saving index_to_name mapping Done!")


if __name__ == '__main__':
    name_path = "img_name_list.pkl"
    vec_path = "img_feature_list.pkl"
    index_to_name_dict_path = "index_to_name_dict.pkl"
    ann_path = "img_feature_list.ann"
    dim = 25088
    n_trees = 10

    build_ann(name_path=name_path, \
        vec_path=vec_path, \
        index_to_name_dict_path=index_to_name_dict_path, \
        ann_path=ann_path, \
        dim=dim, \
        n_trees=n_trees)

本实验构建了10棵二叉树,此时将得到 index_to_name_dict.pkl 和 img_feature_list.ann 两个文件,分别保存图片索引Id与名称的映射数据,和图片特征的二叉树结构信息。

2.5 输入一张图片特征数据,检索并返回最相似的Top-9张图片

话不多说,代码如下:

# encoding:utf-8
from annoy import AnnoyIndex
import numpy as np
np.random.seed(20200601)
import pickle
import sys
from matplotlib import image as mpimg
from matplotlib import pyplot as plt

def load_ann(ann_path=None, index_to_name_dict_path=None, dim=64):
    ann = AnnoyIndex(dim)
    ann.load(ann_path)

    with open(index_to_name_dict_path, 'rb') as f:
        index_to_name_dict = pickle.load(f)
    return ann, index_to_name_dict


def query_ann(ann=None, index_to_name_dict=None, query_vec=None, topN=None):
    topN_item_idx_list = ann.get_nns_by_vector(query_vec, topN)

    topN_item_id_list = []

    for idx in topN_item_idx_list:
        item_id = index_to_name_dict[idx]
        topN_item_id_list.append(item_id)

    return topN_item_id_list


if __name__ == '__main__':
    index_to_name_dict_path = "index_to_name_dict.pkl"
    ann_path = "img_feature_list.ann"
    name_path = "img_name_list.pkl"
    vec_path = "img_feature_list.pkl"
    dim = 25088
    topN = 9
    
    name_path = open(name_path, 'rb')
    vec_path = open(vec_path, 'rb')
    img_name_list = pickle.load(name_path)
    img_vec_list = pickle.load(vec_path)
    
    idx = 126
    query_name = img_name_list[idx]
    query_vec = img_vec_list[idx]
    
    ann, index_to_name_dict = load_ann(ann_path=ann_path, \
        index_to_name_dict_path=index_to_name_dict_path, \
        dim=dim)

    topN_item_list = query_ann(ann=ann, \
        index_to_name_dict=index_to_name_dict, \
        query_vec=query_vec, \
        topN=topN)

    # query 商品图片
    print("query_image: \n")
    fig, axes = plt.subplots(1, 1)
    query_image = mpimg.imread("figures/" + query_name)
    axes.imshow(query_image/255)
    axes.axis('off')
    axes.axis('off')
    axes.set_title('%s' % query_name, fontsize=8, color='r')

    # Top-9 相似商品
    fig, axes = plt.subplots(3, 3)
    for idx, img_path in enumerate(topN_item_list):

        i = idx % 3   # Get subplot row
        j = idx // 3  # Get subplot column
        image = mpimg.imread("figures/" + img_path)
        axes[i, j].imshow(image/255)
        axes[i, j].axis('off')
        axes[i, j].axis('off')

        axes[i, j].set_title('%s' % img_path, fontsize=8, color='b')

本实验以idx=126为例进行测试,idx取值范围为[0, 5129]。

参考:

https://github.com/spotify/annoy

https://blog.csdn.net/hero_fantao/article/details/70245387

欢迎关注 “python科技园” 及 添加小编 进群交流。

文章好看点这里

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值