一个简单的医疗问答RAG

https://github.com/wyf3/llm_related

下面格式是ipynb,具体参考原链接吧

导包

from langchain_community.retrievers import BM25Retriever
from typing import List
import jieba
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings

加载知识库,按照换行符切分

loader = TextLoader('medical_data.txt', encoding="utf-8")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 500,
    chunk_overlap  = 0,
    length_function = len,
    separators=['\n']
)
docs = text_splitter.split_documents(documents)

打印前10条数据

docs[:10]

构建bm25需要的处理函数,jieba分词即可

def preprocessing_func(text: str) -> List[str]:
    return list(jieba.cut(text))
bm25 = BM25Retriever(docs=docs,k=10)
bm25
# print(bm25.k)
retriever = bm25.from_documents(docs,preprocess_func=preprocessing_func)

根据query得到检索结果

retriever.invoke('骨折了应该怎么办')

这里虽然写的是vectorizer ,但是实际上是文本召回,也就是bm25

from rank_bm25 import BM25Okapi
texts = [i.page_content for i in docs]
texts_processed = [preprocessing_func(t) for t in texts]
vectorizer = BM25Okapi(texts_processed)

取前10个检索

vectorizer.get_top_n(preprocessing_func('骨折了应该怎么办'),texts, n=10)

向量召回
这里模型下载用的modelscope,默认模型会下载到~/.cache/modelscope/hub中,如果需要修改下载目录,可以手动指定环境变量:MODELSCOPE_CACHE,modelscope会将模型和数据集下载到该环境变量指定的目录中。

embeddings = HuggingFaceEmbeddings(model_name='bge-large-zh-v1___5', model_kwargs = {'device': 'cuda:0'})

构建向量数据库

db = FAISS.from_documents(docs, embeddings)

文本召回结果

bm25_res = vectorizer.get_top_n(preprocessing_func('骨折了应该怎么办'),texts, n=10)
bm25_res

向量召回结果

vector_res = db.similarity_search('骨折了应该怎么办', k=10)
vector_res

RRF实现融合排序,公式很简单,出现次数 + 1 / ( r a n k + m ) 1 / (rank+m) 1/(rank+m)

def rrf(vector_results: List[str], text_results: List[str], k: int=10, m: int=60):
        """
        使用RRF算法对两组检索结果进行重排序
        
        params:
        vector_results (list): 向量召回的结果列表,每个元素是str
        text_results (list): 文本召回的结果列表,每个元素是str
        k(int): 排序后返回前k个
        m (int): 超参数
        
        return:
        重排序后的结果列表,每个元素是(文档ID, 融合分数)
        """
        
        doc_scores = {}
        
        # 遍历两组结果,计算每个文档的融合分数
        for rank, doc_id in enumerate(vector_results):
            doc_scores[doc_id] = doc_scores.get(doc_id, 0) + 1 / (rank+m)
        for rank, doc_id in enumerate(text_results):
            doc_scores[doc_id] = doc_scores.get(doc_id, 0) + 1 / (rank+m)
        
        # 将结果按融合分数排序
        sorted_results = [d for d, _ in sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:k]]

        return sorted_results

得到融合后的检索结果

vector_results = [i.page_content for i in vector_res]
text_results = [i for i in bm25_res]

# print(vector_results)

# print(text_results)

rrf_res = rrf(vector_results, text_results)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值