self attention(自注意力)
来源: https://arxiv.org/pdf/1706.03762
计算实现:
1.计算出
score
=
Q
K
T
\text{score} =QK^T
score=QKT
2.计算
attention
=
s
o
f
t
m
a
x
(
score
)
\text{attention}=softmax(\text{score})
attention=softmax(score)
3. 计算
weighted
=
a
t
t
e
n
t
i
o
n
∗
V
\text{weighted}=attention*V
weighted=attention∗V
1.计算复杂度是
O
(
L
2
)
O(L^2)
O(L2)
2.因为需要计算 LXL 的 注意力矩阵
softmax ( Q K T d ) \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) softmax(dQKT)
完整公式
self_attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
)
\text{self\_attention}(Q, K, V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)
self_attention(Q,K,V)=softmax(dQKT)
来源:https://arxiv.org/pdf/2009.14794
代码实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, input_dim):
super(SelfAttention, self).__init__()
self.input_dim = input_dim
# 定义线性变换层,用于计算查询、键和值
self.query = nn.Linear(input_dim, input_dim) # [batch_size, seq_length, input_dim]
self.key = nn.Linear(input_dim, input_dim) # [batch_size, seq_length, input_dim]
self.value = nn.Linear(input_dim, input_dim) # [batch_size, seq_length, input_dim]
self.softmax = nn.Softmax(dim=2) # 注意力权重的softmax函数,沿着最后一个维度进行
def forward(self, x): # x的形状为 (batch_size, seq_length, input_dim)
queries = self.query(x) # 计算查询矩阵
keys = self.key(x) # 计算键矩阵
values = self.value(x) # 计算值矩阵
# 计算注意力得分
score = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
attention = self.softmax(score) # 对得分应用softmax函数得到注意力权重
weighted = torch.bmm(attention, values) # 使用注意力权重加权值矩阵
return weighted # 返回加权后的值矩阵
Cross Attention(交叉注意力)
这张图片展示了交叉注意力模块的工作原理。
交叉注意力模块
-
输入:
- “What?”:这是表示“内容”的输入序列,包含值(Value,(V))和键(Key,(K))。
- “Where?”:这是表示“位置”的输入序列,包含查询(Query,(Q))。
-
计算过程:
- 从“内容”输入序列中提取出值 (V) 和键 (K)。
- 从“位置”输入序列中提取出查询 (Q)。
- 计算查询 (Q) 和键 (K) 的点积,得到注意力能量(Attention energy)。
- 将注意力能量除以 (\sqrt{C/h}),其中 (C) 是键的维度,(h) 是注意力头的数量,用以进行缩放。
- 对缩放后的注意力能量应用 softmax 函数,得到注意力权重。
- 将注意力权重应用到值 (V) 上,得到输出上下文(Output context)。
数学公式:
Cross_attention ( Q , K , V ) = Softmax ( Q K T C / h ) ⋅ V \text{Cross\_attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{C/h}}\right) \cdot V \ Cross_attention(Q,K,V)=Softmax(C/hQKT)⋅V
- 解释:
- ( Q ):查询矩阵。
- ( K ):键矩阵。
- ( V ):值矩阵。
- (\text{Softmax}):softmax 函数,用于将注意力能量转换为概率分布。
- ( \sqrt{C/h} ):缩放因子,控制注意力能量的大小。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# 定义线性变换层,用于计算查询、键和值
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads # 注意力头的数量
self.d_head = d_embed // n_heads # 每个注意力头的维度
def forward(self, x, y):
# x (潜在表示): (batch_size, seq_len_q, dim_q)
# y (上下文): (batch_size, seq_len_kv, dim_kv) = (batch_size, 77, 768)
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
# 将每个查询的嵌入向量划分为多个头,确保 d_heads * n_heads = dim_q
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
# 计算查询矩阵 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
q = self.q_proj(x)
# 计算键矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
k = self.k_proj(y)
# 计算值矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
v = self.v_proj(y)
# 将查询矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
q = q.view(interim_shape).transpose(1, 2)
# 将键矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
k = k.view(interim_shape).transpose(1, 2)
# 将值矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
v = v.view(interim_shape).transpose(1, 2)
# 计算注意力得分 (batch_size, h, seq_len_q, dim_q / h) @ (batch_size, h, dim_q / h, seq_len_kv) -> (batch_size, h, seq_len_q, seq_len_kv)
weight = q @ k.transpose(-1, -2)
# 缩放注意力得分 (batch_size, h, seq_len_q, seq_len_kv)
weight /= math.sqrt(self.d_head)
# 对注意力得分应用softmax函数 (batch_size, h, seq_len_q, seq_len_kv)
weight = F.softmax(weight, dim=-1)
# 计算加权后的值矩阵 (batch_size, h, seq_len_q, seq_len_kv) @ (batch_size, h, seq_len_kv, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
output = weight @ v
# 将输出矩阵转置回原始形状 (batch_size, h, seq_len_q, dim_q / h) -> (batch_size, seq_len_q, h, dim_q / h)
output = output.transpose(1, 2).contiguous()
# 将输出矩阵重塑回原始形状 (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, seq_len_q, dim_q)
output = output.view(input_shape)
# 应用最后的线性变换 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
output = self.out_proj(output)
# 返回最终的输出 (batch_size, seq_len_q, dim_q)
return output
代码来源
https://github.com/hkproj/pytorch-stable-diffusion/blob/main/sd/attention.py