### 交叉注意力机制的实现
交叉注意力机制是一种扩展形式的自注意力机制,其核心思想是在两个不同的上下文中计算注意力权重。具体来说,查询(Query)来自一个序列,而键(Key)和值(Value)则来源于另一个序列。以下是基于 PyTorch 和 TensorFlow 的交叉注意力机制实现代码。
#### 使用 PyTorch 实现交叉注意力机制
```python
import torch
import torch.nn as nn
import math
class CrossAttention(nn.Module):
def __init__(self, dim_model, num_heads, dropout=0.1):
super(CrossAttention, self).__init__()
assert dim_model % num_heads == 0, "dim_model must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim_model // num_heads
self.query_projection = nn.Linear(dim_model, dim_model)
self.key_projection = nn.Linear(dim_model, dim_model)
self.value_projection = nn.Linear(dim_model, dim_model)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(dim_model, dim_model)
def forward(self, query, key_value_input, mask=None):
batch_size, seq_len_q, _ = query.size()
_, seq_len_kv, _ = key_value_input.size()
# Linear projections
queries = self.query_projection(query).view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
keys = self.key_projection(key_value_input).view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)
values = self.value_projection(key_value_input).view(batch_size, seq_len_kv, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, heads, seq_len_q, seq_len_kv)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, values) # (batch_size, heads, seq_len_q, head_dim)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len_q, -1) # (batch_size, seq_len_q, dim_model)
output = self.out_projection(context)
return output, attn_weights
```
上述代码实现了交叉注意力模块,其中 `query` 是源序列,`key_value_input` 是目标序列[^1]。
---
#### 使用 TensorFlow 实现交叉注意力机制
```python
import tensorflow as tf
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
class MultiHeadCrossAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadCrossAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, q, kv, mask=None):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len_q, d_model)
k = self.wk(kv) # (batch_size, seq_len_kv, d_model)
v = self.wv(kv) # (batch_size, seq_len_kv, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_kv, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_kv, depth)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
```
此代码展示了如何在 TensorFlow 中构建一个多头交叉注意力层,允许不同序列间的交互[^3]。
---
### 参数调整建议
为了优化交叉注意力机制的表现,可以尝试以下方法:
- 调整模型维度 (`d_model`) 和头部数量 (`num_heads`) 来适应特定任务的需求[^4]。
- 应用正则化技术(如 Dropout 或 L2 正则化)防止过拟合。
- 对于大规模数据集,可采用稀疏注意力或其他高效变体减少计算开销。
---