在文档级关系抽取(Document-level Relation Extraction, DocRE)领域,以下是一些常见的术语及其含义:
-
DocRED: Document-level Relation Extraction Dataset。DocRED是一个广泛使用的文档级关系抽取数据集,用于训练和评估模型在文档级关系抽取任务上的性能。
-
Re-DocRED: Revised DocRED。通常指DocRED数据集的修订版,可能包含改进的数据注释或修订的标注以提高数据质量。
-
Dev: Development set。开发集用于模型调参和验证模型性能,通常不用于最终评估。
-
Test: Test set。测试集用于评估模型的最终性能,通常在模型训练和调参后使用,且其标签对模型不可见。
-
Ign F1: Ignored F1 score。F1分数的一种变体,计算时忽略某些不确定的预测,例如DocRED中忽略NA (Not Available)类别的预测。
-
F1: F1 score。是模型性能的综合指标,结合了精确率(Precision)和召回率(Recall)。F1分数越高,模型的整体性能越好。
示例代码
下面是一个简单的示例代码,展示如何在DocRED数据集上训练和评估一个关系抽取模型,并计算F1和Ign F1分数。
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
class RelationExtractionModel(nn.Module):
def __init__(self, bert_model_name='bert-base-uncased'):
super(RelationExtractionModel, self).__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.classifier = nn.Linear(self.bert.config.hidden_size, num_relations)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids, attention_mask=attention_mask)
cls_output = outputs[1] # CLS token's output
logits = self.classifier(cls_output)
return logits
def compute_f1(predictions, labels, ignore_na=True):
true_positive = false_positive = false_negative = 0
for pred, label in zip(predictions, labels):
if ignore_na and label == 0:
continue
if pred == label:
true_positive += 1
else:
false_positive += 1
false_negative += 1
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return f1
# 示例数据
input_ids = torch.tensor([[101, 2003, 1037, 3185, 102], [101, 2003, 1037, 2309, 102]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]])
labels = torch.tensor([1, 2]) # 假设的标签
num_relations = 3 # 假设关系类别数
model = RelationExtractionModel()
optimizer = optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()
# 模型训练
model.train()
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
# 模型评估
model.eval()
with torch.no_grad():
logits = model(input_ids, attention_mask)
predictions = torch.argmax(logits, dim=1).tolist()
# 计算F1和Ign F1
f1 = compute_f1(predictions, labels, ignore_na=False)
ign_f1 = compute_f1(predictions, labels, ignore_na=True)
print(f"F1 score: {f1:.4f}")
print(f"Ign F1 score: {ign_f1:.4f}")
在这个示例中:
- 使用BERT作为基础模型进行关系抽取。
compute_f1
函数计算F1和Ign F1分数,参数ignore_na
决定是否忽略NA类别。- 假设有3个关系类别进行训练和评估。
这个代码提供了一个基础框架,可以根据具体数据集和任务需求进行扩展和修改。