抽取语料库索引语义向量并建milvus库

import random
import time

import os
import sys
from tqdm import tqdm
import numpy as np
import paddle
from paddle import inference
from paddlenlp.transformers import AutoModel, AutoTokenizer
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger

sys.path.append('.')

def convert_example(example,
                    tokenizer,
                    max_seq_length=512,
                    pad_to_max_seq_len=False):
    result = []
    for key, text in example.items():
        encoded_inputs = tokenizer(text=text,
                                   max_seq_len=max_seq_length,
                                   pad_to_max_seq_len=pad_to_max_seq_len)
        input_ids = encoded_inputs["input_ids"]
        token_type_ids = encoded_inputs["token_type_ids"]
        result += [input_ids, token_type_ids]
    return result

model_dir='./output/yysy/'
corpus_file='./datasets/yysy/milvus/milvus_data_s.csv'
max_seq_length=64
batch_size=64
device='gpu'
cpu_threads=8
model_name_or_path='rocketqa-zh-base-query-encoder'

class Predictor(object):
    def __init__(self,
                 model_dir,
                 device="gpu",
                 max_seq_length=128,
                 batch_size=32,
                 use_tensorrt=False,
                 precision="fp32",
                 cpu_threads=10,
                 enable_mkldnn=False):
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size
        model_file = model_dir + "inference.get_pooled_embedding.pdmodel"
        params_file = model_dir + "inference.get_pooled_embedding.pdiparams"
        if not os.path.exists(model_file):
            raise ValueError("not find model file path {}".format(model_file))
        if not os.path.exists(params_file):
            raise ValueError("not find params file path {}".format(params_file))
        config = paddle.inference.Config(model_file, params_file)
        if device == "gpu":
            config.enable_use_gpu(100, 0)
            precision_map = {
                "fp16": inference.PrecisionType.Half,
                "fp32": inference.PrecisionType.Float32,
                "int8": inference.PrecisionType.Int8
            }
            precision_mode = precision_map[precision]
            if use_tensorrt:
                config.enable_tensorrt_engine(max_batch_size=batch_size,
                                              min_subgraph_size=30,
                                              precision_mode=precision_mode)
        elif device == "cpu":
            config.disable_gpu()
            if enable_mkldnn:
                config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()
            config.set_cpu_math_library_num_threads(cpu_threads)
        elif device == "xpu":
            config.enable_xpu(100)
        config.switch_use_feed_fetch_ops(False)
        self.predictor = paddle.inference.create_predictor(config)
        self.input_handles = [
            self.predictor.get_input_handle(name)
            for name in self.predictor.get_input_names()
        ]
        self.output_handle = self.predictor.get_output_handle(
            self.predictor.get_output_names()[0])
    def predict(self, data, tokenizer):
        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # input
            Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
                ),  # segment
        ): fn(samples)
        all_embeddings = []
        examples = []
        for idx, text in enumerate(tqdm(data)):
            input_ids, segment_ids = convert_example(
                text,
                tokenizer,
                max_seq_length=self.max_seq_length,
                pad_to_max_seq_len=True)
            examples.append((input_ids, segment_ids))
            if (len(examples) >=self.batch_size):
                input_ids, segment_ids = batchify_fn(examples)
                self.input_handles[0].copy_from_cpu(input_ids)
                self.input_handles[1].copy_from_cpu(segment_ids)
                self.predictor.run()
                logits = self.output_handle.copy_to_cpu()
                all_embeddings.append(logits)
                examples = []
        if (len(examples) > 0):
            input_ids, segment_ids = batchify_fn(examples)
            self.input_handles[0].copy_from_cpu(input_ids)
            self.input_handles[1].copy_from_cpu(segment_ids)
            self.predictor.run()
            logits = self.output_handle.copy_to_cpu()
            all_embeddings.append(logits)
        all_embeddings = np.concatenate(all_embeddings, axis=0)
        np.save('yysy_corpus_embedding', all_embeddings)
def read_text(file_path):
    file = open(file_path)
    id2corpus = {}
    for idx, data in enumerate(file.readlines()):
        id2corpus[idx] = data.strip()
    return id2corpus

 predictor = Predictor(model_dir, device, max_seq_length, batch_size)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
id2corpus = read_text(corpus_file)

 corpus_list = [{idx: text} for idx, text in id2corpus.items()]#用来构建索引库的文本

predictor.predict(corpus_list, tokenizer)

MILVUS_HOST = '127.0.0.1'
MILVUS_PORT = 19530
data_dim = 256
top_k = 20
collection_name = 'literature_search'
partition_tag = 'partition_1'
embedding_name = 'embeddings'
index_config = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {
        "nlist": 1000
    },
}
search_params = {
    "metric_type": "L2",
    "params": {
        "nprobe": top_k
    },
}

from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)

fmt = "\n=== {:30} ===\n"
text_max_len = 1000
fields = [
    FieldSchema(name="pk",
                dtype=DataType.INT64,
                is_primary=True,#主键
                auto_id=False,#不自动增长
                max_length=100),#id
    FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=text_max_len),#text
    FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=data_dim)#embedding
]
schema = CollectionSchema(fields, "Neural Search Index")

class VecToMilvus():#语义向量-->milvus
    def __init__(self):
        print(fmt.format("start connecting to Milvus"))
        connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
        self.collection = None
    def has_collection(self, collection_name):
        try:
            has = utility.has_collection(collection_name)
            print(f"Does collection {collection_name} exist in Milvus: {has}")
            return has
        except Exception as e:
            print("Milvus has_table error:", e)
    def creat_collection(self, collection_name):
        try:
            print(fmt.format("Create collection {}".format(collection_name)))
            self.collection = Collection(collection_name,
                                         schema,
                                         consistency_level="Strong")
        except Exception as e:
            print("Milvus create collection error:", e)
    def drop_collection(self, collection_name):
        try:
            utility.drop_collection(collection_name)
        except Exception as e:
            print("Milvus delete collection error:", e)
    def create_index(self, index_name):
        try:
            print(fmt.format("Start Creating index"))
            self.collection.create_index(index_name, index_config)
            print(fmt.format("Start loading"))
            self.collection.load()
        except Exception as e:
            print("Milvus create index error:", e)
    def has_partition(self, partition_tag):
        try:
            result = self.collection.has_partition(partition_tag)
            return result
        except Exception as e:
            print("Milvus has partition error: ", e)
    def create_partition(self, partition_tag):
        try:
            self.collection.create_partition(partition_tag)
            print('create partition {} successfully'.format(partition_tag))
        except Exception as e:
            print('Milvus create partition error: ', e)
    def insert(self, entities, collection_name, index_name, partition_tag=None):
        try:
            if not self.has_collection(collection_name):
                self.creat_collection(collection_name)
                self.create_index(index_name)
            else:
                self.collection = Collection(collection_name)
            if (partition_tag
                    is not None) and (not self.has_partition(partition_tag)):
                self.create_partition(partition_tag)
            self.collection.insert(entities, partition_name=partition_tag)
            print(
                f"Number of entities in Milvus: {self.collection.num_entities}"
            )  # check the num_entites
        except Exception as e:
            print("Milvus insert error:", e)

class RecallByMilvus():#从milvus召回向量
    def __init__(self):
        print(fmt.format("start connecting to Milvus"))
        connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
        self.collection = None
    def get_collection(self, collection_name):
        try:
            print(fmt.format("Connect collection {}".format(collection_name)))
            self.collection = Collection(collection_name)
        except Exception as e:
            print("Milvus create collection error:", e)
    def search(self,
               vectors,
               embedding_name,
               collection_name,
               partition_names=[],
               output_fields=[]):
        try:
            self.get_collection(collection_name)
            result = self.collection.search(vectors,
                                            embedding_name,
                                            search_params,
                                            limit=top_k,
                                            partition_names=partition_names,
                                            output_fields=output_fields)
            return result
        except Exception as e:
            print('Milvus recall error: ', e)

data_path='./datasets/yysy/milvus/milvus_data_s.csv'

embedding_path='./yysy_corpus_embedding.npy'
index=18
batch_size=5000

def read_text(file_path):
    file = open(file_path)
    id2corpus = []
    for idx, data in enumerate(file.readlines()):
        id2corpus.append(data.strip())
    return id2corpus

corpus_list_embed=read_text(data_path)

corpus_list_embed[:5]

embeddings = np.load(embedding_path)

embedding_ids = [i for i in range(embeddings.shape[0])]#嵌入ids

client = VecToMilvus()

 client.has_collection(collection_name)

client.drop_collection(collection_name)

data_size = len(embedding_ids)

x=[corpus_list_embed[j][:1000]for j in range(10000, 15000,1)]#[:200]文本切片操作

max([len(i) for i in x])

for i in range(0, data_size, batch_size):
        print(i)

for i in range(0, data_size, batch_size):#i:0-5000-10000-....
    cur_end = i + batch_size
    if (cur_end > data_size):#确保下标不越界
        cur_end = data_size
    batch_emb = embeddings[np.arange(i, cur_end)]#一个批次的嵌入向量
    entities = [
        [j for j in range(i, cur_end, 1)],#索引
        [corpus_list_embed[j][:text_max_len - 1] for j in range(i, cur_end, 1)],#文本
        batch_emb  #每个批次嵌入向量
    ]
    client.insert(collection_name=collection_name,
                  entities=entities,
                  index_name=embedding_name,
                  partition_tag=partition_tag)

recall_client = RecallByMilvus()

embeddings = embeddings[np.arange(index, index + 1)]

time_start = time.time()# start
result = recall_client.search(embeddings,
                              embedding_name,
                              collection_name,
                              partition_names=[partition_tag],
                              output_fields=['pk', 'text'])
time_end = time.time()# end

sum_t = time_end - time_start
print('time cost', sum_t, 's')

for hits in result:
    for hit in hits:
        print(f"hit: {hit}, text field: {hit.entity.get('text')}")

  • 15
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值