【KG+RAG 论文】医学知识图谱检索增强 LLM 的框架 —— KG-RAG

论文:Biomedical knowledge graph-enhanced prompt generation for large language models
⭐⭐⭐
Code:github.com/BaranziniLab/KG_RAG

论文速读

这篇论文提出了 KG-RAG 的框架,使用医学知识图谱(SPOKE)来对 LLM 进行检索增强。

该框架的运行效果如下图:

运行示例
上图中,黄色部分是用户问题,蓝色部分是 GPT-4 的原生回答,绿色部分是经过 KG-RAG 框架处理后生成的回答。左边的 (A) 是一个关于一跳推理的问题,右边的 (B) 是一个关于两条推理的问题。

可以观察到,KG-RAG 可以解决这个单跳和双跳的问题,并且相比于 GPT-4,可以提供更加简单明了的答案。

工作过程:KG-RAG 框架的基本工作原理如下:

KG-RAG 基本流程

  1. 实体识别与实体链接:根据用户的问题,使用 LLM 做问句中的疾病实体识别,再对识别的结果对 KG 进行实体链接的检索,得到 KG 中相应的节点(即疾病的节点)
  2. 上下文提取(Context pruning):从 KG 中召回与这个实体相关联子图,再基于 embedding 计算语义相似度从子图中过滤出有用的三元组,之后再将这些三元组将其转换为自然语言
  3. 提示组装与文本生成:把上一步得到的自然语言,与 question 拼在一起,组合为 prompt,再加上 SYSTEM_PROMPT,送给 LLM 来回答,从而获得最终答案

模型效果

效果对比

可以看到,在 KG-RAG 框架下,各 LLM 的表现都有提升。

总结

这篇文章提出的框架是一个结合 KG 来做 RAG 的有效方案,但当用于工业落地时,仍会存在很多问题:

  • 实体识别使用了 LLM,之后又做了 entity link,这样的效率肯定不太高。
  • 为了从召回子图过滤出有用的三元组,这里需要专门的 embedding 模型去做
  • 从关联子图 -> 自然语言这一步,也存在很多坑

这篇文章的工作主要是在医学领域结合 KG 来实现 RAG,但在其他领域,需要结合实际的场景去定制具体的策略。

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
根据提供的引用内容,可以得知prompt+RAG的流程如下: 1. 首先,使用Retriever部分在知识库中检索出top-k个匹配的文档zi。 2. 然后,将query和k个文档拼接起来作为QA的prompt,送入seq2seq模型。 3. seq2seq模型生成回复y。 4. 如果需要进行Re-rank,可以使用LLM来rerank,给LLM写好prompt即可。 下面是一个简单的示例代码,演示如何使用prompt+RAG: ```python from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration # 初始化tokenizer、retriever和seq2seq模型 tokenizer = RagTokenizer.from_pretrained('facebook/rag-token-base') retriever = RagRetriever.from_pretrained('facebook/rag-token-base', index_name='exact', use_dummy_dataset=True) model = RagSequenceForGeneration.from_pretrained('facebook/rag-token-base') # 设置query和context query = "What is the capital of France?" context = "France is a country located in Western Europe. Paris, the capital city of France, is known for its romantic ambiance and iconic landmarks such as the Eiffel Tower." # 使用Retriever部分检索top-k个匹配的文档 retrieved_docs = retriever(query) # 将query和k个文档拼接起来作为QA的prompt input_dict = tokenizer.prepare_seq2seq_batch(query, retrieved_docs[:2], return_tensors='pt') generated = model.generate(input_ids=input_dict['input_ids'], attention_mask=input_dict['attention_mask']) # 输出生成的回复 generated_text = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] print(generated_text) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值