CasRel(Cascade and Residual Learning for Relational Triple Extraction)模型
1. 什么是 CasRel 模型?
CasRel(Cascade and Residual Learning for Relational Triple Extraction) 是由 清华大学的 NLP 研究团队 提出的 级联与残差学习模型,用于 关系三元组抽取(Relational Triple Extraction)。该模型发表在 ACL 2020,通过 级联解耦(Cascade Decoupling) 和 残差学习(Residual Learning) 的机制来 提取文本中的实体关系三元组,解决了 重叠关系(Overlapping Relations) 和 复杂关系 的问题。
2. 关系三元组抽取的背景
2.1 什么是关系三元组抽取?
关系三元组抽取是信息抽取中的一个重要任务,其目标是:
- 从 非结构化文本 中提取出 (主语,关系,宾语) 形式的三元组。
- 例如:
- 输入:“Elon Musk is the CEO of Tesla.”
- 输出:(Elon Musk, CEO of, Tesla)
2.2 关系三元组抽取的挑战
- 重叠三元组问题(Overlapping Triple Problem):同一实体可能在同一个句子中有多个关系。
- 多重关系问题(Multiple Relations):同一实体对之间可能有多个不同的关系。
- 长距离依赖问题(Long Dependency):当主语和宾语距离较远时,提取关系更具挑战性。
3. CasRel 模型的核心思想
CasRel 通过 级联解耦(Cascade Decoupling) 和 残差学习(Residual Learning) 来解决关系三元组抽取问题:
3.1 级联解耦(Cascade Decoupling)
- CasRel 使用 两个子任务 来解耦关系三元组的抽取:
- 主语识别(Subject Extraction):首先识别句子中的主语。
- 关系-宾语联合抽取(Relation-Object Pair Extraction):根据主语预测关系和相应的宾语。
3.2 残差学习(Residual Learning)
- 通过残差机制,CasRel 能够在主语识别和关系-宾语联合抽取之间进行 信息流的有效传递,从而提高模型的学习能力。
4. CasRel 的模型结构
CasRel 主要包括以下几个模块:
4.1 基础编码层(Embedding Layer)
- 使用 BERT/Transformer 作为文本的编码器,将输入的句子转换成上下文表示。
4.2 主语识别模块(Subject Extraction Layer)
- 通过 BERT + BiLSTM + CRF 识别句子中的 主语(subject)。
- 使用 Span-based(起始和结束位置) 机制定位主语的起始位置和结束位置。
4.3 关系-宾语联合抽取模块(Relation-Object Pair Extraction Layer)
- 以识别出的 主语 为条件,对句子进行 关系和宾语预测。
- 通过 BERT + 关系解码器(Relation Decoder) 预测所有可能的关系,并提取对应的宾语。
4.4 级联机制(Cascade Learning)
- 关系和宾语的抽取是 级联进行的,即先识别主语,然后再进行关系-宾语的联合预测。
5. CasRel 的训练和损失函数
CasRel 采用 二元交叉熵损失(Binary Cross-Entropy, BCE) 和 解耦损失(Decoupling Loss) 进行训练。
5.1 主语识别损失
L s u b j e c t = BCE ( y s u b j e c t , y ^ s u b j e c t ) L_{subject} = \text{BCE}(y_{subject}, \hat{y}_{subject}) Lsubject=BCE(ysubject,y^subject)
- y s u b j e c t y_{subject} ysubject 是真实的主语位置。
- y ^ s u b j e c t \hat{y}_{subject} y^subject 是模型预测的主语位置。
5.2 关系-宾语联合预测损失
L r e l a t i o n − o b j e c t = BCE ( y r e l a t i o n − o b j e c t , y ^ r e l a t i o n − o b j e c t ) L_{relation-object} = \text{BCE}(y_{relation-object}, \hat{y}_{relation-object}) Lrelation−object=BCE(yrelation−object,y^relation−object)
- y r e l a t i o n − o b j e c t y_{relation-object} yrelation−object 是真实的关系-宾语对。
- y ^ r e l a t i o n − o b j e c t \hat{y}_{relation-object} y^relation−object 是模型预测的关系-宾语对。
5.3 总损失函数
L
t
o
t
a
l
=
L
s
u
b
j
e
c
t
+
L
r
e
l
a
t
i
o
n
−
o
b
j
e
c
t
L_{total} = L_{subject} + L_{relation-object}
Ltotal=Lsubject+Lrelation−object
模型通过联合训练主语识别和关系-宾语联合抽取来最小化损失函数。
6. CasRel 的推理流程
6.1 输入句子
- 给定一个输入句子,首先使用 BERT/Transformer 进行编码。
- 例如:
"Elon Musk is the CEO of Tesla."
6.2 主语识别
- 识别主语:Elon Musk
- 模型预测主语的起始和结束位置。
6.3 关系-宾语联合抽取
- 针对 Elon Musk 预测可能的关系和宾语:
- (Elon Musk, CEO of, Tesla)
7. Casrel 代码实现示例
7.1 数据格式
Casrel 需要的数据格式为:
sample_data = [
{
"text": "Elon Musk is the CEO of Tesla.",
"spo_list": [
{"subject": "Elon Musk", "predicate": "CEO", "object": "Tesla"}
]
},
{
"text": "Steve Jobs founded Apple.",
"spo_list": [
{"subject": "Steve Jobs", "predicate": "founded", "object": "Apple"}
]
}
]
7.2 PyTorch 实现 Casrel 模型
import torch
import torch.nn as nn
from transformers import BertModel
class CasRel(nn.Module):
def __init__(self, encoder_name='bert-base-uncased', rel_num=10, hidden_size=768):
super(CasRel, self).__init__()
self.bert = BertModel.from_pretrained(encoder_name)
self.rel_num = rel_num
# 主语头尾预测
self.sub_head_extractor = nn.Linear(hidden_size, 1)
self.sub_tail_extractor = nn.Linear(hidden_size, 1)
# 对每种关系预测对应的宾语头尾
self.obj_head_extractor = nn.Linear(hidden_size, rel_num)
self.obj_tail_extractor = nn.Linear(hidden_size, rel_num)
def forward(self, input_ids, attention_mask, subject_positions=None):
encoder_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
sub_heads = torch.sigmoid(self.sub_head_extractor(encoder_outputs)).squeeze(-1)
sub_tails = torch.sigmoid(self.sub_tail_extractor(encoder_outputs)).squeeze(-1)
if subject_positions is not None:
# subject_positions: (batch, 2) -> start, end
# 提取 subject 向量作为条件,用于条件关系抽取
sub_start = subject_positions[:, 0]
sub_end = subject_positions[:, 1]
batch_size = encoder_outputs.size(0)
sub_start_vec = torch.stack([encoder_outputs[i, sub_start[i]] for i in range(batch_size)])
sub_end_vec = torch.stack([encoder_outputs[i, sub_end[i]] for i in range(batch_size)])
subject_vec = (sub_start_vec + sub_end_vec) / 2
subject_vec = subject_vec.unsqueeze(1)
conditioned_output = encoder_outputs + subject_vec
obj_heads = torch.sigmoid(self.obj_head_extractor(conditioned_output))
obj_tails = torch.sigmoid(self.obj_tail_extractor(conditioned_output))
else:
obj_heads = None
obj_tails = None
return sub_heads, sub_tails, obj_heads, obj_tails
7.3 Tokenizer 与数据预处理
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def encode_example(example, max_len=64):
tokens = tokenizer(example['text'], return_tensors='pt', truncation=True, max_length=max_len, padding='max_length')
input_ids = tokens['input_ids'].squeeze(0)
attention_mask = tokens['attention_mask'].squeeze(0)
return input_ids, attention_mask
7.4 模拟训练过程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CasRel(rel_num=5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.BCELoss()
# 模拟训练
for epoch in range(3):
for sample in sample_data:
input_ids, attention_mask = encode_example(sample)
input_ids = input_ids.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
sub_heads_pred, sub_tails_pred, _, _ = model(input_ids, attention_mask)
# 模拟 subject 头尾标签(这里只是示例,真实训练需构造真实标签)
sub_head_label = torch.zeros_like(sub_heads_pred)
sub_tail_label = torch.zeros_like(sub_tails_pred)
sub_head_label[0][1] = 1 # 假设 subject start 在第 1 个 token
sub_tail_label[0][3] = 1 # 假设 subject end 在第 3 个 token
loss = loss_fn(sub_heads_pred, sub_head_label) + loss_fn(sub_tails_pred, sub_tail_label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch + 1} Loss: {loss.item():.4f}")
7.5 预测代码(推理阶段)
def predict_subjects(model, text):
model.eval()
input_ids, attention_mask = encode_example({'text': text})
input_ids = input_ids.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
with torch.no_grad():
sub_heads_pred, sub_tails_pred, _, _ = model(input_ids, attention_mask)
sub_heads_pred = sub_heads_pred[0].cpu().numpy()
sub_tails_pred = sub_tails_pred[0].cpu().numpy()
# 阈值选取
subject_list = []
for i, sh in enumerate(sub_heads_pred):
if sh > 0.5:
for j, st in enumerate(sub_tails_pred[i:]):
if st > 0.5:
start, end = i, i + j
tokens = tokenizer.convert_ids_to_tokens(input_ids[0][start:end+1])
subject_text = tokenizer.convert_tokens_to_string(tokens)
subject_list.append((subject_text.strip(), start, end))
break
return subject_list
8. CasRel 的优势
8.1 解决重叠三元组问题
- CasRel 通过 级联解耦 解决了 多重关系和重叠三元组问题,保证模型在复杂句子中也能提取出完整的关系。
8.2 提高模型的可解释性
- 由于 CasRel 是 逐步解耦 的关系提取方法,模型可以 明确解释每一步的推理过程,提高了结果的透明性。
8.3 兼容性强
- CasRel 可以结合 BERT、RoBERTa、ALBERT 等预训练语言模型,提高了模型的泛化能力。
9. CasRel 的应用场景
应用场景 | CasRel 适用性 |
---|---|
关系抽取 | 从新闻、法律文档、科研论文中提取实体关系 |
知识图谱构建 | 从非结构化文本中自动构建知识图谱 |
医学信息抽取 | 提取患者病史、症状、治疗关系 |
金融信息抽取 | 提取公司、股市、金融事件之间的关系 |
推荐系统 | 基于用户和物品的关系进行个性化推荐 |
10. 结论
- CasRel(Cascade and Residual Learning) 是一种 级联与残差学习模型,解决了 重叠关系、复杂关系抽取 的问题。
- 模型的核心思想 是将关系三元组的提取解耦成 主语识别 和 关系-宾语联合预测 两个子任务。
- CasRel 在多个数据集(如 NYT、WebNLG、SciERC)中表现出色,是目前关系抽取任务中的主流模型之一。
- 由于其 级联解耦和残差机制,CasRel 能够在大规模文本数据中 高效、准确地提取关系三元组,广泛应用于 信息抽取、知识图谱构建、金融分析等领域。