废话不多说,直接上代码(写的比较糙,勿喷勿喷)
# coding:utf-8
import os
import json
import time
import faiss
import numpy as np
import pandas as pd
from text2vec import SentenceModel
class VecRetrieval(object):
"""使用faiss进行向量检索"""
def __init__(self, model_path, target_vec_file, dim, num, embedding_store):
"""
:param model_path:向量模型路径(建议使用bge模型)
:param target_vec_file:需要写入向量库文本
:param dim:向量维度
:param num:返回前num个相似文本
:param embedding_store:向量存储路径
"""
self.model_path = model_path
self.target_vec_file = target_vec_file
self.dim = dim
self.num = num
self.model = SentenceModel(self.model_path)
self.embedding_store = embedding_store
def read_txt(self):
"""
读取需要存入向量库的文件
:return:文本列表
"""
with open(self.target_vec_file, "r", encoding="utf-8") as f:
texts = f.readlines()
texts = [text.strip() for text in texts]
return texts
def vec_storage(self):
"""
文本转为向量工具
:return: 向量库对象
"""
sentence_embedding = self.model.encode(self.read_txt())
index = faiss.IndexFlatL2(self.dim)
index.add(sentence_embedding)
os.makedirs(self.embedding_store, exist_ok=True)
faiss.write_index(index, os.path.join(self.embedding_store, "embedding_index.index"))
return index
def vec_retrieval(self, query, rate=0.5):
"""
向量检索;
由于faiss向量数据库检索返回结果不为空,因此设置rate参数,计算规则比较暴力,直接取前百分之rate*100有结果,剩余的返回NULL
:param query: 需要检索的文本,可以是字符串,也可以是列表
:param rate:取值0-1之间,值越大,匹配到的结果越多,反之越少。匹配不到的返回NULL,自定义设置
:return:检索到的相似文本
"""
if isinstance(query, str):
query_vec = self.model.encode([query], max_seq_length=64)
if os.path.exists(os.path.join(self.embedding_store, "embedding_index.index")):
index = faiss.read_index(os.path.join(self.embedding_store, "embedding_index.index"))
else:
index = self.vec_storage()
_, I = index.search(query_vec, self.num)
similarity_text = self.read_txt()[I[0][0]]
elif isinstance(query, list):
query_vec = self.model.encode(query, max_seq_length=64)
if os.path.exists(os.path.join(self.embedding_store, "embedding_index.index")):
start_time = time.time()
index = faiss.read_index(os.path.join(self.embedding_store, "embedding_index.index"))
spend_time = time.time() - start_time
print(f'加载向量数据库耗时:{spend_time} s')
else:
start_time = time.time()
index = self.vec_storage()
spend_time = time.time() - start_time
print(f'文本转向量耗时:{spend_time} s')
D, I = index.search(query_vec, self.num)
similarity_text = []
texts = self.read_txt()
scores = [i for d in D for i in d]
scores.sort()
for ids, i in enumerate(I):
if D[ids][0] < scores[int(len(scores) * rate)]:
similarity_text.append(texts[i[0]])
else:
similarity_text.append("NULL")
else:
similarity_text = None
return similarity_text
最有用的是可以将结构化的数据直接存在向量库中进行查询,不依托于langchain框架,简单高效