医疗问答检索任务的完整 Pipeline 示例

示例代码:

# 医疗问答检索任务完整 Pipeline 示例
# 包含训练数据、retrieval、评估三步

from typing import Dict, List
from collections import defaultdict
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

# 模拟的语料库(corpus)
corpus = {
    "d1": "糖尿病是一种慢性病,需要控制饮食和规律运动。",
    "d2": "高血压与钠摄入量有关,应减少食盐摄入。",
    "d3": "血糖控制可以通过口服降糖药或胰岛素治疗。",
    "d4": "冠心病患者应避免剧烈运动。"
}

# 模拟的任务数据(包含训练和评估)
task_data = [
    {
        "query": "糖尿病患者如何控制血糖?",
        "positive": ["d1", "d3"],
        "negative": ["d2", "d4"]
    },
    {
        "query": "高血压如何治疗?",
        "positive": ["d2"],
        "negative": ["d1", "d3", "d4"]
    }
]

# 假设我们加载一个分类模型(判断是否相关)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-snli")
model.eval()

# 检索函数:用模型判断每个 query 与文档是否相关
def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:
    inputs = tokenizer([query] * len(corpus), list(corpus.values()), padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)
        relevance_scores = probs[:, 1]  # 假设第1类是“相关”
    return {k: v.item() for k, v in zip(corpus.keys(), relevance_scores)}

# 构造 retrieved_results(模拟模型推理)
retrieved_results = {}
for item in task_data:
    query = item["query"]
    retrieved_results[query] = query_retrieval(query, corpus)

# 构造评估用的标准答案(Ground Truth)
relevant_docs = {item["query"]: set(item["positive"]) for item in task_data}

# 简单评估:计算 Recall@2
def evaluate_recall(retrieved_results: Dict[str, Dict[str, float]], relevant_docs: Dict[str, set], k: int = 2):
    recalls = []
    for query, scores in retrieved_results.items():
        top_k_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]
        top_k_doc_ids = [doc_id for doc_id, _ in top_k_docs]
        num_hits = len(set(top_k_doc_ids) & relevant_docs[query])
        recall = num_hits / len(relevant_docs[query])
        print(f"Query: {query}\nTop-{k}: {top_k_doc_ids}\nRecall@{k}: {recall:.2f}\n")
        recalls.append(recall)
    avg_recall = sum(recalls) / len(recalls)
    print(f"Average Recall@{k}: {avg_recall:.2f}")

# 评估
evaluate_recall(retrieved_results, relevant_docs, k=2)

输出结果:
在这里插入图片描述

这段代码是一个医疗问答检索任务的完整 Pipeline 示例,涵盖了以下三个核心功能:


功能一:构造模拟检索任务数据(corpus 与任务集)

作用:

  • 模拟一个简单的检索场景,用于测试问答匹配系统的效果。

内容:

  • corpus:构造一个小型文档库(4条医疗相关文本),每条文本用 d1, d2, … 编号。
  • task_data:模拟用户查询,标注其相关文档(positive)和不相关文档(negative)。

🔎 示例:

task_data = [
    {
        "query": "糖尿病患者如何控制血糖?",
        "positive": ["d1", "d3"],
        "negative": ["d2", "d4"]
    },

这表示:对这个 query,d1d3 是相关文档。


功能二:基于句对分类模型进行检索匹配

作用:

  • 使用一个现成的分类模型(比如 SNLI 领域的 BERT)来判断 query 和 每篇文档 的相关性打分。

核心函数:

def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:

步骤详细解释:

  1. Tokenizer 编码: query 与每个 corpus 文档形成句对,用 BERT 的 tokenizer 编码。
  2. 模型推理: 输入模型做前向传播,得到每个句对的分类 logits。
  3. Softmax 得分: 用 softmax 转换为概率,假设“相关”是第1类(probs[:, 1])。
  4. 返回结果: 返回一个字典,每个文档对应一个相关性得分。

🔎 示例:

{'d1': 0.85, 'd2': 0.12, 'd3': 0.78, 'd4': 0.09}

表示该 query 与 d1, d3 最相关。


功能三:评估模型效果(Recall@k)

作用:

  • 检查模型检索的 top-k 结果中,有多少比例是标注为相关的文档。

评估函数:

def evaluate_recall(retrieved_results, relevant_docs, k=2)

步骤解释:

  1. 提取 top-k: 对每个 query 按得分排序,取前 k 个文档。
  2. 计算 Recall: Recall@k = (top-k 中相关文档数量) / (所有标注相关文档数量)
  3. 打印和平均: 每个 query 的 Recall@k 打印出来,同时统计平均值。

🔄 总结:整个流程实现了什么?

步骤描述
1️⃣ 数据准备构造 query、文档、标注对
2️⃣ 检索利用分类模型计算 query 与文档的相关性得分
3️⃣ 排序与评估取 top-k 文档,计算 recall,衡量模型效果

下面对代码部分进行详细解释

数据准备、检索到评估,我将按模块逐行详细解释


一、导入库和模型

from typing import Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

含义说明:

  • Dict, List: 类型注解,说明变量的数据结构。
  • torch: PyTorch 深度学习库。
  • transformers: HuggingFace 提供的 Transformers 模型接口。
  • F.softmax: 用于将模型输出转为“概率”。

二、模拟语料库(corpus)

corpus = {
    "d1": "糖尿病是一种慢性病,需要控制饮食和规律运动。",
    "d2": "高血压与钠摄入量有关,应减少食盐摄入。",
    "d3": "血糖控制可以通过口服降糖药或胰岛素治疗。",
    "d4": "冠心病患者应避免剧烈运动。"
}

含义:

  • 这是一个简单的文档库,每个文档有个 id(如 d1),和一段医学相关的文本内容。

三、任务数据(模拟的训练/评估样本)

task_data = [
    {
        "query": "糖尿病患者如何控制血糖?",
        "positive": ["d1", "d3"],
        "negative": ["d2", "d4"]
    },
    {
        "query": "高血压如何治疗?",
        "positive": ["d2"],
        "negative": ["d1", "d3", "d4"]
    }
]

含义:

  • 模拟用户提出的问题(query)。
  • positive: 该问题对应的正确文档 id。
  • negative: 与问题无关的文档 id。

例如:

  • 问题:“糖尿病患者如何控制血糖?”

    • 正确答案文档:d1(控制饮食)和 d3(药物治疗)。
    • 错误文档:d2(高血压)和 d4(冠心病)。

四、加载预训练模型

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-snli")
model.eval()

含义:

  • 加载 BERT 模型(用于句子对判断:是否相关)。
  • tokenizer: 把文本转成模型输入格式。
  • model.eval(): 设置模型为评估模式,不进行梯度更新。

注意:这里用的模型是 snli 任务(自然语言推理),它学会了判断两个句子是否有关系,非常适合做句子对的匹配任务。


五、检索函数:判断 Query 和每篇文档是否相关

def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:
    inputs = tokenizer([query] * len(corpus), list(corpus.values()), padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)
        relevance_scores = probs[:, 1]  # 假设第1类是“相关”
    return {k: v.item() for k, v in zip(corpus.keys(), relevance_scores)}

详细解释:

  • 输入:一个 query,比如“糖尿病怎么控制血糖?”
  • 针对每个文档,都和 query 配成一个句子对(比如“糖尿病控制”和 d1 文本)。
  • 使用 tokenizer 编码成批输入,送入模型。
  • 得到模型对每个句子对的预测 logits,softmax 转为概率。
  • probs[:, 1]: 取“相关”这个类别的概率,作为得分。

🔍 例子:

query = "糖尿病患者如何控制血糖?"
文档 d1 = "糖尿病是一种慢性病,需要控制饮食和规律运动。"
→ BERT 判断是否相关,输出一个概率,比如 0.88

六、生成每个 Query 的检索结果

retrieved_results = {}
for item in task_data:
    query = item["query"]
    retrieved_results[query] = query_retrieval(query, corpus)
  • 针对每个问题 query,调用上面定义的 query_retrieval,获取所有文档的相关性得分。

🔍 结果类似于:

{
  "糖尿病患者如何控制血糖?": {"d1": 0.88, "d2": 0.1, ...},
  ...
}

七、构造标准答案(Ground Truth)

relevant_docs = {item["query"]: set(item["positive"]) for item in task_data}
  • 把每个 query 的 positive 列表转成集合,作为真实相关文档。

八、评估:Recall@K

def evaluate_recall(retrieved_results: Dict[str, Dict[str, float]], relevant_docs: Dict[str, set], k: int = 2):
    recalls = []
    for query, scores in retrieved_results.items():
        top_k_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]
        top_k_doc_ids = [doc_id for doc_id, _ in top_k_docs]
        num_hits = len(set(top_k_doc_ids) & relevant_docs[query])
        recall = num_hits / len(relevant_docs[query])
        print(f"Query: {query}\nTop-{k}: {top_k_doc_ids}\nRecall@{k}: {recall:.2f}\n")
        recalls.append(recall)
    avg_recall = sum(recalls) / len(recalls)
    print(f"Average Recall@{k}: {avg_recall:.2f}")

详细解释:

  • 对每个 query,从检索结果中取出得分最高的前 k 篇文档。
  • 看这 k 篇文档里,有多少是正确答案(交集)。
  • Recall 计算方式:命中的正例数 / 实际正例数
  • 最后输出平均 Recall。

🔍 例子:

Query: 糖尿病如何控制血糖?
Top-2: ["d3", "d1"]  # 全部命中
Recall@2 = 2 / 2 = 1.00

九、运行评估

evaluate_recall(retrieved_results, relevant_docs, k=2)

会输出类似:

Query: 糖尿病患者如何控制血糖?
Top-2: ['d3', 'd1']
Recall@2: 1.00

Query: 高血压如何治疗?
Top-2: ['d2', 'd4']
Recall@2: 1.00

Average Recall@2: 1.00

总结:整个流程干了什么?

  1. 定义语料库和训练样本
  2. 用 BERT 模型对 query 和文档进行配对判断
  3. 打分并选出相关文档
  4. 根据真实正例,计算 Recall@K 评估效果
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值