vision transformer(vit)复杂度分析

transformer 复杂度分析

  • 1.Vision Transformer(VIT) 复杂度
    • 1.1 各个维度
    • 1.2 patch embedding的复杂度
    • 1.3 transformer encoder复杂度
    • 1.4 前馈神经网络
    • 1.5 总复杂度

1.Vision Transformer(VIT) 复杂度

输入时先将图片进行分割投影---->生成query,key,value向量---->计算注意力权重(多头)---->前馈神经网络

1.1 各个维度

输入图像:H ×W×C
分割为patch的维度:P×P
分割后的patch的数量: W × H P × P \frac{W\times H}{P\times P} P×PW×H
每个patch的维度: P × P × C P\times P\times C P×P×C

1.2 patch embedding的复杂度

  • 每个patch的维度: d p a t c h = P × P × C d_{patch}=P\times P\times C dpatch=P×P×C
    对每个patch进行维度嵌入: p a t c h ∈ R 1 × d p a t c h patch\in \Bbb R^{1\times d_{patch}} patchR1×dpatch
  • patch经过线性投影到 d m o d e l d_{model} dmodel
    • 线性投影权重: W w e i g h t d p a t c h × d m o d e l {W_{weight}}^{d_{patch}\times d_{model}} Wweightdpatch×dmodel
    • 对patch进行投影后的维度: X e m b e d d i n g 1 × d m o d e l X_{embedding}^{1\times d_{model}} Xembedding1×dmodel
    • 对N个patch进行投影后的维度: X e m b e d d i n g N × d m o d e l X_{embedding}^{N\times d_{model}} XembeddingN×dmodel

投影的复杂度: O e m b e d d i n g = N × d p a t c h × d m o d e l O_{embedding}=N\times d_{patch} \times d_{model} Oembedding=N×dpatch×dmodel

1.3 transformer encoder复杂度

  • 输入特征 X ∈ R N × d m o d e l X\in \Bbb{R}^{N\times d_{model}} XRN×dmodel , N为patch的数量
    Q = X W Q Q=XW_Q Q=XWQ K = X W K K=XW_K K=XWK V = X W V V=XW_V V=XWV
    W Q , W K , W V ∈ R d m o d e l × d m o d e l W_Q,W_K,W_V\in \Bbb R^{d_{model}\times d_{model}} WQ,WK,WVRdmodel×dmodel Q , K , V ∈ R N × d m o d e l Q,K,V\in \Bbb R^{N \times d_{model}} Q,K,VRN×dmodel
    复杂度计算: 对于单个Q,K,V,复杂度为 O Q , K , V = N × d m o d e l 2 O_{Q,K,V}=N\times d_{model}^2 OQ,K,V=N×dmodel2
    总的为: 3 × O Q , K , V 3\times O_{Q,K,V} 3×OQ,K,V

  • 注意力计算得分
    A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V
    Q和K点积计算: Q K T ∈ R N × N QK^T\in \Bbb R^{N\times N} QKTRN×N。复杂度: O Q K T = N 2 ⋅ d m o d e l O_{QK^T}=N^2 \cdot d_{model} OQKT=N2dmodel.
    s o f t m a x ( ⋅ ) V ∈ R N × d m o d e l softmax(\cdot)V\in \Bbb R^{N\times d_{model}} softmax()VRN×dmodel。复杂度: O s o f t m a x ( ⋅ ) V = N 2 ⋅ d m o d e l O_{softmax(\cdot)V}=N^2\cdot d_{model} Osoftmax()V=N2dmodel

  • 单个头复杂度: O h e a d = 3 N d m o d e l 2 + 2 N 2 d m o d e l O_{head}=3Nd_{model}^2+2N^2d_{model} Ohead=3Ndmodel2+2N2dmodel

  • 多个头复杂度:
    对于h个头,每个头的维度: d k = d m o d e l h d_k=\frac{d_{model}}{h} dk=hdmodel
    总复杂度: O m u l t i − h e a d = n ( 3 N d k 2 + 2 N 2 d k ) = 3 N d m o d e l 2 + 2 N 2 d m o d e l O_{multi-head}=n(3Nd^2_k+2N^2d_k)=3Nd^2_{model}+2N^2d_{model} Omultihead=n(3Ndk2+2N2dk)=3Ndmodel2+2N2dmodel

1.4 前馈神经网络

F F N ( X ) = R e L U ( X W 1 + b 1 ) W 2 + b 2 FFN(X)=ReLU(XW_1+b_1)W_2+b_2 FFN(X)=ReLU(XW1+b1)W2+b2
展 开 维 度 d f f = 4 ⋅ d m o d e l 展开维度d_{ff}=4\cdot d_{model} dff=4dmodel,通常是 d m o d e l d_{model} dmodel的四倍
W 1 ∈ R d m o d e l × d f f W_1\in \Bbb R^{d_{model}\times d_{ff}} W1Rdmodel×dff W 2 ∈ R d f f × d m o d e l W_2\in \Bbb R^{d_{ff}\times d_{model} } W2Rdff×dmodel
X ∈ R N × d m o d e l X\in \Bbb R^{N\times d_{model}} XRN×dmodel
X W 1 ∈ R N × d f f XW_1\in \Bbb R^{N\times d_{ff}} XW1RN×dff,复杂度: O 1 = N ⋅ d m o d e l ⋅ d f f O_1=N\cdot d_{model}\cdot d_{ff} O1=Ndmodeldff
R e L U ( X W 1 + b 1 ) W 2 ∈ R N × d m o d e l ReLU(XW_1+b_1)W_2\in \Bbb R^{N\times d_{model}} ReLU(XW1+b1)W2RN×dmodel,复杂度: O 2 = N ⋅ d f f ⋅ d m o d e l O_2=N\cdot d_{ff}\cdot d_{model} O2=Ndffdmodel
总复杂度: O F F N = O 1 + O 2 = 2 ⋅ N ⋅ d f f ⋅ d m o d e l = 8 ⋅ N ⋅ d m o d e l 2 O_{FFN}=O_1+O_2=2\cdot N\cdot d_{ff} \cdot d_{model}=8\cdot N\cdot d^2_{model} OFFN=O1+O2=2Ndffdmodel=8Ndmodel2

1.5 总复杂度

在transformer的编码器和解码器中包含3个多头注意力和2个全连接层
L:transformer的层数
O v i t = O e m b e d d i n g + L ( 3 ⋅ O m u l t i − h e a d + 2 ⋅ O F F N ) = N ⋅ d p a t c h ⋅ d m o d e l + L ( 3 ⋅ 3 ⋅ N ⋅ d m o d e l 2 + 3 ⋅ 2 ⋅ N 2 ⋅ d m o d e l + 2 ⋅ 8 ⋅ N ⋅ d m o d e l 2 ) O_{vit}=O_{embedding}+L(3\cdot O_{multi-head}+2\cdot O_{FFN})=N\cdot d_{patch}\cdot d_{model}+L(3\cdot3\cdot N\cdot d_{model}^2+3\cdot2\cdot N^2\cdot d_{model}+2\cdot8\cdot N\cdot d^2_{model}) Ovit=Oembedding+L(3Omultihead+2OFFN)=Ndpatchdmodel+L(33Ndmodel2+32N2dmodel+28Ndmodel2)

transformer结构图

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值