训练Faiss
加载调用
bge_eval_main.py
"""
BGE 评估的主函数:
输入一个query
输出两段各10个候选集
"""
from embeddings_service import question_embedding
from utils.milvusConnectionManager import milvus_handler
import sys
sys.path.append('storage/wangyongpeng/FlagEmbedding/')
import re
import gradio as gr
import uvicorn
from fastapi import Body, FastAPI,BackgroundTasks,Request
from fastapi.middleware.cors import CORSMiddleware
import pydantic
from pydantic import BaseModel
from tqdm import tqdm
# # from langchain.embeddings.huggingface import HuggingFaceEmbeddings
# from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
from langchain_core.documents import Document
import json
import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from fastapi import FastAPI, Query
app = FastAPI()
# 初始化编码模型
embedding_function = HuggingFaceEmbeddings(model_name='/home/jovyan/storage/wangyongpeng/FlagEmbedding/model_output_V1_release',model_kwargs={'device': "cpu" })
# 初始化faiss向量库
vector_store = FAISS.load_local("/home/jovyan/storage/wangyongpeng/FlagEmbedding/faiss_db", embedding_function, allow_dangerous_deserialization=True)
# vector_store.deserialize_from_bytes(allow_dangerous_deserialization=True)
# 原bge embedding 服务
# 输入 query, 输出 milvus_handler.search()搜索的结果
def old_bge_result(query):
# 编码 embedding
emb = question_embedding(query)
# 搜索milvus
expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
res = milvus_handler.client.search(
collection_name="saic_lz_database",
data=emb,
output_fields=["doc",'source','paragraph','pre_paragraph','next_paragraph'],
# doc -> pagecontent, 其他全部作为字典放在meta_data
limit=10,
filter=expr_f
)
print(len(res))
return res
def new_bge_result(query):
# 编码
new_bge_emb = embedding_function.embed_documents(query)
# 查询faiss
related_docs_with_score = vector_store.similarity_search_with_score(query, k=10)
return related_docs_with_score
"""
======================================================================================================================
# 查询所有的文档片段
def query_milvus(s):
saic_vector_name = "saic_lz_database"
expr_f = f"store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1'] and source == '{s}'"
milvus_all = milvus_handler.query(collection_name=saic_vector_name,
filter=f'{expr_f}',
output_fields=['doc','source','paragraph','pre_paragraph','next_paragraph']
,limit=16384)
return milvus_all
"""
[{'source': '1715333131835.2021年03月12日第06次产品项目会会议纪要.pdf', 'id': 451024891781748623},
{'source': '1715333131847.2021年04月30日第09次产品项目会会议纪要.pdf', 'id': 451024891781748624},
{'source': '1715333131863.2021年03月24日第07次产品项目会会议纪要.pdf', 'id': 451024891781748625}
]
"""
# 从milvus中读取数据,存储在本地txt文件中
def get_data_from_milvus():
saic_vector_name = "saic_source"
expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
sources = milvus_handler.query(collection_name=saic_vector_name,
filter=f'{expr_f}',
output_fields=['source']
,limit=16384)
print(len(sources))
print(sources[:3])
source_list = [i.get("source") for i in sources]
print(source_list[:3])
chunks_list = []
# chunks_list = query_milvus(source_list[0])
for s in tqdm(source_list):
chunks = query_milvus(s)
chunks_list += chunks
# with open("/home/jovyan/storage/wangyongpeng/FlagEmbedding/BGE-eval/row.txt", 'a', encoding='utf-8') as f:
# for s in tqdm(source_list):
# chunks = query_milvus(s)
# f.write("\n".join(str(i) for i in chunks))
return chunks_list
def convert_chunks2_documents(milvus_list):
documents = []
for item_json in milvus_list:
doc = item_json.get("doc")
item_json.pop("doc")
document = Document(page_content= doc, metadata=item_json)
documents.append(document)
return documents
#=======================================================================================
# def put_txt_to_faiss():
# chunks = get_data_from_milvus()
# documents = convert_chunks2_documents(chunks)
# db_filepath = "/home/jovyan/storage/wangyongpeng/FlagEmbedding/faiss_db"
# # embedding_function = HuggingFaceEmbeddings(model_name='/home/jovyan/storage/wangyongpeng/FlagEmbedding/model_output_V1_release')
# # model_kwargs={'device': "cuda" }
# print(documents[:2])
# # 初始化 # 添加文档
# vector_store = FAISS.from_documents(documents, embedding_function)
# vector_store.save_local(db_filepath)
# # related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)
def get_data_from_milvus2():
saic_vector_name = "saic_source"
expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
sources = milvus_handler.query(collection_name=saic_vector_name,
filter=f'{expr_f}',
output_fields=['source']
,limit=16384)
print(len(sources))
print(sources[:3])
source_list = [i.get("source") for i in sources]
print(source_list[:3])
chunks_list = []
# chunks_list = query_milvus(source_list[0])
for s in tqdm(source_list):
chunks = query_milvus(s)
chunks_list += chunks
with open("/home/jovyan/storage/wangyongpeng/FlagEmbedding/BGE-eval/row-j.json", 'w', encoding='utf-8') as f:
json.dump(chunks_list, f, ensure_ascii=False, indent=4)
# return chunks_list
======================================================================================================================
"""
# 定义一个路由,接受一个名为"name"的查询字符串参数
@app.get("/hello")
async def hello(query: str = Query(default="", title="Name")):
# 用原embedding模型进行召回
old_res = old_bge_result(query)
# 用new emb召回
new_res = new_bge_result(query)
print(f"old:\n{len(old_res)}\n")
print(f"new:\n{len(new_res)}\n")
return f"\n=================================================旧的==========================================================================={old_res}\n\n\n\n\n\n\n==================================================新的============================================================================{new_res}"
if __name__ == "__main__":
uvicorn.run('bge_eval_main:app', host="0.0.0.0", port=9999)
# query = "AS32"
embedding_service.py
from typing import List
import json
import requests
from configs.env_config import EMBEDDING_SERVICE_URL,EMBEDDING_URL
QUESTION_EMBEDDING_URL = EMBEDDING_SERVICE_URL + "/question_vectorize"
ANSWER_SOURCE_URL = EMBEDDING_SERVICE_URL + "/answer_source"
#EMBEDDING_URL = EMBEDDING_SERVICE_URL + "/vectorize"
class EmbeddingsService():
def __init__(self, url):
self.url = EMBEDDING_URL
def embed_documents(self, texts: List[str]) -> List[List[float]]:
text_list = list(map(lambda x: x.replace("\n", " "), texts))
headers = {'Content-Type':'application/json'}
response = requests.request("POST", self.url,headers=headers, data=json.dumps(text_list),timeout=3600)
embeddings = response.json()['embeddings']
return embeddings
@staticmethod
def answer_source(answer_list: List[str], source_list: List[str]) -> List[List[float]]:
headers = {'Content-Type':'application/json'}
data={
"answer_list": answer_list,
"source_list": source_list
}
response = requests.request("POST", ANSWER_SOURCE_URL,headers=headers, data=json.dumps(data),timeout=3600)
indices = response.json()['indices']
return indices
def embed_query(self, text: str) -> List[float]:
flag = True
count = 0
while flag and count<10:
try:
embeddings = self.embed_documents([text])
embeddings_result = embeddings[0]
flag=False
except:
count+=1
return embeddings[0]
def question_embedding(text):
headers = {'Content-Type':'application/json'}
response = requests.request("POST",QUESTION_EMBEDDING_URL,headers=headers, data=json.dumps(str(text)),timeout=3600)
embeddings = response.json()['embeddings']
return embeddings
configs.model_config.py
私密内容,不对外部可见