开宗明义:attention就是一个加权机制,但是权重需要经过训练得到。感觉有些情况就是代替了concat然后过模型的场景,优势是更小、参数更少。
本文介绍一些我自己整理的常见attention类型。
本文不关注别的博文里已经写烂了的attention机制的重要性和直觉解释(又不是写论文还要写introduction)。
multi-head attention也不在本文赘述。
文章目录
1. attention基础概念讲解
待补。
2. attention结果计算方法
2.1 KQV版
K-key:输入,想从中找到权重高低的部分的那个东西
Q-query:拿来计算key各部权重的对象
KQ是拿来计算权重的,权重需要归一化(softmax)
V-value:是那个被用来乘权重的对象(也就是输入)
2.1.1 Dot-Product Attention
加权求和:Q是训练出来的context vector
K是通过输入通过MLP后得到,Q是通过训练得到的context vector,V是输入。相当于对输入进行一个加权求和
-
HAN1:两层attention+GRU
(图中h是单词表征, u i t u_it uit是单词表征经一层MLP后得到的隐表征, u w u_w uw是context vector, α \alpha α是注意力权重,s是句子表征)
(这是第二层attention,差不多) -
LeSICiN2里面聚合法条表征:
(论文中给出一项优化方法,是将context vector改为用其他嵌入经线性转换后得到的结果。但是没给出ablation study)
加权求和:权重是神经网络直接算出来的
GNMT:
ConvS2S
加权求和:QKV都是输入
- Transformer(实际用的是scaled版本,见2.1.2部分)
- RNN+翻译
Q是decoder的上一步输出(第一步就是输入),KV是当前输入
加权求和,local attention
定义见第4节
图场景:
(QKV都是输入,但仅在样本之间存在的关系/边上概率化attention)
- GAT3:用每个节点的所有邻居节点(有自环,所以包括该节点本身)信息attentively聚合
a就是那个MLP(对每个节点对表征(向量对),得到标量)
将上两式合并:
加权求和:
(GAT的多头注意力机制,前几层是concat,最后一层是average:)
- MAGNN4:延续GAT思想。
也是聚合两次,一次是将目标节点的metapath instances表征聚合为metapath表征:
(图中黄色是目标节点表征, a P \mathbf{a}_P aP是metapath-specific的权重 (attention vector) )
再将各个metapath表征聚合起来:
(这个应该算是context vector系了,图中 q A \mathbf{q}_A qA是节点类型的attention vector)
加权求和:Q是给定张量,K和V是输入
- code-wise attention
示例来自NeurJudge代码,query是用来从输入中提取相关信息的辅助矩阵(罪名表征,广播到输入的mini-batch上,维度是[batch_size,query_num,hidden_dim]
),context是输入(维度是[batch_size,sent_len,hidden_dim]
)
α是attention,query是G,输入是D
NeurJudge里这个输出([batch_size,1,hidden_dim]
)直接就用来预测了(原代码里分别用了2个query,得到两个attention输出,concat起来做预测)
class Code_Wise_Attention(nn.Module):
def __init__(self):
super(Code_Wise_Attention, self).__init__()
def forward(self,query,context):
S = torch.bmm(context, query.transpose(1, 2))
attention = torch.nn.functional.softmax(torch.max(S, 2)[0], dim=-1)
context_vec = torch.bmm(attention.unsqueeze(1), context)
return context_vec
这个attention机制,在NeurJudge原文中给出了两篇参考文献,还没看:Bidirectional attention flow for machine comprehension和Multi-Interactive Attention Network for Fine-grained Feature Learning in CTR Prediction
- 还是来自NeurJudge,参考文献是Sentence Similarity Learning by Lexical Decomposition and Composition,逻辑也来源自这篇,是想要样本表征(d)通过罪名表征(c)分割为两个部分,分别与c平行与正交。这一部分计算attention就是为了将c投影到d上:公式5是为了计算c和d之间token的点积相似度,公式6是用softmax来从c中选择与d最相似的token(softmax相当于是软版的max)
在代码里考虑了mask的情况,用了2个mask(一个在softmax之前,一个在×V之前)。两个输入矩阵和返回值的维度都是[batch_size,sent_len,hidden_dim]
class Mask_Attention(nn.Module):
def __init__(self):
super(Mask_Attention, self).__init__()
def forward(self, query, context):
attention = torch.bmm(context, query.transpose(1, 2))
mask = attention.new(attention.size()).zero_()
mask[:,:,:] = -np.inf
attention_mask = torch.where(attention==0, mask, attention)
attention_mask = torch.nn.functional.softmax(attention_mask, dim=-1)
mask_zero = attention.new(attention.size()).zero_()
final_attention = torch.where(attention_mask!=attention_mask, mask_zero, attention_mask)
context_vec = torch.bmm(final_attention, query)
return context_vec
2.1.2 Scaled Dot-Product Attention
解决了dot-product attention随维度增长而剧增、导致softmax取值集中、梯度变小的问题(其实我没看懂这是为啥)
- Transformer5
self-attention: KQV都由输入通过线性转换运算得到。这种做法可以用来计算出一组对象内部之间的关系。在LEMM6中这方面可能会体现得更全面:
decoder中则是KV通过encoder得到,Q通过decoder上一层输出得到。
使用self-attention的工作:GL-GIN: Fast and Accurate Non-Autoregressive Model for Joint Multiple Intent Detection and Slot Filling
2.1.3 加性attention
2.2 计算样本对之间的attention
- DVQA7:这篇论文既没有给公式,也没有给代码,只能看图了。但是看图感觉还挺清晰的。
图中attention distribution map得到样本对之间的attention(每一个方块里面是一个计算attention的模型)
DAN:在triplet loss部分直接使用attention之间的距离,classification loss部分则用类似加权求和的方式利用注attention
DCN:将attention转化为上下文结合进了原样本表征中,直接实现分类任务
- CTM8:与DVQA做法和在DAN中的用处都类似,直接给出了公式:
2.3 协同注意力
融合法律文本结构信息的刑事案件判决预测(下载自知网)
大致来说是输入两个矩阵(以下文本中是“法律文本”和“案情描述”),分别得到两个矩阵与对方交互后得到的结果,于是实现了各矩阵获得对方矩阵信息的效果
3. Soft/Hard Attention
soft attention:传统attention,可被嵌入到模型中去进行训练并传播梯度
hard attention:不计算所有输出,依据概率对encoder的输出采样,在反向传播时需采用蒙特卡洛进行梯度估计
4. Global/Local Attention
global attention:传统attention,对所有encoder输出进行计算
local attention:介于soft和hard之间,会预测一个位置并选取一个窗口进行计算
其他本文撰写过程中使用到的参考资料
- Transformer 模型详解
- 深度学习attention机制中的Q,K,V分别是从哪来的? - 知乎:只看了几个回答,感觉挺多讲得不错,待继续看
- NLP中的Attention原理和源码解析 - 知乎
- 还没看
Hierarchical Attention Networks for Document Classification ↩︎
LeSICiN: A Heterogeneous Graph-based Approach for Automatic Legal Statute Identification from Indian Legal Documents
可参考我写的博文:Re6:读论文 LeSICiN: A Heterogeneous Graph-based Approach for Automatic Legal Statute Identification fro ↩︎MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph Embedding ↩︎
An Element-aware Multi-representation Model for Law Article Prediction ↩︎
Augmenting Legal Judgment Prediction with Contrastive Case Relations ↩︎