transformer 复杂度分析
- 1.Vision Transformer(VIT) 复杂度
- 1.1 各个维度
- 1.2 patch embedding的复杂度
- 1.3 transformer encoder复杂度
- 1.4 前馈神经网络
- 1.5 总复杂度
1.Vision Transformer(VIT) 复杂度
输入时先将图片进行分割投影---->生成query,key,value向量---->计算注意力权重(多头)---->前馈神经网络
1.1 各个维度
输入图像:H ×W×C
分割为patch的维度:P×P
分割后的patch的数量:
W
×
H
P
×
P
\frac{W\times H}{P\times P}
P×PW×H
每个patch的维度:
P
×
P
×
C
P\times P\times C
P×P×C
1.2 patch embedding的复杂度
- 每个patch的维度:
d
p
a
t
c
h
=
P
×
P
×
C
d_{patch}=P\times P\times C
dpatch=P×P×C。
对每个patch进行维度嵌入: p a t c h ∈ R 1 × d p a t c h patch\in \Bbb R^{1\times d_{patch}} patch∈R1×dpatch - patch经过线性投影到
d
m
o
d
e
l
d_{model}
dmodel
- 线性投影权重: W w e i g h t d p a t c h × d m o d e l {W_{weight}}^{d_{patch}\times d_{model}} Wweightdpatch×dmodel
- 对patch进行投影后的维度: X e m b e d d i n g 1 × d m o d e l X_{embedding}^{1\times d_{model}} Xembedding1×dmodel
- 对N个patch进行投影后的维度: X e m b e d d i n g N × d m o d e l X_{embedding}^{N\times d_{model}} XembeddingN×dmodel
投影的复杂度: O e m b e d d i n g = N × d p a t c h × d m o d e l O_{embedding}=N\times d_{patch} \times d_{model} Oembedding=N×dpatch×dmodel
1.3 transformer encoder复杂度
-
输入特征 X ∈ R N × d m o d e l X\in \Bbb{R}^{N\times d_{model}} X∈RN×dmodel , N为patch的数量
Q = X W Q Q=XW_Q Q=XWQ K = X W K K=XW_K K=XWK V = X W V V=XW_V V=XWV
W Q , W K , W V ∈ R d m o d e l × d m o d e l W_Q,W_K,W_V\in \Bbb R^{d_{model}\times d_{model}} WQ,WK,WV∈Rdmodel×dmodel, Q , K , V ∈ R N × d m o d e l Q,K,V\in \Bbb R^{N \times d_{model}} Q,K,V∈RN×dmodel
复杂度计算: 对于单个Q,K,V,复杂度为 O Q , K , V = N × d m o d e l 2 O_{Q,K,V}=N\times d_{model}^2 OQ,K,V=N×dmodel2
总的为: 3 × O Q , K , V 3\times O_{Q,K,V} 3×OQ,K,V -
注意力计算得分
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
Q和K点积计算: Q K T ∈ R N × N QK^T\in \Bbb R^{N\times N} QKT∈RN×N。复杂度: O Q K T = N 2 ⋅ d m o d e l O_{QK^T}=N^2 \cdot d_{model} OQKT=N2⋅dmodel.
s o f t m a x ( ⋅ ) V ∈ R N × d m o d e l softmax(\cdot)V\in \Bbb R^{N\times d_{model}} softmax(⋅)V∈RN×dmodel。复杂度: O s o f t m a x ( ⋅ ) V = N 2 ⋅ d m o d e l O_{softmax(\cdot)V}=N^2\cdot d_{model} Osoftmax(⋅)V=N2⋅dmodel -
单个头复杂度: O h e a d = 3 N d m o d e l 2 + 2 N 2 d m o d e l O_{head}=3Nd_{model}^2+2N^2d_{model} Ohead=3Ndmodel2+2N2dmodel
-
多个头复杂度:
对于h个头,每个头的维度: d k = d m o d e l h d_k=\frac{d_{model}}{h} dk=hdmodel
总复杂度: O m u l t i − h e a d = n ( 3 N d k 2 + 2 N 2 d k ) = 3 N d m o d e l 2 + 2 N 2 d m o d e l O_{multi-head}=n(3Nd^2_k+2N^2d_k)=3Nd^2_{model}+2N^2d_{model} Omulti−head=n(3Ndk2+2N2dk)=3Ndmodel2+2N2dmodel
1.4 前馈神经网络
F
F
N
(
X
)
=
R
e
L
U
(
X
W
1
+
b
1
)
W
2
+
b
2
FFN(X)=ReLU(XW_1+b_1)W_2+b_2
FFN(X)=ReLU(XW1+b1)W2+b2
展
开
维
度
d
f
f
=
4
⋅
d
m
o
d
e
l
展开维度d_{ff}=4\cdot d_{model}
展开维度dff=4⋅dmodel,通常是
d
m
o
d
e
l
d_{model}
dmodel的四倍
W
1
∈
R
d
m
o
d
e
l
×
d
f
f
W_1\in \Bbb R^{d_{model}\times d_{ff}}
W1∈Rdmodel×dff ,
W
2
∈
R
d
f
f
×
d
m
o
d
e
l
W_2\in \Bbb R^{d_{ff}\times d_{model} }
W2∈Rdff×dmodel
X
∈
R
N
×
d
m
o
d
e
l
X\in \Bbb R^{N\times d_{model}}
X∈RN×dmodel
X
W
1
∈
R
N
×
d
f
f
XW_1\in \Bbb R^{N\times d_{ff}}
XW1∈RN×dff,复杂度:
O
1
=
N
⋅
d
m
o
d
e
l
⋅
d
f
f
O_1=N\cdot d_{model}\cdot d_{ff}
O1=N⋅dmodel⋅dff
R
e
L
U
(
X
W
1
+
b
1
)
W
2
∈
R
N
×
d
m
o
d
e
l
ReLU(XW_1+b_1)W_2\in \Bbb R^{N\times d_{model}}
ReLU(XW1+b1)W2∈RN×dmodel,复杂度:
O
2
=
N
⋅
d
f
f
⋅
d
m
o
d
e
l
O_2=N\cdot d_{ff}\cdot d_{model}
O2=N⋅dff⋅dmodel
总复杂度:
O
F
F
N
=
O
1
+
O
2
=
2
⋅
N
⋅
d
f
f
⋅
d
m
o
d
e
l
=
8
⋅
N
⋅
d
m
o
d
e
l
2
O_{FFN}=O_1+O_2=2\cdot N\cdot d_{ff} \cdot d_{model}=8\cdot N\cdot d^2_{model}
OFFN=O1+O2=2⋅N⋅dff⋅dmodel=8⋅N⋅dmodel2
1.5 总复杂度
在transformer的编码器和解码器中包含3个多头注意力和2个全连接层
L:transformer的层数
O
v
i
t
=
O
e
m
b
e
d
d
i
n
g
+
L
(
3
⋅
O
m
u
l
t
i
−
h
e
a
d
+
2
⋅
O
F
F
N
)
=
N
⋅
d
p
a
t
c
h
⋅
d
m
o
d
e
l
+
L
(
3
⋅
3
⋅
N
⋅
d
m
o
d
e
l
2
+
3
⋅
2
⋅
N
2
⋅
d
m
o
d
e
l
+
2
⋅
8
⋅
N
⋅
d
m
o
d
e
l
2
)
O_{vit}=O_{embedding}+L(3\cdot O_{multi-head}+2\cdot O_{FFN})=N\cdot d_{patch}\cdot d_{model}+L(3\cdot3\cdot N\cdot d_{model}^2+3\cdot2\cdot N^2\cdot d_{model}+2\cdot8\cdot N\cdot d^2_{model})
Ovit=Oembedding+L(3⋅Omulti−head+2⋅OFFN)=N⋅dpatch⋅dmodel+L(3⋅3⋅N⋅dmodel2+3⋅2⋅N2⋅dmodel+2⋅8⋅N⋅dmodel2)