Transformer参数量和复杂度

在算法岗面试中经常会问到Transformer相关的基础知识。
首先需要清楚Transformer的参数量和复杂度分别在算什么。

定义:

  • 参数量:神经网络中有很多参数矩阵,这个矩阵大小的和就是参数量,静态的,摆在那就在那,定量的。
  • 复杂度:与输入的数据有关,动态的,跟计算公式有关,定性的。

Transformer的架构主要分成2部分:

  • Encoder:6层,每层包括 Multi-Head Self-Attention(MHSA)和FFN
  • Decoder:6层,每层包括Masked MHSA、Multi-Head Cross-Attention和FFN

其他部分包括Input Embedding 、Postion Encoding以及最后解码的Linear层。

其中,每个MHSA、Masked MHSA、Multi-Head Cross-Attention和FFN中有含有Add&Norm操作。
在这里插入图片描述

参数量计算:

InputEmbedding: 将vocab中的词映射到d维度,所以: vocab*d
Encoder+Decoder:

  • MHSA/Multi-Head Cross-Attention/Masked MHSA:Q K V O 四个矩阵 【没有bias】
    4 ∗ d ∗ d = 4 d 2 4*d*d=4d^2 4dd=4d2
  • FFN :第一个矩阵先增到4d,第二个矩阵减到d。所以参数量为
    d ∗ ( 4 d ) + 4 ∗ d + ( 4 ∗ d ) ∗ d + d = 8 d 2 + 5 d d*(4d) +4*d + (4*d)*d +d= 8d^2+5d d(4d)+4d+(4d)d+d=8d2+5d
  • layerNorm : 参数量就是 γ \gamma γ β \beta β ,所以是 2d

其实可以看到,MHSA的参数量仅有 FFNN 的一半。

T o t a l = I n p u t E m b e d d i n g + E n c o d e r + D e c o d e r = v o c a b ∗ d + 6 ∗ ( 4 d 2 + 2 d + 8 d 2 + 5 d + 2 d ) + 6 ( 4 d 2 + 2 d + 4 d 2 + 2 d + 8 d 2 + 5 d + 2 d ) = v o c a b ∗ d + 6 ∗ ( 12 d 2 + 9 d ) + 6 ∗ ( 16 d 2 + 11 d ) = v o c a b ∗ d + 168 d 2 + 120 d Total = InputEmbedding + Encoder + Decoder \\ = vocab*d + 6*( 4d^2 + 2d + 8d^2+5d + 2d) + 6(4d^2 + 2d + 4d^2 + 2d + 8d^2+5d + 2d)\\ = vocab*d + 6*(12d^2+9d ) + 6*(16d^2+11d) \\ = vocab*d + 168d^2 + 120d Total=InputEmbedding+Encoder+Decoder=vocabd+6(4d2+2d+8d2+5d+2d)+64d2+2d+4d2+2d+8d2+5d+2d=vocabd+6(12d2+9d)+6(16d2+11d)=vocabd+168d2+120d

复杂度分析

复杂度分成时间复杂度和空间复杂度。
神经网络中,最常见的就是线性映射,涉及到矩阵运算。这里用到一个矩阵相乘运算复杂度的前置知识:

矩阵M1=m * n 矩阵M2=n * k,得到矩阵M=m * k,所以时间复杂度为O(mnk),空间复杂度为O(m*k)

假设输入序列长度为N

InputEmbedding: 将vocab中的词映射到d维度,类似检索哈希表。
时间复杂度:O(N)
空间复杂度:O(N*d)
Encoder+Decoder:

  • MHSA/Multi-Head Cross-Attention/Masked MHSA:
    attention计算是复杂度的关键。
    时间复杂度: O ( N 2 d ) O(N^2d) O(N2d)
    self-attention中计算attention score那里就是
    Q ∗ K T Q*K^T QKT O ( N ∗ d ∗ N ) = O ( N 2 d ) O(N*d*N)=O(N^2d) O(NdN)=O(N2d)
    ● 让 s c o r e s = s o f t m a x ( Q ∗ K T / d ) scores=softmax(Q*K^T/\sqrt{d}) scores=softmax(QKT/d ) 那么 s c o r e s ∗ V scores*V scoresV O ( N 2 d ) O(N^2d) O(N2d)
    空间复杂度: O ( N 2 + N d ) O(N^2+Nd) O(N2+Nd)
    Q ∗ K T Q*K^T QKT O ( N 2 ) O(N^2) O(N2)
    s c o r e s ∗ V scores*V scoresV O ( N d ) O(Nd) O(Nd)
  • FFN :第一个矩阵先增到4d,第二个矩阵减到d。
    时间复杂度: O ( N ∗ d ∗ 4 d ) = O ( N d 2 ) O(N*d* 4d) =O(Nd^2) O(Nd4d)=O(Nd2)
    空间复杂度: O ( N d ) O(Nd) O(Nd)

所以:
时间复杂度: O ( N 2 d + N d 2 ) O(N^2d+Nd^2) O(N2d+Nd2)
空间复杂度: O ( N 2 + N d ) O(N^2+Nd) O(N2+Nd)

通常把d作为常量,所以:
时间复杂度: O ( N 2 ) O(N^2) O(N2)
空间复杂度: O ( N 2 ) O(N^2) O(N2)

参考:https://zhuanlan.zhihu.com/p/661804092

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值