Document-level RE中,关键词DocRED,Re-DocRED,Dev,Test,Ign F1,F1分别都是什么意思(附代码)

在文档级关系抽取(Document-level Relation Extraction, DocRE)领域,以下是一些常见的术语及其含义:

  1. DocRED: Document-level Relation Extraction Dataset。DocRED是一个广泛使用的文档级关系抽取数据集,用于训练和评估模型在文档级关系抽取任务上的性能。

  2. Re-DocRED: Revised DocRED。通常指DocRED数据集的修订版,可能包含改进的数据注释或修订的标注以提高数据质量。

  3. Dev: Development set。开发集用于模型调参和验证模型性能,通常不用于最终评估。

  4. Test: Test set。测试集用于评估模型的最终性能,通常在模型训练和调参后使用,且其标签对模型不可见。

  5. Ign F1: Ignored F1 score。F1分数的一种变体,计算时忽略某些不确定的预测,例如DocRED中忽略NA (Not Available)类别的预测。

  6. F1: F1 score。是模型性能的综合指标,结合了精确率(Precision)和召回率(Recall)。F1分数越高,模型的整体性能越好。

示例代码

下面是一个简单的示例代码,展示如何在DocRED数据集上训练和评估一个关系抽取模型,并计算F1和Ign F1分数。

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel

class RelationExtractionModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased'):
        super(RelationExtractionModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_relations)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        cls_output = outputs[1]  # CLS token's output
        logits = self.classifier(cls_output)
        return logits

def compute_f1(predictions, labels, ignore_na=True):
    true_positive = false_positive = false_negative = 0
    for pred, label in zip(predictions, labels):
        if ignore_na and label == 0:
            continue
        if pred == label:
            true_positive += 1
        else:
            false_positive += 1
            false_negative += 1
    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    return f1

# 示例数据
input_ids = torch.tensor([[101, 2003, 1037, 3185, 102], [101, 2003, 1037, 2309, 102]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]])
labels = torch.tensor([1, 2])  # 假设的标签

num_relations = 3  # 假设关系类别数
model = RelationExtractionModel()
optimizer = optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()

# 模型训练
model.train()
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()

# 模型评估
model.eval()
with torch.no_grad():
    logits = model(input_ids, attention_mask)
    predictions = torch.argmax(logits, dim=1).tolist()

# 计算F1和Ign F1
f1 = compute_f1(predictions, labels, ignore_na=False)
ign_f1 = compute_f1(predictions, labels, ignore_na=True)

print(f"F1 score: {f1:.4f}")
print(f"Ign F1 score: {ign_f1:.4f}")

在这个示例中:

  • 使用BERT作为基础模型进行关系抽取。
  • compute_f1函数计算F1和Ign F1分数,参数ignore_na决定是否忽略NA类别。
  • 假设有3个关系类别进行训练和评估。

这个代码提供了一个基础框架,可以根据具体数据集和任务需求进行扩展和修改。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值