CV中的self-attention操作
基本概念
self-attention中的self是什么意思呢?
self指的是关注自己内部的信息的相关性:对于NLP任务来说,指的是关于一句话中词与词的相关性;对于CV任务来说,指的是一张图像中若干互不重叠的图像块之间的相关性(后面会细说)。
普通的self-attention
数学原理
我们有一组行向量
{
x
T
i
}
i
=
1
n
,
x
T
i
∈
R
1
×
d
\{\mathbf{x}_T^i\}_{i=1}^n, \mathbf{x}_T^i\in \mathbb{R}^{1\times d}
{xTi}i=1n,xTi∈R1×d,
那么,就有
X
=
[
x
T
1
x
T
2
⋮
x
T
n
]
∈
R
n
×
d
\text{那么,就有 }\mathbf{X} = \begin{bmatrix} \mathbf{x}_T^1 \\ \mathbf{x}_T^2\\ \vdots \\ \mathbf{x}_T^n \end{bmatrix}\in \mathbb{R}^{n\times d}
那么,就有 X=
xT1xT2⋮xTn
∈Rn×d
有三个矩阵
W
Q
,
W
K
,
W
V
W^Q, W^K, W^V
WQ,WK,WV,它们的大小都是
d
×
m
d\times m
d×m。那么,经过:
q
T
i
=
x
T
i
⋅
W
Q
k
T
i
=
x
T
i
⋅
W
K
v
T
i
=
x
T
i
⋅
W
V
\begin{aligned} \mathbf{q}_T^i = \mathbf{x}_T^i\cdot W^Q\\ \mathbf{k}_T^i =\mathbf{x}_T^i\cdot W^K\\ \mathbf{v}_T^i= \mathbf{x}_T^i\cdot W^V\\ \end{aligned}
qTi=xTi⋅WQkTi=xTi⋅WKvTi=xTi⋅WV,我们就获得了对于行向量$ \mathbf{x}_T^i$的三个编码:query, key & value。
对于整体来说,有如下表示:
Q
=
X
W
Q
K
=
X
W
K
V
=
X
W
V
\begin{aligned} \mathbf{Q} = \mathbf{X}W^Q\\ \mathbf{K} = \mathbf{X}W^K\\ \mathbf{V} = \mathbf{X}W^V \end{aligned}
Q=XWQK=XWKV=XWV
其中, Q \mathbf{Q} Q表示query(查询), K \mathbf{K} K表示key, V \mathbf{V} V表示value,这三个矩阵的大小都是 n × m n\times m n×m。
self-attention的工作原理就好像是查字典:要求 x T i \mathbf{x}_T^i xTi对应的输出 o T i \mathbf{o}_T^i oTi,即对于给定的 query q T i \mathbf{q}_T^i qTi,计算 query 与其他所有 key 的相关性,然后将 query 与 key 的相关性作为对应 value 的权值,将所有value加权求和获得输出。
可以用数学公式清晰地表达这个过程:
∀
i
,
o
T
i
=
q
T
i
⋅
(
∑
j
=
1
n
k
T
j
T
⋅
v
T
j
)
=
q
T
i
⋅
(
K
T
V
)
∈
R
1
×
m
.
\begin{aligned} \forall i, \mathbf{o}_T^i &= \mathbf{q}_T^i\cdot (\sum_{j=1}^n {\mathbf{k}_T^j}^T\cdot \mathbf{v}_T^j) \\ &=\mathbf{q}_T^i\cdot (\mathbf{K}^T \mathbf{V}) \in \mathbb{R}^{1\times m}. \end{aligned}
∀i,oTi=qTi⋅(j=1∑nkTjT⋅vTj)=qTi⋅(KTV)∈R1×m.
故有 O = [ o T 1 o T 2 ⋮ o T n ] = [ q T 1 q T 2 ⋮ q T n ] ⋅ ( K T V ) = Q K T V ∈ R n × m , \text{故有 }\mathbf{O} = \begin{bmatrix} \mathbf{o}_T^1 \\ \mathbf{o}_T^2\\ \vdots \\ \mathbf{o}_T^n \end{bmatrix}= \begin{bmatrix} \mathbf{q}_T^1\\ \mathbf{q}_T^2\\ \vdots\\ \mathbf{q}_T^n \end{bmatrix} \cdot (\mathbf{K}^T \mathbf{V}) =\mathbf{Q} \mathbf{K}^T \mathbf{V} \in \mathbb{R}^{n\times m}, 故有 O= oT1oT2⋮oTn = qT1qT2⋮qTn ⋅(KTV)=QKTV∈Rn×m,
我们真正的self-attention公式如下:
O = Attention ( Q , K , V ) = softmax ( Q K T m ) V \mathbf{O} = \operatorname{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \operatorname{softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{m}})\mathbf{V} O=Attention(Q,K,V)=softmax(mQKT)V巧记:矩阵排放次序为Q->K->V
为什么要对 Q K T \mathbf{Q}\mathbf{K}^T QKT做 ⋅ m \frac{\cdot}{\sqrt{m}} m⋅以及 softmax ( ⋅ ) \operatorname{softmax}(\cdot) softmax(⋅)操作呢?
- softmax ( ⋅ ) \operatorname{softmax}(\cdot) softmax(⋅)操作的意义: 我们在上文中提到“query 与 key 的相关性是对应 value 的权值",权值之和应为1。 softmax ( ⋅ ) \operatorname{softmax}(\cdot) softmax(⋅)操作实际上是将 Q K T \mathbf{Q}\mathbf{K}^T QKT中每一行(对于 q i q_i qi)的权值之和变换为1。
- ⋅ m \frac{\cdot}{\sqrt{m}} m⋅操作的意义: 用于消除编码长度 m m m对于最终的权值分布的影响。 q T i ⋅ k T j T = ∑ t = 1 m Q i , t ⋅ K j , t \mathbf{q}_T^i\cdot {\mathbf{k}_T^j}^T =\sum_{t=1}^{m}Q_{i,t}\cdot K_{j,t} qTi⋅kTjT=∑t=1mQi,t⋅Kj,t的值的变化范围会随着其编码长度 m m m的变大而变大,query q T i \mathbf{q}_T^i qTi 与其他 key k T j \mathbf{k}_T^j kTj的相关性的方差太大时(因为 m m m太大而导致的),经过 softmax ( ⋅ ) \operatorname{softmax}(\cdot) softmax(⋅)操作后会导致大的权值被更加放大,小的权值被更加缩小。
代码实现
class SelfAttention(nn.Module):
"""
自注意力机制
dim: input features中每个sub特征向量的编码长度
dim_head: q, k, v中每个编码向量的长度
"""
def __init__(self, dim, dim_head):
super().__init__()
self.generateQKV = nn.Linear(dim, 3 * dim_head, bias=False)
self.softmax = nn.Softmax(dim=-1) # 归一化
self._norm_fact = 1 / sqrt(dim_head) # 消除编码长度的影响
self.isProject = not (dim == dim_head)
if self.isProject:
self.linearProject = nn.Linear(dim_head, dim, bias=False)
def forward(self, x):
b, n, dim = x.shape
# 获取Q, K, V矩阵
QKV = self.generateQKV(x) # (b, n, 3 * dim_head)
Q, K, V = torch.chunk(QKV, 3, dim=-1) # (b, n, dim_head)
# Attention公式的核心计算
# mat.shape = (b, n, n)
mat = self.softmax(Q @ torch.transpose(K, 1, 2) * self._norm_fact)
out = mat @ V # out.shape = (b, n, dim_head)
if self.isProject:
out = self.linearProject(out) # out.shape = (b, n, dim)
return out
把self-attention应用到CV任务中
我们知道,self-attention是Transformer网络的一大核心操作,而Transformer网络本来是应用与NLP任务中的,这也意味这,self-attention接受的输入是向量组(一个句子由若干个words组成,将这些words编码成词向量)。
那么,怎么将self-attention操作迁移到图像任务中来呢?
很简单,只要将其图像块转化成向量就可以了。(在spatial dimension上做文章)
对于一张输入图片:
I
∈
R
C
×
H
×
W
\mathcal{I}\in \mathbb{R}^{C\times H \times W}
I∈RC×H×W
以不重叠的方式将其切成若干大小为
C
×
p
×
p
C\times p \times p
C×p×p的图像块(patches),然后将每个图像块都扁平化成长度为
C
p
2
Cp^2
Cp2的向量。
可以用如下的python代码实现该操作:
from einops import rearrange
import torch
import torch.nn as nn
b, c, h, w = I.shape
patch_size = 64
pn_h, pn_w = h // patch_size, w // patch_size
I = rearrange(I, "b c (p ph) (p pw) -> b c p p (ph pw)", p=patch_size, ph=pn_h, pw=pn_w)
I = torch.transpose(I, 4, 1) # (b, n, c, p, p)
I = nn.Flatten(I, start_dim=2, end_dim=4) # (b, n, c*p^2)
至此,我们就构造出了符合要求的一组向量,然后就可以对它们进行编码的操作啦^_^
还有一点,当我们用self-attention处理image restormation任务的时候,我们通常不希望经过self-attention操作之后,图像的大小发生了变化!但是通常编码长度 m m m会比原来的向量长度 d d d要大,这个时候就还需要重新投影变换回原来的长度。
self-attention + multi-head机制
数学原理
mult-head机制是为了进一步解决什么问题呢?
举个最简单的例子,我们使用的单词是有一词多义的情况的,它们在不同的上下文中可能是不同的意思。类似的,不同的物体,在不同的视角下可能会对应相同的2D图案;相同的物体,在不同的视角下也可能有不同的表现形式。
那么,我们就需要有多套编码来覆盖所有的情况。
multi-head机制其实很简单,就是对于一组向量,有heads组 { W Q , W K , W V } \{W^Q, W^K, W^V\} {WQ,WK,WV}矩阵,将这组向量编码成heads组 Q , K , V \mathbf{Q}, \mathbf{K}, \mathbf{V} Q,K,V。然后每组都运用 Attention \operatorname{Attention} Attention公式,获得heads组输出,再对这些输出做加权求和,获得最终的输出。
代码实现
基于pytorch框架的代码仓库:ViT
class SelfAttention(nn.Module):
"""
多头自注意力机制
dim: input features的通道数
heads: 一共有多少组Q, K, V
dim_head: q, k, v的通道数
"""
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
# 如果dim_head == dim且不用多头注意力机制,就无需做投影变换操作
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
# x.shape = (b, n, dim)
# n是vector的数量,dim是vector的长度
x = self.norm(x)
# 编码,将长度为dim的x_i编码成长度为dim_head的q_i, k_i, v_i
# q, k, v.shape = (b, n, heads * dim_head)
qkv = self.to_qkv(x).chunk(3, dim = -1)
# q, k, v.shape = (b, heads, n, dim_head)
# 每一个Q, K, V矩阵的大小为(n, dim_head),共有heads组
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# dots.shape = (b, h, n, n)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# softmax操作,将dots中每一行的权值之和变换为1
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
## 对多头算出来的结果进行加权求和,重编码回长度为dim的vector
# output.shape = (b, n, dim)
return self.to_out(out)
小结
前面的两种self-attention的实现都没有跳脱开NLP任务中对于self-attention的实现。无可否认self-attention的优越性,但是它也有弊端!固定下切图像块的大小,对于high-resolution的图片来说,意味着切出来的图像块的数量更多,在 O = Attention ( Q , K , V ) = softmax ( Q K T m ) V \mathbf{O} = \operatorname{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \operatorname{softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{m}})\mathbf{V} O=Attention(Q,K,V)=softmax(mQKT)V中,计算 Q K T \mathbf{Q}\mathbf{K}^T QKT的时间复杂度为 O ( H 2 W 2 ) \mathcal{O}(H^2W^2) O(H2W2):计算乘加的次数为 n 2 ( 2 m ) n^2(2m) n2(2m),其中 n = H / p or W / p n = H/p \text{ or }W/p n=H/p or W/p。
那么,我们有必要设计出一种新的/针对CV任务的self-attention机制,以缓解做大图时间慢、消耗更多内存(显存)资源的问题。
换一种方式实现self-attention
数学原理
经典论文:Restormer
核心思想:apply self-attention across channel rather than the spatial dimension
怎么理解这里的channel?
在channel dimension上做self-attention,意味着:
对于一张输入图片:
I
∈
R
C
×
H
×
W
\mathcal{I}\in \mathbb{R}^{C\times H \times W}
I∈RC×H×W
先通过
1
×
1
1\times1
1×1卷积做像素级通道上下文的融合
再通过
3
×
3
3\times3
3×3 depth-wise conv做通道级spatial上下文的融合
获得大小为
m
×
H
×
W
m\times H \times W
m×H×W的张量
Q
,
K
,
V
\mathcal{Q}, \mathcal{K}, \mathcal{V}
Q,K,V。
对比于之前的、有强烈数学意义的SA,Restormer给出的做法并没有什么数学意义,单纯为了降低计算量。大家也不用那么纠结。
将张量 Q , K , V \mathcal{Q}, \mathcal{K}, \mathcal{V} Q,K,V reshape一下,获得大小为 H W × m HW\times m HW×m的 Q ^ , V ^ \hat{\mathbf{Q}}, \hat{\mathbf{V}} Q^,V^矩阵和大小为 m × H W m\times HW m×HW的 K ^ \hat{\mathbf{K}} K^矩阵。
然后,运用下面的公式来实现单头的self-attention:
O
^
=
Attention
(
Q
^
,
K
^
,
V
^
)
=
V
^
softmax
(
K
^
Q
^
α
)
∈
R
H
W
×
m
.
\hat{\mathbf{O}} = \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}}) = \hat{\mathbf{V}}\operatorname{softmax}(\frac{\hat{\mathbf{K}}\hat{\mathbf{Q}}}{\alpha}) \in \mathbb{R}^{HW\times m}.
O^=Attention(Q^,K^,V^)=V^softmax(αK^Q^)∈RHW×m.其中
α
\alpha
α是一个learnable的参数。
巧记:矩阵排放次数和普通的SA操作恰好相反:V->K->Q
可以看出来,现在计算 Q ^ K ^ \hat{\mathbf{Q}}\hat{\mathbf{K}} Q^K^的计算量为 O ( m 2 ( H W ) ) \mathcal{O}(m^2(HW)) O(m2(HW)),与图像的空间分辨率仅是一次相关,大幅缓解了self-attention在大图运算上存在的问题。
然后,我们再将 O ^ \hat{\mathbf{O}} O^经过 1 × 1 1\times 1 1×1卷积操作,重新投影变换回 C × H × W C\times H \times W C×H×W的大小即完成一次完整的self-attention操作。
代码实现
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
"""
dim: input feature x 的通道数(c = dim_head * num_heads)
num_heads: multi-head中的组数
"""
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
# learnable alpha:其中每个elem的作用与 1/sqrt(dim_head) 类似
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
# 分组卷积(depth-wise conv):将feature沿channel-axis均分成dim*3 groups
# 即将(b, 3*dim, h, w)的feature分成dim*3组(b, 1, h, w)的块
# 同理,也将kernel分成dim*3组(1, 1, 3, 3)的小kernel
# 对应组小kernel卷积对应组feature块,再沿channal-axis cat起来,获得输出
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b,c,h,w = x.shape
# 将x先经过1*1 conv,再经过3*3 depth-wise conv
# (b,c,h,w) =>self.qkv()=> (b,3c,h,w) =>self.qkv_dwconv()=> (b,3c,h,w)
qkv = self.qkv_dwconv(self.qkv(x))
# q, k, v.shape = (b, c, h, w)
q,k,v = qkv.chunk(3, dim=1)
# 这下面的c是dim_head,即q, k, v中向量编码的长度
# q, k, v.shape = (b, num_heads, dim_head, n)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
# attn.shape = (b, num_heads, dim_head, dim_head) 跟原来的self-attention最不同的地方!!!
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
# out.shape = (b, num_heads, dim_head, n)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out