向量数据库Milvus字符串查询

        因为项目需要,用到了向量数据库Milvus,刚开始都没有遇到问题,直到一个表的主键是字符串(VARCHAR),在查询时刚好要以该主键作为查询条件,此时会出现异常,特此记录一下。

        记住,字符串查询,构建表达式时要加上单引号,比如下面的'{face_id}',其实face_id本来就是一个字符串类型了,如果不加会出现如下的异常:
        # pymilvus.exceptions.MilvusException: <MilvusException: (code=65535, message=cannot parse expression: face_id == 2_0, error: invalid expression: face_id == 2_0)>

  具体看下面的代码(milvus_demo.py),其中exists()函数中构建查询表达式时做了特殊处理:

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, Partition
import time
from datetime import datetime
from typing import List


#用于测试字符串查询的demo


# MILVUS向量数据库地址
MILVUS_HOST_ONLINE = '127.0.0.1'
MILVUS_PORT = 19530

# 检索时返回的匹配内容条数
VECTOR_SEARCH_TOP_K = 100


class MilvusAvatar:
    # table_name 表名
    # partition_names  分区名,使用默认即可
    def __init__(self, mode, table_name, *, partition_names=["default"], threshold=1.1, client_timeout=3):
        self.table_name = table_name
        self.partition_names = partition_names
        
        self.host = MILVUS_HOST_ONLINE
        self.port = MILVUS_PORT
        self.client_timeout = client_timeout
        self.threshold = threshold
        self.sess: Collection = None
        self.partitions: List[Partition] = []
        self.top_k = VECTOR_SEARCH_TOP_K
        self.search_params = {"metric_type": "L2", "params": {"nprobe": 256}}
        self.create_params = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 2048}}
                
        self.init()

    @property
    def fields(self):
        fields = [
            FieldSchema(name='face_id', dtype=DataType.VARCHAR, max_length=640, is_primary=True, auto_id = False),
            FieldSchema(name='media_id', dtype=DataType.INT64),
            FieldSchema(name='file_path', dtype=DataType.VARCHAR, max_length=640),  #原图片保存路径
            FieldSchema(name='name', dtype=DataType.VARCHAR, max_length=640),  #姓名
            FieldSchema(name='count', dtype=DataType.INT64),  #数量
            FieldSchema(name='save_path', dtype=DataType.VARCHAR, max_length=640),  #现保存的绝对路径,包含文件名
            FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=512) 
        ]
        return fields

    @property
    def output_fields(self):
        return ['face_id','media_id', 'file_path', 'name', 'count', 'save_path','embedding']

    def init(self):
        try:
            connections.connect(host=self.host, port=self.port)  # timeout=3 [cannot set]
            if utility.has_collection(self.table_name):
                self.sess = Collection(self.table_name)
                print(f'collection {self.table_name} exists')
            else:
                schema = CollectionSchema(self.fields)
                print(f'create collection {self.table_name} {schema}')
                self.sess = Collection(self.table_name, schema)
                self.sess.create_index(field_name="embedding", index_params=self.create_params)
            for index in self.partition_names:
                if not self.sess.has_partition(index):
                    self.sess.create_partition(index)
            self.partitions = [Partition(self.sess, index) for index in self.partition_names]
            print('partitions: %s', self.partition_names)
            self.sess.load()
        except Exception as e:
            print(e)

        
    def query_expr_sync(self, expr, output_fields=None, client_timeout=None):
        if client_timeout is None:
            client_timeout = self.client_timeout
        if not output_fields:
            output_fields = self.output_fields

        print(f"MilvusAvatar query_expr_sync:{expr},output_fields:{output_fields}")
        print(f"MilvusAvatar num_entities:{self.sess.num_entities}")
        if self.sess.num_entities == 0:
            return []
            
        return  self.sess.query(partition_names=self.partition_names, 
                                output_fields=output_fields,
                                expr=expr,
                                _async= False,
                                offset=0, 
                                limit=1000)
        
    # emb 为一个人脸特征向量
    def insert_avatar_sync(self, face_id, media_id, file_path, name, save_path, embedding):
        print(f'now insert_avatar {file_path}')
        print(f'now insert_avatar {file_path}')
                       
        data = [[] for _ in range(len(self.sess.schema))] 
        data[0].append(face_id)
        data[1].append(media_id)
        data[2].append(file_path)
        data[3].append(name)
        data[4].append(1)
        data[5].append(save_path)
        data[6].append(embedding)

        # 执行插入操作
        try:
            print('Inserting into Milvus...')
            self.partitions[0].insert(data=data)
            print(f'{file_path}')
            
            print(f"MilvusAvatar insert_avatar num_entities:{self.sess.num_entities}")
        except Exception as e:
            print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
            print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
            return False

        return True    
        
  
    # embs是一个数组
    def search_emb_sync(self, embs, expr='', top_k=None, client_timeout=None):
        if self.sess is None:
            return None
    
        if not top_k:
            top_k = self.top_k
        milvus_records = self.sess.search(data=embs, partition_names=self.kb_ids, anns_field="embedding",
                                          param=self.search_params, limit=top_k,
                                          output_fields=self.output_fields, expr=expr, timeout=client_timeout)
        print(f"milvus_records:{milvus_records}")
        return milvus_records   
        
       
    def exists(self,face_id):
        print(f"exists:{face_id},{type(face_id)}")
        # 记住,字符串查询,构建表达式时要加上单引号,比如下面的'{face_id}',其实face_id本来就是一个字符串类型了,如果不加会出现如下的异常:
        # pymilvus.exceptions.MilvusException: <MilvusException: (code=65535, message=cannot parse expression: face_id == 2_0, error: invalid expression: face_id == 2_0)>
        res = self.query_expr_sync(expr=f"face_id == '{face_id}'", output_fields=self.output_fields)
        #print(f"exists:{res},{len(res)}")
        if len(res) > 0: 
            return True
        
        return False
    
    
    # 修改照片数    
    def add_count(self, face_id):
        res = self.query_expr_sync(expr=f"face_id == '{face_id}'", output_fields=self.output_fields)
        self.sess.delete(expr=f"face_id == '{face_id}'")
        for result in res:
            media_id = result['media_id']
            file_path = result['file_path']
            name = result['name']
            count = int(result['count'])
            save_path = result['save_path']
            embedding = result['embedding']
            
            data = [[] for _ in range(len(self.sess.schema))] 
            data[0].append(face_id)
            data[1].append(media_id)
            data[2].append(file_path)
            data[3].append(name)
            data[4].append(count + 1)
            data[5].append(save_path)
            data[6].append(embedding)    
            print(f"add_count face_id:{face_id},file_path:{file_path}, count:{count}")
            
            # 执行插入操作
            try:
                print('Inserting into Milvus...')
                self.partitions[0].insert(data=data)
            except Exception as e:
                print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
                return False                
        

    def delete_collection(self):
        print("delete_collection")
        self.sess.release()
        utility.drop_collection(self.table_name)

    def delete_partition(self, partition_name):
        print("delete_partition")
        part = Partition(self.sess, partition_name)
        part.release()
        self.sess.drop_partition(partition_name)
        
    def query_all(self,limit=None):
        res = self.sess.query(partition_names = self.partition_names, 
                                output_fields = ["face_id","media_id", "name", "count", "save_path"],
                                expr= f"face_id != ''",
                                _async = False,
                                offset = 0, 
                                limit = None)
                                
        print(res)
        return res

if __name__ == "__main__":
    milvus_avatar= MilvusAvatar("local", "avatar", partition_names=["avatar"])
    media_id = 2
    index = 0
    face_id = f"{media_id}_{index}"
    file_path = "/home/data/bbh.jpg"
    save_path = "/home/data/bbh_avatar.jpg"
    embedding = [i/1000 for i in range(512)]
    milvus_avatar.insert_avatar_sync(face_id, media_id, file_path, "bbh", save_path, embedding)
    #result = milvus_avatar.query_all()
    #print(result)
    print(milvus_avatar.exists(face_id))
    

执行:python milvus_demo.py

如果是针对非字符串字段进行查询,则无需做上面的特殊处理。

  • 9
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值