BGE微调评测

训练Faiss

加载调用

bge_eval_main.py

"""
BGE 评估的主函数:
输入一个query
输出两段各10个候选集
"""
from embeddings_service import question_embedding
from  utils.milvusConnectionManager import  milvus_handler
import sys
sys.path.append('storage/wangyongpeng/FlagEmbedding/')

import re
import gradio as gr
import uvicorn
from fastapi import Body, FastAPI,BackgroundTasks,Request
from fastapi.middleware.cors import CORSMiddleware
import pydantic
from pydantic import BaseModel
from tqdm import tqdm
# # from langchain.embeddings.huggingface import HuggingFaceEmbeddings
# from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch    


from langchain_core.documents import Document
import json
import faiss
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from fastapi import FastAPI, Query

app = FastAPI()


# 初始化编码模型
embedding_function = HuggingFaceEmbeddings(model_name='/home/jovyan/storage/wangyongpeng/FlagEmbedding/model_output_V1_release',model_kwargs={'device': "cpu" }) 

# 初始化faiss向量库
vector_store = FAISS.load_local("/home/jovyan/storage/wangyongpeng/FlagEmbedding/faiss_db", embedding_function, allow_dangerous_deserialization=True)
# vector_store.deserialize_from_bytes(allow_dangerous_deserialization=True)


# 原bge embedding 服务
# 输入 query, 输出 milvus_handler.search()搜索的结果
def old_bge_result(query):
    # 编码 embedding
    emb = question_embedding(query)
    # 搜索milvus
    expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
    res = milvus_handler.client.search(
        collection_name="saic_lz_database",
        data=emb,
        output_fields=["doc",'source','paragraph','pre_paragraph','next_paragraph'],
        # doc -> pagecontent, 其他全部作为字典放在meta_data
        limit=10,
        filter=expr_f
    )
    print(len(res))
    return res
    
def new_bge_result(query):
    # 编码
    new_bge_emb = embedding_function.embed_documents(query)
    # 查询faiss
    related_docs_with_score = vector_store.similarity_search_with_score(query, k=10)
    return related_docs_with_score
    

"""
======================================================================================================================


# 查询所有的文档片段
def query_milvus(s):
    saic_vector_name = "saic_lz_database"
    expr_f = f"store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1'] and source == '{s}'"
    milvus_all = milvus_handler.query(collection_name=saic_vector_name,
                      filter=f'{expr_f}',
                      output_fields=['doc','source','paragraph','pre_paragraph','next_paragraph']
                      ,limit=16384)
    return milvus_all


    
"""
[{'source': '1715333131835.2021年03月12日第06次产品项目会会议纪要.pdf', 'id': 451024891781748623}, 
{'source': '1715333131847.2021年04月30日第09次产品项目会会议纪要.pdf', 'id': 451024891781748624}, 
{'source': '1715333131863.2021年03月24日第07次产品项目会会议纪要.pdf', 'id': 451024891781748625}
]
"""  
# 从milvus中读取数据,存储在本地txt文件中
def get_data_from_milvus():
    saic_vector_name = "saic_source"
    expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
    sources = milvus_handler.query(collection_name=saic_vector_name,
                      filter=f'{expr_f}',
                      output_fields=['source']
                      ,limit=16384)
    print(len(sources))
    print(sources[:3])
    source_list = [i.get("source") for i in sources]
    print(source_list[:3])
    chunks_list = []
#     chunks_list = query_milvus(source_list[0])
    for s in tqdm(source_list):
        chunks = query_milvus(s)
        chunks_list += chunks
#     with open("/home/jovyan/storage/wangyongpeng/FlagEmbedding/BGE-eval/row.txt", 'a', encoding='utf-8') as f:
#         for s in tqdm(source_list):
#             chunks = query_milvus(s)
#             f.write("\n".join(str(i) for i in chunks))
    return chunks_list
        
def convert_chunks2_documents(milvus_list):
    documents = []
    for item_json in milvus_list:
        doc = item_json.get("doc")
        item_json.pop("doc")
        document = Document(page_content= doc, metadata=item_json)
        documents.append(document)
    return documents
 
#=======================================================================================



    
# def put_txt_to_faiss():
#     chunks = get_data_from_milvus()
#     documents = convert_chunks2_documents(chunks)
#     db_filepath = "/home/jovyan/storage/wangyongpeng/FlagEmbedding/faiss_db"
# #     embedding_function = HuggingFaceEmbeddings(model_name='/home/jovyan/storage/wangyongpeng/FlagEmbedding/model_output_V1_release') 
#     # model_kwargs={'device': "cuda" }
#     print(documents[:2])
#     # 初始化 # 添加文档
#     vector_store = FAISS.from_documents(documents, embedding_function) 
#     vector_store.save_local(db_filepath)
#     # related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k)

def get_data_from_milvus2():
    saic_vector_name = "saic_source"
    expr_f = "store_name in ['saic_chanpinxiangmuhuihuiyiziliaoceshichongfuaigc_test1','saic_chanpinxiangmuhuihuiyijiyaoaigc_test1']"
    sources = milvus_handler.query(collection_name=saic_vector_name,
                      filter=f'{expr_f}',
                      output_fields=['source']
                      ,limit=16384)
    print(len(sources))
    print(sources[:3])
    source_list = [i.get("source") for i in sources]
    print(source_list[:3])
    chunks_list = []
#     chunks_list = query_milvus(source_list[0])
    for s in tqdm(source_list):
        chunks = query_milvus(s)
        chunks_list += chunks
    with open("/home/jovyan/storage/wangyongpeng/FlagEmbedding/BGE-eval/row-j.json", 'w', encoding='utf-8') as f:
        json.dump(chunks_list, f, ensure_ascii=False, indent=4)
#     return chunks_list
======================================================================================================================
"""



# 定义一个路由,接受一个名为"name"的查询字符串参数
@app.get("/hello")
async def hello(query: str = Query(default="", title="Name")):
    # 用原embedding模型进行召回
    old_res = old_bge_result(query)
    # 用new emb召回
    new_res = new_bge_result(query)
    print(f"old:\n{len(old_res)}\n")
    print(f"new:\n{len(new_res)}\n")
    return f"\n=================================================旧的==========================================================================={old_res}\n\n\n\n\n\n\n==================================================新的============================================================================{new_res}"



if __name__ == "__main__":
    uvicorn.run('bge_eval_main:app', host="0.0.0.0", port=9999)
#     query = "AS32"

embedding_service.py

from typing import List
import json
import requests
from configs.env_config import EMBEDDING_SERVICE_URL,EMBEDDING_URL

QUESTION_EMBEDDING_URL = EMBEDDING_SERVICE_URL + "/question_vectorize"
ANSWER_SOURCE_URL = EMBEDDING_SERVICE_URL + "/answer_source"
#EMBEDDING_URL = EMBEDDING_SERVICE_URL + "/vectorize"

class EmbeddingsService():
    def __init__(self, url):
        self.url = EMBEDDING_URL

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        text_list = list(map(lambda x: x.replace("\n", " "), texts))
        headers = {'Content-Type':'application/json'}
        response = requests.request("POST", self.url,headers=headers, data=json.dumps(text_list),timeout=3600)
        embeddings = response.json()['embeddings']
        return embeddings
    
    @staticmethod
    def answer_source(answer_list: List[str], source_list: List[str]) -> List[List[float]]:
        headers = {'Content-Type':'application/json'}
        data={
            "answer_list": answer_list,
            "source_list": source_list
        }
        response = requests.request("POST", ANSWER_SOURCE_URL,headers=headers, data=json.dumps(data),timeout=3600)
        indices = response.json()['indices']
        return indices

    def embed_query(self, text: str) -> List[float]:
        flag = True
        count = 0
        while flag and count<10:
            try:
                embeddings = self.embed_documents([text])
                embeddings_result = embeddings[0]
                flag=False
            except:
                count+=1
        return embeddings[0]
def question_embedding(text):
    headers = {'Content-Type':'application/json'}
    response = requests.request("POST",QUESTION_EMBEDDING_URL,headers=headers, data=json.dumps(str(text)),timeout=3600)
    embeddings = response.json()['embeddings']
    return embeddings

configs.model_config.py

私密内容,不对外部可见

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值