采用向量数据库实现百万级图像搜索

       最近接了一个图像搜索的项目,通过爬虫爬取订单并把图像进行入库,然后用户在查询历史记录的时候可快速找到之前的成交记录。讲一下历程:

开始打算采用vgg16和numpy直接搞,通过vgg16获取图像特征,然后通过numpy进行存储和查询,但是在爬虫爬取过程中发现数据量很大,想进行快速的查找并进行相似度计算需要很大的算力,耗时也比较长。翻阅了资料后发现了向量数据库:milvus,目前网上针对此资料要么太老,要么太少,基本无法参考。所以只能看官方文档自己慢慢摸索。目前已经摸索成功,数据库目前达到了340万,查询时间200毫秒。

       milvus的官方地址:Vector database - Milvus

       文档地址:Milvus documentation

这个是根据官方代码进行了调整,官方并没有单特征更新和删除,也对此进行了完善。

import sys
from config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, METRIC_TYPE,DEFAULT_TABLE
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from logs import LOGGER


class MilvusHelper:
    """
    MilvusHelper class to manager the Milvus Collection.

    Args:
        host (`str`):
            Milvus server Host.
        port (`str|int`):
            Milvus server port.
        ...
    """
    def __init__(self, host=MILVUS_HOST, port=MILVUS_PORT):
        try:
            self.collection = None
            connections.connect(host=host, port=port)
            LOGGER.debug(f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}")
        except Exception as e:
            LOGGER.error(f"Failed to connect Milvus: {e}")
            sys.exit(1)

    def set_collection(self, collection_name):
        try:
            self.collection = Collection(name=collection_name)
        except Exception as e:
            LOGGER.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    def has_collection(self, collection_name):
        # Return if Milvus has the collection
        try:
            return utility.has_collection(collection_name)
        except Exception as e:
            LOGGER.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    def create_collection(self, collection_name):
        # Create milvus collection if not exists
        try:
            if not self.has_collection(collection_name):
                field1 = FieldSchema(name='path', dtype=DataType.VARCHAR, descrition='path to image', max_length=500,
                                     is_primary=True, auto_id=False)
                field2 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="image embedding vectors",
                                     dim=VECTOR_DIMENSION, is_primary=False)
                schema = CollectionSchema(fields=[field1, field2], description="collection description")
                self.collection = Collection(name=collection_name, schema=schema)
                self.create_index(collection_name)
                LOGGER.debug(f"Create Milvus collection: {collection_name}")
            else:
                self.set_collection(collection_name)
            return "OK"
        except Exception as e:
            LOGGER.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    def insert(self, collection_name, path, vectors):
        # Batch insert vectors to milvus collection
        try:
            data = [path, vectors]
            self.set_collection(collection_name)
            mr = self.collection.insert(data)
            ids = mr.primary_keys
            self.collection.load()
            LOGGER.debug(
                    f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows")
            return ids
        except Exception as e:
            LOGGER.error(f"Failed to load data to Milvus: {e}")
            sys.exit(1)

    def create_index(self, collection_name):
        # Create IVF_FLAT index on milvus collection
        try:
            self.set_collection(collection_name)
            default_index = {"index_type": "IVF_SQ8", "metric_type": METRIC_TYPE, "params": {"nlist": 16384}}
            status = self.collection.create_index(field_name="embedding", index_params=default_index)
            if not status.code:
                LOGGER.debug(
                    f"Successfully create index in collection:{collection_name} with param:{default_index}")
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error(f"Failed to create index: {e}")
            sys.exit(1)

    def delete_collection(self, collection_name):
        # Delete Milvus collection
        try:
            self.set_collection(collection_name)
            self.collection.drop()
            LOGGER.debug("Successfully drop collection!")
            return "ok"
        except Exception as e:
            LOGGER.error(f"Failed to drop collection: {e}")
            sys.exit(1)

    def search_vectors(self, collection_name, vectors, top_k):
        # Search vector in milvus collection
        try:
            self.set_collection(collection_name)
            search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": 16}}
            res = self.collection.search(vectors, anns_field="embedding", param=search_params, limit=top_k)
            LOGGER.debug(f"Successfully search in collection: {res}")
            return res
        except Exception as e:
            LOGGER.error(f"Failed to search vectors in Milvus: {e}")
            sys.exit(1)

    def count(self, collection_name):
        # Get the number of milvus collection
        try:
            self.set_collection(collection_name)
            num = self.collection.num_entities
            LOGGER.debug(f"Successfully get the num:{num} of the collection:{collection_name}")
            return num
        except Exception as e:
            LOGGER.error(f"Failed to count vectors in Milvus: {e}")
            sys.exit(1)

向量数据库目前还在不断的完善过程中,版本之间互不兼容。

目前是2.2版本,虽然有主键,但是并不像实际数据库的主键那样具有唯一性。

以图搜图主要用的API模块为

connections,大家参考的话也主要是基于此。
import sys,os
from glob import glob
from diskcache import Cache
from config import DEFAULT_TABLE
from logs import LOGGER


def do_upload(table_name, img_path, model, milvus_client):
    try:
        if not table_name:
            table_name = DEFAULT_TABLE
        milvus_client.create_collection(table_name)
        feat = model.resnet50_extract_feat(img_path)
        ids = milvus_client.insert(table_name, [os.path.basename(img_path).split('.')[0]], [feat])
        return ids
    except Exception as e:
        LOGGER.error(f"Error with upload : {e}")
        sys.exit(1)


def extract_features(img_dir, model):
    img_list = []
    for path in ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']:
        img_list.extend(glob(img_dir+path))
    try:
        if len(img_list) == 0:
            raise FileNotFoundError(f"There is no image file in {img_dir} and endswith ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']")
        cache = Cache('./tmp')
        feats = []
        names = []
        total = len(img_list)
        cache['total'] = total
        for i, img_path in enumerate(img_list):
            try:
                norm_feat = model.resnet50_extract_feat(img_path)
                feats.append(norm_feat)
                names.append(os.path.basename(img_path).split('.')[0])
                cache['current'] = i + 1
                print(f"Extracting feature from image No. {i + 1} , {total} images in total")
            except Exception as e:
                LOGGER.error(f"Error with extracting feature from image:{img_path}, error: {e}")
                continue
        return feats, names
    except Exception as e:
        LOGGER.error(f"Error with extracting feature from image {e}")
        sys.exit(1)


def do_load(table_name, image_dir, model, milvus_client):
    if not table_name:
        table_name = DEFAULT_TABLE
    milvus_client.create_collection(table_name)
    vectors, paths = extract_features(image_dir, model)
    ids = milvus_client.insert(table_name, paths, vectors)
    return len(ids)

def do_update(table_name,img_path, model, milvus_client):
    try:
        if not table_name:
            table_name = DEFAULT_TABLE
        milvus_client.create_collection(table_name)
        feat = model.resnet50_extract_feat(img_path)
        ids = milvus_client.update(table_name, [os.path.basename(img_path).split('.')[0]], [feat])
        return ids
    except Exception as e:
        LOGGER.error(f"Error with update : {e}")
        sys.exit(1)

def do_search(table_name, img_path, top_k, model, milvus_client):
    try:
        if not table_name:
            table_name = DEFAULT_TABLE
        feat = model.resnet50_extract_feat(img_path)
        vectors = milvus_client.search_vectors(table_name, [feat], top_k)
        paths = [str(x.id) for x in vectors[0]]
        distances = [x.distance for x in vectors[0]]
        return paths, distances
    except Exception as e:
        LOGGER.error(f"Error with search : {e}")
        sys.exit(1)

def do_delete(table_name, milvus_client, path):
    try:
        if not table_name:
            table_name = DEFAULT_TABLE
        milvus_client.delete_vectors(table_name, path)

    except Exception as e:
        LOGGER.error(f"Error with search : {e}")
        sys.exit(1)

def do_count(table_name, milvus_cli):
    if not table_name:
        table_name = DEFAULT_TABLE
    try:
        if not milvus_cli.has_collection(table_name):
            return None
        num = milvus_cli.count(table_name)
        return num
    except Exception as e:
        LOGGER.error(f"Error with count table {e}")
        sys.exit(1)


def do_drop(table_name, milvus_cli):
    if not table_name:
        table_name = DEFAULT_TABLE
    try:
        if not milvus_cli.has_collection(table_name):
            return f"Milvus doesn't have a collection named {table_name}"
        status = milvus_cli.delete_collection(table_name)
        return status
    except Exception as e:
        LOGGER.error(f"Error with drop table: {e}")
        sys.exit(1)

本人在官方的基础上实现了向量的更新和删除,如需要可留言。

参考:

Milvus documentation

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值