Focal Self Attention技术分析
动机
通过整合fine-grained local attention和coarse-grained golbal attention,来克服各自的问题。
Focal self Attention
所谓的Focal self attention简单来说就是对距离Query越近的区域进行细粒度fine granulity的attention, 对距离Query越远的区域进行粗粒度的attention,通过调整粒度的级别最终会得到一个层次化的feature maps, 然后把这些feature map经过Flatten、Concatenation等操作转化为Vector, 然后对Vector分别进行Linear projection得到 Key和Value, 然后进行常规的Scaled Dot-Product Attention (SDPA)即可。
基本概念
为了控制粒度,作者引入了三个基本定义,分别为:
1.Focal level
L
L
L: 用于进行Focal self attention的粒度级别的数量;
2.Focal Window size
s
w
l
s_{w}^{l}
swl: 第
l
l
l个粒度级别上用于提取tokens的sub-window的尺寸;
3.Focal Region size
s
r
l
s_{r}^{l}
srl: 第
l
l
l个粒度级别上水平/垂直方向上包含的sub-window的数量;
流程
1.网格化
假设Feature map的原始尺寸为
[
H
,
W
,
d
]
[H, W, d]
[H,W,d], 常规做法是对每个token都要进行一次attention, time和memory开销太大, 因此通过网格化将feature map分成一个个的sub windows, 然后逐window进行attention, 即所谓的window-wise self attention。
假设sub window的尺寸为
[
s
p
,
s
p
]
[s_{p}, s_{p}]
[sp,sp], 则总的windows数量为
H
W
s
p
s
p
\frac{HW}{s_{p}s_{p}}
spspHW, 记网格化后的尺寸为
[
s
p
,
s
p
,
H
/
s
p
⋅
W
/
s
p
⋅
d
]
[sp, sp, H/s_{p} \cdot W/s_{p} \cdot d]
[sp,sp,H/sp⋅W/sp⋅d].
注:通过设置不同的大小的
s
p
s_{p}
sp(实际上就是
s
w
l
s_{w}^{l}
swl), 可以得到不同粒度级别的surroundings.
2. 子窗口池化
假设输入:
X
∈
R
H
⋅
M
⋅
d
X \in R^{H \cdot M \cdot d}
X∈RH⋅M⋅d,
对于第
l
l
l个粒度级别, 其网格化操作形式化定义为:
X
^
=
R
e
s
h
a
p
e
(
X
)
∈
R
s
w
l
x
s
w
l
(
H
/
s
w
l
⋅
W
/
s
w
l
⋅
d
)
\hat{X}=Reshape(X) \in R^{s_{w}^{l} {\rm x} s_{w}^{l} {\rm} (H/s_{w}^{l} \cdot W/s_{w}^{l} \cdot d)}
X^=Reshape(X)∈Rswlxswl(H/swl⋅W/swl⋅d),
然后池化操作形式化为:
x
l
=
f
p
l
∈
R
H
/
s
w
l
⋅
W
/
s
w
l
⋅
d
x^{l}=f_{p}^{l} \in R^{H/s_{w}^{l} \cdot W/s_{w}^{l} \cdot d}
xl=fpl∈RH/swl⋅W/swl⋅d.
然后遍历
l
,
l
∈
1
,
2
,
.
.
.
,
h
l, l \in {1,2,...,h}
l,l∈1,2,...,h, 最终得到输入feature map的层次化surroundings:
{
x
l
}
l
=
1
h
\{x^{l}\}_{l=1}^{h}
{xl}l=1h.
3.聚合操作
x
l
x^{l}
xl为feature map,首先通过Flatten操作将其转化为Vector, 然后再把
h
h
h个这样的Vector进行Concat操作, 最终得到的tokens向量为
x
t
o
t
a
l
∈
R
s
x
d
x_{total} \in R^{s {\rm x} d}
xtotal∈Rsxd, 其中
s
=
∑
l
=
1
h
(
s
w
l
)
2
s=\sum_{l=1}^{h} (s_{w}^{l})^{2}
s=∑l=1h(swl)2.
4.计算Query, Key和Value
通过分别进行Linear projection即可得到Query, Key和Value,形式化定义如下:
Q
=
f
q
(
x
1
)
Q=f_{q}(x^{1})
Q=fq(x1)
K
=
f
k
(
x
t
o
t
a
l
)
K=f_{k}(x_{total})
K=fk(xtotal)
V
=
f
v
(
x
t
o
t
a
l
)
V=f_{v}(x_{total})
V=fv(xtotal)
5.Attention计算
对于位于第
i
i
i个sub-window
Q
i
∈
R
s
p
⋅
s
p
⋅
d
Q_{i} \in R^{s_{p} {\cdot s_{p} \cdot d}}
Qi∈Rsp⋅sp⋅d的Query, 其对应的Key和Value分别记为
K
i
∈
R
s
⋅
d
K_{i} \in R^{s \cdot d}
Ki∈Rs⋅d,
V
i
∈
R
s
⋅
d
V_{i} \in R^{s \cdot d}
Vi∈Rs⋅d.
计算公式如下:
A
t
t
e
n
t
i
o
n
(
Q
i
,
K
i
,
V
i
)
=
s
o
f
t
m
a
x
(
Q
i
K
i
T
d
+
B
)
V
i
Attention(Q_{i}, K_{i}, V_{i})=softmax(\frac{Q_{i} K_{i}^{T}}{\sqrt{d}+B}) V_{i}
Attention(Qi,Ki,Vi)=softmax(d+BQiKiT)Vi
最终输出的尺寸依然为
R
s
p
⋅
s
p
⋅
d
R^{s_{p} {\cdot s_{p} \cdot d}}
Rsp⋅sp⋅d
然后遍历
i
,
i
∈
{
1
,
2
,
.
.
,
H
/
s
p
⋅
W
/
s
p
}
i, i \in\{1,2,.., H/s_{p} \cdot W/s_{p}\}
i,i∈{1,2,..,H/sp⋅W/sp}重复上述操作即可。
时间复杂度分析
网格化实际上就是对数据进行Reshape, 假设原始输入
X
∈
R
M
x
N
x
d
X \in R^{M {\rm x} N {\rm x} d}
X∈RMxNxd
,Reshape之后第
l
l
l层
X
^
∈
R
s
w
l
x
s
w
l
x
(
M
s
w
l
⋅
N
s
w
l
⋅
d
)
\hat{X} \in R^{s_{w}^{l} {\rm x} s_{w}^{l} {\rm x} (\frac{M}{s_{w}^{l}} \cdot \frac{N}{s_{w}^{l}} \cdot d)}
X^∈Rswlxswlx(swlM⋅swlN⋅d)
对其池化的时间复杂度为
O
(
s
w
l
⋅
s
w
l
M
s
w
l
⋅
N
s
w
l
⋅
d
)
=
O
(
M
⋅
N
⋅
d
)
O(s_{w}^{l} \cdot s_{w}^{l} \frac{M}{s_{w}^{l}} \cdot \frac{N}{s_{w}^{l}} \cdot d)=O(M\cdot N \cdot d)
O(swl⋅swlswlM⋅swlN⋅d)=O(M⋅N⋅d), 对所有L个层池化的时间复杂度为
O
(
O
(
L
⋅
M
⋅
N
⋅
d
)
)
O(O(L \cdot M\cdot N \cdot d))
O(O(L⋅M⋅N⋅d)).
每个Query所属sub-window Q i ∈ R s p x s p x d Q_{i} \in R^{s_{p} {\rm x} s_{p} {\rm x} d} Qi∈Rspxspxd, 对应 K i , V i ∈ R s x d K_{i}, V_{i} \in R^{s {\rm x} d} Ki,Vi∈Rsxd, 因此计算Attention的时间复杂度为 O ( ( s p ) 2 ⋅ d ⋅ s ) O((s_{p})^{2}\cdot d \cdot s) O((sp)2⋅d⋅s), 共有 M s p ⋅ N s p \frac{M}{s_{p}} \cdot \frac{N}{s_{p}} spM⋅spN个这样的sub-window, 因此进行Attention总的时间复杂度为 M ⋅ N ⋅ d ⋅ ( ∑ 1 L ( s r l ) 2 ) M\cdot N \cdot d \cdot (\sum_{1}^{L} (s_{r}^{l})^{2}) M⋅N⋅d⋅(∑1L(srl)2)。
因此,总的时间复杂度为: O ( M ⋅ N ⋅ d ⋅ ( L + ∑ l = 1 L ( s r l ) 2 ) ) O(M\cdot N \cdot d \cdot (L + \sum_{l=1}^{L} (s_{r}^{l})^{2})) O(M⋅N⋅d⋅(L+∑l=1L(srl)2)).
点评
1.fine-grained local attention不能抓住global的信息,coarse-grained golbal attention虽然能抓住global information但因为粒度粗,因此,这两种方式实际上都不能发挥出NLP中原始Transformer中attention的建模能力。
focal self attention 实际上正是对这两种attention的整合,因为它既能抓住closet surroundings的fine-grained local信息,同时又能抓住far surroundings的coarse-grained global信息。具体来说, Key和Value的来源不是某个单一粒度级别的tokens, 而是多个粒度级别的层次化的tokens的融合。
2.所谓的focal self attention最核心的东西可以看成是常规scaled dot-product attention的一个前置操作。
与Transformer中MHA的比较
1.Transformer中的MHA处理的是序列,对于长度为n,维度为d的序列
X
∈
R
n
x
d
m
o
d
e
l
X \in R^{n {\rm x} d_{model}}
X∈Rnxdmodel,
Q
∈
R
n
x
d
k
Q \in R^{n {\rm x} d_{k}}
Q∈Rnxdk, 对应
K
∈
R
m
x
d
k
,
V
∈
R
m
x
d
v
K \in R^{m {\rm x} d_{k}}, V \in R^{m {\rm x} d_{v}}
K∈Rmxdk,V∈Rmxdv,
d
k
,
d
v
d_{k}, d_{v}
dk,dv完全取决于Linear projection输出的维度,
m
m
m为K-V对的数量;而在focal self attention这里,
Q
i
∈
R
s
p
⋅
s
p
⋅
d
Q_{i} \in R^{s_{p} \cdot {s_{p} \cdot d}}
Qi∈Rsp⋅sp⋅d, 其中
d
d
d为Figure的channel,
s
p
s_{p}
sp取决于网格化时设置的每个sub-window的宽度,对应
K
i
,
V
i
∈
R
s
⋅
d
K_{i}, V_{i} \in R^{s {\cdot} d}
Ki,Vi∈Rs⋅d, 其中s为所有L层池化结果经Flatten后再Concat的长度。
2.原始Transformer中为MHA, 通过多次Linear projection得到多个子空间,分别进行SDPA再聚合,论文中提到这既可以节省memory,提高efficancy,同时提升泛化能力; focal self attention中用的是网格化方法,Attention是在Window级别,而不是每个Query position, 这减少了memory, 节省了time.
几个疑点
1.论文中Fig.4 与正文中描述略不同, 图中是不同level的tokens先concatenation后进行linear projection产生 K i , V i K_{i}, V_{i} Ki,Vi, 文本描述部分是对每个level的分别进行linear projection (即 f k , f v f_{k}, f_{v} fk,fv), 然后再concatenation.
2.Fig. 4中Query position所在区域恰好位于Feature map的中心,因此每个level的tokens都是方形, 然而当Query position位于边缘时,怎么办? 需要做填充吗? 从论文中没有看到类似信息。
3.
s
p
s_{p}
sp,
s
w
l
s_{w}^{l}
swl与
s
r
l
s_{r}^{l}
srl之间应该满足的关系:
这一点论文中并没有明确说明。
很明显,
{
s
w
l
}
l
=
1
L
\{s_{w}^{l}\}_{l=1}^{L}
{swl}l=1L应该都能被
s
p
s_{p}
sp整除, 特别的,当
s
w
l
=
1
s_{w}^{l}=1
swl=1时,粒度级别最细; 当
s
w
l
=
s
p
s_{w}^{l}=s_{p}
swl=sp时, 粒度级别最粗。
三者之间的关系, Fig.4中Query所在的蓝色区域周围上下左右包含的尺寸为 s w l x s w l s_{w}^{l} {\rm x} {s_{w}^{l}} swlxswl的方格的数量都为2, 因此可以很容易得出结论: s r l = 4 + s p s w l s_{r}^{l}=4 + \frac{s_{p}}{s_{w}^{l}} srl=4+swlsp.
Focal Transformer
如图, Focal Transformer包含多个Stage:
{
S
t
a
g
e
i
}
i
=
1
4
\{\rm Stage i\}_{i=1}^{4}
{Stagei}i=14, 每个
S
t
a
g
e
i
{\rm Stage i}
Stagei里面通过Stack相同的building block组成,数量为
N
i
N_{i}
Ni,每个building block主要由Focal Self-Attention和Multi-Layer Perceptron 组成, 最核心的就是上面提到的Focal self-attention.
可以看到,在每个Stage之前都有一个Patch Embedding (PE), PE实际上就是一个Convolution层, 其中第一层的PE作用是将 X ∈ R M x N x 3 X \in R^{M {\rm x} N {\rm x} 3} X∈RMxNx3映射到 R M x N x d R^{M {\rm x} N {\rm x} d} RMxNxd, 然接下来,每个Stage开始之前都会有一个PE将上一个Stage输出的Feature map的Spatial dimension减小为原来的1/2, Channel dimension增大为原来的2倍。
总结
Reference
1.ArXiv, 2021, Focal Self-attention for Local-Global Interactions in Vision Transformers.