自注意力(Self-Attention)与交叉注意力(Cross-Attention)PyTorch 简单实现
在深度学习中,注意力机制是现代 Transformer 架构的核心思想之一。本文将介绍两种常见的注意力机制:自注意力(Self-Attention)与交叉注意力(Cross-Attention),并通过 PyTorch 给出简单实现与使用示例。
📦 必要导入
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange, repeat
from inspect import isfunction
# 一些基础工具函数
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
🔍 什么是注意力机制?
注意力机制允许模型在处理输入序列时自动聚焦于最相关的部分,从而增强建模能力。以 Transformer 为例,它通过注意力机制建立了序列中不同位置之间的信息关联。
🤖 自注意力(Self-Attention)
自注意力是指查询(Query)、键(Key)、值(Value)都来自同一个输入序列。这种机制允许序列中的每个元素关注其它所有位置的信息,是 BERT、GPT 等模型的基本构件。
✅ PyTorch 实现:
class SelfAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.heads = heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
h = self.heads
qkv = self.to_qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
...
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
🧪 使用示例:
attn = SelfAttention(dim=64)
x = torch.randn(1, 10, 64)
out = attn(x)
print(out.shape) # torch.Size([1, 10, 64])
🔁 交叉注意力(Cross-Attention)
交叉注意力允许模型在处理一个序列时,从另一个序列中获取信息。常用于:
- 编码器-解码器结构(如 Transformer 翻译模型)
- 图文跨模态对齐
- 条件生成任务
与自注意力的不同在于:
- Query 来自当前输入(例如解码器)
- Key 与 Value 来自另一个序列(例如编码器)
✅ PyTorch 实现:
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
context = default(context, x)
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
...
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
🧪 使用示例:
ca = CrossAttention(query_dim=64, context_dim=77)
x = torch.randn(1, 10, 64) # 解码器输入
context = torch.randn(1, 20, 77) # 编码器输出
out = ca(x, context)
print(out.shape) # torch.Size([1, 10, 64])
🧠 总结对比
模块 | Query 来自 | Key/Value 来自 | 典型应用 |
---|---|---|---|
Self-Attention | 当前输入 | 当前输入 | BERT、GPT、自注意图像建模 |
Cross-Attention | 当前输入 | 外部上下文 | 编解码结构、跨模态、条件生成 |