1. 基本概念与核心思想
分层注意力是注意力机制的一种扩展,其核心思想是在不同层次上分别应用注意力机制,从而更细致地捕捉数据的多层次结构信息。与传统注意力机制(如自注意力)相比,分层注意力能够同时关注局部细节和全局上下文,特别适合处理具有明显层级结构的数据,如文本(单词→句子→文档)、视频(帧→片段→视频)等。
关键优势:
- 捕捉多尺度特征:通过在不同层级应用注意力,同时关注微观和宏观特征
- 减少信息损失:避免直接将所有信息压缩到单一表示中
- 可解释性强:不同层级的注意力权重可用于分析模型关注的重点
2. 分层注意力的典型结构
分层注意力通常包含两个或多个层级的注意力机制,常见的结构有:
-
词级注意力 → 句子级注意力(用于文本处理)
- 词级注意力:关注句子中哪些单词更重要
- 句子级注意力:关注文档中哪些句子更重要
-
帧级注意力 → 片段级注意力 → 视频级注意力(用于视频处理)
- 帧级注意力:关注视频帧中的关键区域
- 片段级注意力:关注视频片段中的重要动作
- 视频级注意力:整合所有片段信息,生成最终表示
-
通道级注意力 → 空间级注意力(用于图像 / 特征图处理)
- 通道级注意力:关注哪些特征通道更重要
- 空间级注意力:关注特征图中哪些空间位置更重要
3. 数学形式化表示
以文本处理中的词级 + 句子级分层注意力为例,其数学表达如下:
-
词级注意力:
- 输入:词嵌入序列 \(\mathbf{h}_t \in \mathbb{R}^d, t=1,2,\dots,T\)
- 隐层表示:\(\mathbf{u}_t = \tanh(W_w\mathbf{h}_t + b_w)\)
- 注意力权重:\(\alpha_t = \frac{\exp(\mathbf{u}_t^T\mathbf{v}_w)}{\sum_{k=1}^T\exp(\mathbf{u}_k^T\mathbf{v}_w)}\)
- 句子向量:\(\mathbf{s} = \sum_{t=1}^T \alpha_t \mathbf{h}_t\)
-
句子级注意力:
- 输入:句子向量序列 \(\mathbf{s}_i \in \mathbb{R}^d, i=1,2,\dots,M\)
- 隐层表示:\(\mathbf{u}_i = \tanh(W_s\mathbf{s}_i + b_s)\)
- 注意力权重:\(\beta_i = \frac{\exp(\mathbf{u}_i^T\mathbf{v}_s)}{\sum_{k=1}^M\exp(\mathbf{u}_k^T\mathbf{v}_s)}\)
- 文档向量:\(\mathbf{d} = \sum_{i=1}^M \beta_i \mathbf{s}_i\)
4. 关键应用场景
-
文本情感分析:
- 词级注意力捕捉情感关键词(如 "很棒"、"糟糕")
- 句子级注意力区分不同句子对整体情感的贡献
-
视频理解:
- 帧级注意力定位关键动作区域
- 片段级注意力识别重要事件序列
- 视频级注意力整合全局信息进行分类
-
多模态融合:
- 在不同模态(文本、图像、音频)内部应用注意力
- 在模态间应用跨模态注意力,整合多源信息
-
长序列建模:
- 处理超长文本或视频时,分层注意力可有效降低计算复杂度
- 例如,将长文档分为段落,先对段落内的句子应用注意力,再对段落间应用注意力
5. 代表性模型与论文
-
HAN (Hierarchical Attention Networks, 2016)
- 论文:Hierarchical Attention Networks for Document Classification
- 贡献:首次提出词级 + 句子级的分层注意力结构,用于文档分类任务
- 特点:可解释性强,通过注意力权重可视化可分析模型决策依据
-
HA-Net (Hierarchical Attention Network for Video Captioning, 2017)
- 论文:Hierarchical Attention Networks for Video Captioning
- 贡献:应用于视频描述生成,通过分层注意力同时关注视频帧和时间序列
- 结构:帧级注意力(空间)→ 片段级注意力(时间)
-
BERT-based 分层注意力
- 许多基于 BERT 的模型在预训练后加入分层注意力结构
- 例如,在文档级任务中,先使用 BERT 编码句子,再通过句子级注意力整合文档信息
6. 代码实现示例
以下是一个简化的词级 + 句子级分层注意力模型的 PyTorch 实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class WordAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(WordAttention, self).__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.context_vector = nn.Parameter(torch.randn(hidden_dim))
def forward(self, x):
# x shape: [batch_size, seq_len, input_dim]
u = torch.tanh(self.linear(x)) # [batch_size, seq_len, hidden_dim]
alpha = F.softmax(torch.matmul(u, self.context_vector), dim=1) # [batch_size, seq_len]
s = torch.bmm(alpha.unsqueeze(1), x).squeeze(1) # [batch_size, input_dim]
return s, alpha
class SentenceAttention(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(SentenceAttention, self).__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.context_vector = nn.Parameter(torch.randn(hidden_dim))
def forward(self, x):
# x shape: [batch_size, num_sentences, input_dim]
u = torch.tanh(self.linear(x)) # [batch_size, num_sentences, hidden_dim]
beta = F.softmax(torch.matmul(u, self.context_vector), dim=1) # [batch_size, num_sentences]
d = torch.bmm(beta.unsqueeze(1), x).squeeze(1) # [batch_size, input_dim]
return d, beta
class HierarchicalAttentionNetwork(nn.Module):
def __init__(self, vocab_size, embed_dim, word_hidden_dim, sent_hidden_dim, num_classes):
super(HierarchicalAttentionNetwork, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.word_attn = WordAttention(embed_dim, word_hidden_dim)
self.sent_attn = SentenceAttention(word_hidden_dim, sent_hidden_dim)
self.classifier = nn.Linear(sent_hidden_dim, num_classes)
def forward(self, x):
# x shape: [batch_size, num_sentences, seq_len]
batch_size, num_sentences, seq_len = x.size()
# 词嵌入
x = x.view(batch_size * num_sentences, seq_len) # [batch_size*num_sentences, seq_len]
embedded = self.embedding(x) # [batch_size*num_sentences, seq_len, embed_dim]
# 词级注意力
sentence_vectors, word_attn_weights = self.word_attn(embedded) # [batch_size*num_sentences, word_hidden_dim]
# 重塑为句子序列
sentence_vectors = sentence_vectors.view(batch_size, num_sentences, -1) # [batch_size, num_sentences, word_hidden_dim]
# 句子级注意力
document_vector, sent_attn_weights = self.sent_attn(sentence_vectors) # [batch_size, sent_hidden_dim]
# 分类
logits = self.classifier(document_vector) # [batch_size, num_classes]
return logits, word_attn_weights, sent_attn_weights
7. 面试常见问题
Q1:分层注意力与普通注意力的区别是什么? A1: 普通注意力在单一层级上处理输入,例如只关注句子中的单词;而分层注意力在多个层级上依次应用注意力机制,例如先关注单词,再关注句子,能够捕捉更丰富的层次结构信息。
Q2:分层注意力有哪些应用场景? A2: 适合处理具有层级结构的数据,如:
- 文本:词→句子→文档
- 视频:帧→片段→视频
- 图像:像素→区域→整体
- 语音:音素→单词→句子
Q3:如何设计分层注意力的层级结构? A3: 根据数据的固有结构设计,例如:
- 文本处理:通常设计为词级 + 句子级
- 视频处理:可设计为帧级 + 片段级 + 视频级
- 图像 / 特征图:可设计为通道级 + 空间级
Q4:分层注意力的计算复杂度如何? A4: 分层注意力的计算复杂度通常高于普通注意力,因为需要在多个层级上重复计算注意力机制。但通过合理设计层级结构和共享参数,可以控制复杂度。例如,在长文档处理中,分层注意力可以减少序列长度,从而降低总体计算量。
Q5:如何通过可视化理解分层注意力模型的决策过程? A5: 可以可视化不同层级的注意力权重:
- 词级注意力:在文本中高亮显示权重高的单词
- 句子级注意力:在文档中高亮显示权重高的句子
- 帧级注意力:在视频帧上叠加注意力热图 通过这些可视化,可以分析模型关注的关键信息,解释模型决策依据。
8. 总结
分层注意力是一种强大的机制,能够有效捕捉数据的多层次结构信息,在文本、视频、图像等领域都有广泛应用。其核心优势在于同时关注局部细节和全局上下文,提高模型对复杂数据的理解能力。在实际应用中,需要根据数据特性合理设计层级结构,并注意控制计算复杂度。分层注意力不仅能提升模型性能,还能通过注意力权重可视化增强模型的可解释性。