最近接了一个图像搜索的项目,通过爬虫爬取订单并把图像进行入库,然后用户在查询历史记录的时候可快速找到之前的成交记录。讲一下历程:
开始打算采用vgg16和numpy直接搞,通过vgg16获取图像特征,然后通过numpy进行存储和查询,但是在爬虫爬取过程中发现数据量很大,想进行快速的查找并进行相似度计算需要很大的算力,耗时也比较长。翻阅了资料后发现了向量数据库:milvus,目前网上针对此资料要么太老,要么太少,基本无法参考。所以只能看官方文档自己慢慢摸索。目前已经摸索成功,数据库目前达到了340万,查询时间200毫秒。
milvus的官方地址:Vector database - Milvus
文档地址:Milvus documentation
这个是根据官方代码进行了调整,官方并没有单特征更新和删除,也对此进行了完善。
import sys
from config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, METRIC_TYPE,DEFAULT_TABLE
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from logs import LOGGER
class MilvusHelper:
"""
MilvusHelper class to manager the Milvus Collection.
Args:
host (`str`):
Milvus server Host.
port (`str|int`):
Milvus server port.
...
"""
def __init__(self, host=MILVUS_HOST, port=MILVUS_PORT):
try:
self.collection = None
connections.connect(host=host, port=port)
LOGGER.debug(f"Successfully connect to Milvus with IP:{MILVUS_HOST} and PORT:{MILVUS_PORT}")
except Exception as e:
LOGGER.error(f"Failed to connect Milvus: {e}")
sys.exit(1)
def set_collection(self, collection_name):
try:
self.collection = Collection(name=collection_name)
except Exception as e:
LOGGER.error(f"Failed to load data to Milvus: {e}")
sys.exit(1)
def has_collection(self, collection_name):
# Return if Milvus has the collection
try:
return utility.has_collection(collection_name)
except Exception as e:
LOGGER.error(f"Failed to load data to Milvus: {e}")
sys.exit(1)
def create_collection(self, collection_name):
# Create milvus collection if not exists
try:
if not self.has_collection(collection_name):
field1 = FieldSchema(name='path', dtype=DataType.VARCHAR, descrition='path to image', max_length=500,
is_primary=True, auto_id=False)
field2 = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, descrition="image embedding vectors",
dim=VECTOR_DIMENSION, is_primary=False)
schema = CollectionSchema(fields=[field1, field2], description="collection description")
self.collection = Collection(name=collection_name, schema=schema)
self.create_index(collection_name)
LOGGER.debug(f"Create Milvus collection: {collection_name}")
else:
self.set_collection(collection_name)
return "OK"
except Exception as e:
LOGGER.error(f"Failed to load data to Milvus: {e}")
sys.exit(1)
def insert(self, collection_name, path, vectors):
# Batch insert vectors to milvus collection
try:
data = [path, vectors]
self.set_collection(collection_name)
mr = self.collection.insert(data)
ids = mr.primary_keys
self.collection.load()
LOGGER.debug(
f"Insert vectors to Milvus in collection: {collection_name} with {len(vectors)} rows")
return ids
except Exception as e:
LOGGER.error(f"Failed to load data to Milvus: {e}")
sys.exit(1)
def create_index(self, collection_name):
# Create IVF_FLAT index on milvus collection
try:
self.set_collection(collection_name)
default_index = {"index_type": "IVF_SQ8", "metric_type": METRIC_TYPE, "params": {"nlist": 16384}}
status = self.collection.create_index(field_name="embedding", index_params=default_index)
if not status.code:
LOGGER.debug(
f"Successfully create index in collection:{collection_name} with param:{default_index}")
return status
else:
raise Exception(status.message)
except Exception as e:
LOGGER.error(f"Failed to create index: {e}")
sys.exit(1)
def delete_collection(self, collection_name):
# Delete Milvus collection
try:
self.set_collection(collection_name)
self.collection.drop()
LOGGER.debug("Successfully drop collection!")
return "ok"
except Exception as e:
LOGGER.error(f"Failed to drop collection: {e}")
sys.exit(1)
def search_vectors(self, collection_name, vectors, top_k):
# Search vector in milvus collection
try:
self.set_collection(collection_name)
search_params = {"metric_type": METRIC_TYPE, "params": {"nprobe": 16}}
res = self.collection.search(vectors, anns_field="embedding", param=search_params, limit=top_k)
LOGGER.debug(f"Successfully search in collection: {res}")
return res
except Exception as e:
LOGGER.error(f"Failed to search vectors in Milvus: {e}")
sys.exit(1)
def count(self, collection_name):
# Get the number of milvus collection
try:
self.set_collection(collection_name)
num = self.collection.num_entities
LOGGER.debug(f"Successfully get the num:{num} of the collection:{collection_name}")
return num
except Exception as e:
LOGGER.error(f"Failed to count vectors in Milvus: {e}")
sys.exit(1)
向量数据库目前还在不断的完善过程中,版本之间互不兼容。
目前是2.2版本,虽然有主键,但是并不像实际数据库的主键那样具有唯一性。
以图搜图主要用的API模块为
connections,大家参考的话也主要是基于此。
import sys,os
from glob import glob
from diskcache import Cache
from config import DEFAULT_TABLE
from logs import LOGGER
def do_upload(table_name, img_path, model, milvus_client):
try:
if not table_name:
table_name = DEFAULT_TABLE
milvus_client.create_collection(table_name)
feat = model.resnet50_extract_feat(img_path)
ids = milvus_client.insert(table_name, [os.path.basename(img_path).split('.')[0]], [feat])
return ids
except Exception as e:
LOGGER.error(f"Error with upload : {e}")
sys.exit(1)
def extract_features(img_dir, model):
img_list = []
for path in ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']:
img_list.extend(glob(img_dir+path))
try:
if len(img_list) == 0:
raise FileNotFoundError(f"There is no image file in {img_dir} and endswith ['/*.png', '/*.jpg', '/*.jpeg', '/*.PNG', '/*.JPG', '/*.JPEG']")
cache = Cache('./tmp')
feats = []
names = []
total = len(img_list)
cache['total'] = total
for i, img_path in enumerate(img_list):
try:
norm_feat = model.resnet50_extract_feat(img_path)
feats.append(norm_feat)
names.append(os.path.basename(img_path).split('.')[0])
cache['current'] = i + 1
print(f"Extracting feature from image No. {i + 1} , {total} images in total")
except Exception as e:
LOGGER.error(f"Error with extracting feature from image:{img_path}, error: {e}")
continue
return feats, names
except Exception as e:
LOGGER.error(f"Error with extracting feature from image {e}")
sys.exit(1)
def do_load(table_name, image_dir, model, milvus_client):
if not table_name:
table_name = DEFAULT_TABLE
milvus_client.create_collection(table_name)
vectors, paths = extract_features(image_dir, model)
ids = milvus_client.insert(table_name, paths, vectors)
return len(ids)
def do_update(table_name,img_path, model, milvus_client):
try:
if not table_name:
table_name = DEFAULT_TABLE
milvus_client.create_collection(table_name)
feat = model.resnet50_extract_feat(img_path)
ids = milvus_client.update(table_name, [os.path.basename(img_path).split('.')[0]], [feat])
return ids
except Exception as e:
LOGGER.error(f"Error with update : {e}")
sys.exit(1)
def do_search(table_name, img_path, top_k, model, milvus_client):
try:
if not table_name:
table_name = DEFAULT_TABLE
feat = model.resnet50_extract_feat(img_path)
vectors = milvus_client.search_vectors(table_name, [feat], top_k)
paths = [str(x.id) for x in vectors[0]]
distances = [x.distance for x in vectors[0]]
return paths, distances
except Exception as e:
LOGGER.error(f"Error with search : {e}")
sys.exit(1)
def do_delete(table_name, milvus_client, path):
try:
if not table_name:
table_name = DEFAULT_TABLE
milvus_client.delete_vectors(table_name, path)
except Exception as e:
LOGGER.error(f"Error with search : {e}")
sys.exit(1)
def do_count(table_name, milvus_cli):
if not table_name:
table_name = DEFAULT_TABLE
try:
if not milvus_cli.has_collection(table_name):
return None
num = milvus_cli.count(table_name)
return num
except Exception as e:
LOGGER.error(f"Error with count table {e}")
sys.exit(1)
def do_drop(table_name, milvus_cli):
if not table_name:
table_name = DEFAULT_TABLE
try:
if not milvus_cli.has_collection(table_name):
return f"Milvus doesn't have a collection named {table_name}"
status = milvus_cli.delete_collection(table_name)
return status
except Exception as e:
LOGGER.error(f"Error with drop table: {e}")
sys.exit(1)
本人在官方的基础上实现了向量的更新和删除,如需要可留言。
参考: