清华基于混合注意力的原型网络 解决关系抽取问题
清华在AAAI2019上提出的论文,
主要创新:
- 利用两个注意力的方式来证明自己的模型在有噪声的环境下可以取得很好的效果。
- 将小样本算法引入到NLP领域,并实现不错的效果。
- 加速了模型的收敛速度。
论文地址:https://ojs.aaai.org//index.php/AAAI/article/view/4604.
git地址:https://github.com/thunlp/HATT-Proto.
————————————————————————
摘要:
关系分类当前主要的方法是远程监督(Distant Supervision,DS),但是并没有大规模的数据来让模型训练,同时长尾分布的关系也面临数据稀缺。而人们可以通过少量实例来学习新的知识,所以利用小样本学习(FewShot Learning, FSL来解决关系分类)。FSL之前在图像领域实现,在NLP领域中很少见。
介绍:
关系分类是信息抽取中的一个重要任务,旨在两个给定上下文语句的实体中分类出对应的关系。
由于文本的多样性和噪声,在CV领域表现很好的FSL很难迁移到NLP领域。为了解决这个问题,提出了混合注意力机制的原型网络。使用神经网络来给支持集中每个句子进行编码,为每个类计算出一个特征向量(类原型),进而利用查询集中的数据和类原型进行距离计算,从而进行分类。对于有噪声的模型,数据和特征都很稀缺,支持集中很小的噪声就可能会导致很大误差,并且,关系特征维度并不是特别有区分度的来帮助支持集做最终的分类。
相关工作
提到监督学习是利用大规模数据集来训练,矩阵学习可以计算类之间的距离,元学习可以快速学习参数,小样本学习可以在很少的数据集上取得不错的效果。
model
基本模型
任务定义
x是句子 h是头实体,t是尾实体,r是头实体和尾实体之间的关系。
特征级注意力和实例级注意力:
特征级注意力就是在300维的词向量上进行卷积操作,从而抽取有价值的信息。
实例级注意力:包括一个embedding和一个encoding,通过softmax和tanh结合来计算最终的注意力得分。
我们的混合注意力由两个模块组成,即实例级注意力模块可在支持集中选择更多的信息,以及特征级注意模块,以突出距离功能中的重要尺寸。
具体注意力到代码中研究
初始定义:
class ProtoHATT(fewshot_re_kit.framework.FewShotREModel):
def __init__(self, sentence_encoder, shots, hidden_size=230):
fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder)
self.hidden_size = hidden_size
self.drop = nn.Dropout()
# for instance-level attention
self.fc = nn.Linear(hidden_size, hidden_size, bias=True)
# for feature-level attention
self.conv1 = nn.Conv2d(1, 32, (shots, 1), padding=(shots // 2, 0)) # 20
self.conv2 = nn.Conv2d(32, 64, (shots, 1), padding=(shots // 2, 0))
self.conv_final = nn.Conv2d(64, 1, (shots, 1), stride=(shots, 1))
def forward(self, support, query, N, K, Q):
...
# feature-level attention
fea_att_score = support.view(B * N, 1, K, self.hidden_size) # (B * N, 1, K, D)# 80 1 5 230
fea_att_score = F.relu(self.conv1(fea_att_score)) # (B * N, 32, K, D) # 80 32 5 230
fea_att_score = F.relu(self.conv2(fea_att_score)) # (B * N, 64, K, D)# 80 64 5 230
fea_att_score = self.drop(fea_att_score)
fea_att_score = self.conv_final(fea_att_score) # (B * N, 1, 1, D)# 80 1 1 230
fea_att_score = F.relu(fea_att_score)
fea_att_score = fea_att_score.view(B, N, self.hidden_size).unsqueeze(1) # (B, 1, N, D)
# instance-level attention
support = support.unsqueeze(1).expand(-1, NQ, -1, -1, -1) # (B, NQ, N, K, D)
support_for_att = self.fc(support)
query_for_att = self.fc(query.unsqueeze(2).unsqueeze(3).expand(-1, -1, N, K, -1))
ins_att_score = F.softmax(torch.tanh(support_for_att * query_for_att).sum(-1), dim=-1) # (B, NQ, N, K)
support_proto = (support * ins_att_score.unsqueeze(4).expand(-1, -1, -1, -1, self.hidden_size)).sum(3) # (B, NQ, N, D)
# Prototypical Networks
logits = -self.__batch_dist__(support_proto, query, fea_att_score)
_, pred = torch.max(logits.view(-1, N), 1)
return logits, pred
结果分析:
1.加入了一个混乱因子的效果展示,支持集中数据10%,30%,50%概率出错,测试下模型的准确率。
2.按照迭代次数来计算训练集上的loss和验证集上的准确率。
3.通过聚类算法来查看词向量的分布情况。