目录
Self-Attention:一种Attention机制,用于处理单个输入序列中的依赖关系。
Cross-Attention:一种Attention机制,用于处理两个或多个输入序列之间的依赖关系。
Gated Self-Attention:一种改进的Self-Attention机制,引入了门控机制来控制Attention输出。
Generalized Query Attention:一种扩展的Self-Attention机制,支持多个Query和多个Key-Value对。
这些Attention机制都可以用于自然语言处理、计算机视觉等领域,用于捕获输入数据中的依赖关系和语义信息。
以下是Attention机制的异同点表格,输出为Markdown格式:
Attention机制 | Self-Attention | Cross-Attention | Gated Self-Attention | Generalized Query Attention |
---|---|---|---|---|
输入 | 单个输入序列 | 两个或多个输入序列 | 单个输入序列 | 多个Query和多个Key-Value对 |
输出 | Attention输出 | Attention输出 | Attention输出 | Attention输出 |
依赖关系 | 单个输入序列中的依赖关系 | 两个或多个输入序列之间的依赖关系 | 单个输入序列中的依赖关系 | 多个Query和多个Key-Value对之间的依赖关系 |
门控机制 | 无 | 无 | 有 | 无 |
支持多个Query | 否 | 否 | 否 | 是 |
支持多个Key-Value对 | 否 | 否 | 否 | 是 |
异同点总结
- Self-Attention和Gated Self-Attention都用于处理单个输入序列中的依赖关系,但Gated Self-Attention引入了门控机制来控制Attention输出。
- Cross-Attention用于处理两个或多个输入序列之间的依赖关系。
- Generalized Query Attention支持多个Query和多个Key-Value对,用于处理更复杂的依赖关系。
代码实现
Self-Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, hidden_size, attention_heads):
super(SelfAttention, self).__init__()
self.hidden_size = hidden_size
self.attention_heads = attention_heads
self.query_linear = nn.Linear(hidden_size, hidden_size)
self.key_linear = nn.Linear(hidden_size, hidden_size)
self.value_linear = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
# x: [batch_size, sequence_length, hidden_size]
batch_size, sequence_length, _ = x.size()
# Linear transformations
query = self.query_linear(x)
key = self.key_linear(x)
value = self.value_linear(x)
# Attention weights
attention_weights = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.hidden_size)
attention_weights = F.softmax(attention_weights, dim=-1)
# Attention output
attention_output