1.Encoder
(1)input_embeding
Input假设是一个维度为vocab的向量,
通过Input_embeding部分变成 vocab*d_model的矩阵:即
(vocab*1)@(1*d_model)=vocab*d_model次乘法。
再乘上d_model ** 0.5 所以 Input_embeding的FLOPs = vocab*d_model*2
(2)Positional Encoding
采用公式
构建position encoding,即每个位置是5个运算,并最后将position encoding与input_embeding相加,所以每个位置进行6个运算。
所以:position encoding的FLOPs是 vocab * d_model * 6
(3)sublayer1
LayerNorm的输入是(vocab*d_model)
首先会对每个词向量算均值:
vocab * (d_model - 1 )次加法 + vocab次乘法 = vocab * d_model次运算
再算方差:(d_model + d_model + d_model - 1 +2) * vocab = vocab * (3d_model + 1)次运算
所以总共有:FLOPs = vocab * (4d_model + 1)次运算。
self_attention:
1.映射
的维度都是vocab * d_model
的维度是 d_model * dk
的维度是d_model * dv
论文里面设定的是dk = dv ,所以统一采用dk
所以映射产生的FLOPs = (2d_model-1)*vocab*3dk
2.过Attention函数
FLOPs = (2dk - 1) * vocab * vocab + vocab*vocab + 2vocab*vocab + vocab * (vocab -1)+(2vocab - 1)*vocab *dv = vocab * vocab *(4dk + 3 )- vocab - vocab*dk
3.因为有h个头,所以总FLOPs = h * (((2d_model-1)*vocab*3dk + vocab * vocab *(4dk + 3 )- vocab - vocab*dk)
4.最后再过一个线性层FLOPs = (2d_model - 1) * vocab * d_model
综上sublayer1的FLOPs = h * (((2d_model-1)*vocab*3dk + vocab * vocab *(4dk + 3 )- vocab - vocab*dk)+(2d_model - 1) * vocab * d_model + vocab * (4d_model + 1)
=4 * vocab *d_model + h * (6*d_model*vocab*d_k - 4*v*d_k + 4*d_k*vocab*vocab + 3*vocab*vocab - vocab) +2 * vocab * d_model *d_model - vocab * d_model + vocab
(4)sublayer2
相加FLOPs = vocab * d_model
由前可知:Layernorm的FLOPs = vocab * (4d_model + 1)
feed_forward层根据代码
可知:
FLOPs = (2d_model-1)* vocab * d_ff + vocab * d_ff + (2 * d_ff -1) * vocab * d_ model
= vocab * d_model(4d_ff - 1)
所以:sublayer2的FLOPs = vocab * d_model + vocab * (4d_model + 1)+ vocab * d_model(4d_ff - 1)
综上:EncoderLayer的FLOPs = vocab*d_model*2 + vocab * d_model * 6 + h * (((2d_model-1)*vocab*3dk + vocab * vocab *(4dk + 3 )- vocab - vocab*dk)+(2d_model - 1) * vocab * d_model + vocab * (4d_model + 1)+ vocab * d_model + vocab * (4d_model + 1)+ vocab * d_model(4d_ff - 1)
=16* vocab *d_model + h * (6*d_model*vocab*d_k - 4*v*d_k + 4*d_k*vocab*vocab + 3*vocab*vocab - vocab) +2 * vocab * d_model *d_model - vocab * d_model + 2 * vocab + 4 * vocab * d_model *d_ff
所以Encoder的FLOPs = 6 * (16* vocab *d_model + h * (6*d_model*vocab*d_k - 4*v*d_k + 4*d_k*vocab*vocab + 3*vocab*vocab - vocab) +2 * vocab * d_model *d_model - vocab * d_model + 2 * vocab + 4 * vocab * d_model *d_ff)
(5)LayerNorm
最后还有一层LayerNorm FLOPs = vocab * (4d_model + 1)
6 * (16* vocab *d_model + h * (6*d_model*vocab*d_k - 4*v*d_k + 4*d_k*vocab*vocab + 3*vocab*vocab - vocab) +2 * vocab * d_model *d_model - vocab * d_model + 2 * vocab + 4 * vocab * d_model *d_ff) + 4 * vocab *d_model +vocab
2.Decoder
decoder部分和encoder部分的FLOPs计算基本一致,直接给出FLOPs计算结果。
tips:Multi_Attention那里有从Encoder处来的memory输入。
(1)
output假设是一个 d_out, 通过embeding变成d_out * d_model的矩阵
Input_embeding的FLOPs = d_out*d_model*2
position encoding的FLOPs是 d_out * d_model * 6
加法 FLOPs d_out*d_model
总共FLOPs = 9 * d_out*d_model
(2)sublayer1
h * (((2d_model-1)*d_out*3dk + d_out * d_out *(4dk + 3 )- d_out - d_out*dk)+(2d_model - 1) * d_out * d_model + d_out * (4d_model + 1)
(3)sublayer2
加法的FLOPs = d_out * d_model
LayerNorm FLOPs = vocab * (4d_model + 1)
Multihead_Attention
1.映射
的维度是d_out * d_model
的维度都是vocab * d_model
的维度是 d_model * dk
的维度是d_model * dv
论文里面设定的是dk = dv ,所以统一采用dk
所以映射产生的FLOPs = (2d_model-1)*vocab*2dk + (2d_model - 1) * d_out * dk
2.过Attention函数
FLOPs =(2dk -1)*d_out * vocab + d_out * vocab +2*d_out*vocab + d_out * (vocab -1 ) +
(2vocab -1) * d_out * dk
3.因为有h个头,所以总FLOPs = h * ((2d_model-1)*vocab*2dk + (2d_model - 1) * d_out * dk + (2dk -1)*d_out * vocab + d_out * vocab +2*d_out*vocab + d_out * (vocab -1 ) +
(2vocab -1) * d_out * dk)
4.最后再过一个线性层FLOPs = (2d_model - 1) * d_out * d_model
所以sublayer2的FLOPs = d_out * d_model + vocab * (4d_model + 1)+ h * ((2d_model-1)*vocab*2dk + (2d_model - 1) * d_out * dk + (2dk -1)*d_out * vocab + d_out * vocab +2*d_out*vocab + d_out * (vocab -1 ) +(2vocab -1) * d_out * dk) + (2d_model - 1) * d_out * d_model
(4)sublayer3
FLOPs = d_out * d_model + d_out * (4d_model + 1)+ d_out * d_model(4d_ff - 1)
(5)LayerNorm
最后还有一层LayerNorm
d_out * (4d_model + 1)
模型右半部分的最终FLOPs = LayerNorm + 6*Decoder =
d_out * (4d_model + 1)+ 6 *(d_out * d_model + d_out * (4d_model + 1)+ d_out * d_model(4d_ff - 1) + d_out * d_model + vocab * (4d_model + 1)+ h * ((2d_model-1)*vocab*2dk + (2d_model - 1) * d_out * dk + (2dk -1)*d_out * vocab + d_out * vocab +2*d_out*vocab + d_out * (vocab -1 ) +(2vocab -1) * d_out * dk) + (2d_model - 1) * d_out * d_model + h * (((2d_model-1)*d_out*3dk + d_out * d_out *(4dk + 3 )- d_out - vd_out*dk)+(2d_model - 1) * d_out * d_model + d_out * (4d_model + 1))
=d_out * ( 4 * d_model +1) + 6 * (16 * d_out *d_model + 2 * d_out * d_model * d_model + 4*d_model*d_out*d_ff + 4 * d_model * vocab + vocab +2 * d_model * d_model * d_out + h * (8 * d_model * d_out * d_k - 6 * d_k + 4 * d_k * d_out * d_out + 3 * d_out * d_out - 2 * d_out +4 * d_model * vocab * d_k - 2 * d_k * vocab + 4 * d_k * d_out * vocab + 3 * vocab * d_out))
3.Generator
线性层 : d_out * d_model @ d_model * tgt _vocab 结果是一个d_out * tgt_vocab的矩阵
FLOPs = (2d_model -1) * d_out * tgt_vocab
softmax的FLOPs = 2d_out*tgt_vocab + d_out * (tgt_vocab -1)
generatorFLOPs = 2 * tgt_vocab * d_out - d_out + 2 * d_model * tgt_vocab * d_out
4.总结:
论文里所给得参数值:
vocab = 10 # vocab可以调整 d_model=512 d_ff=2048 h=8 d_k=64 d_out = 1 tgt_vocab = 10000 encoder = 4 * vocab *d_model +vocab + 6 * (16* vocab *d_model + h * (6*d_model*vocab*d_k - 4*vocab*d_k + 4*d_k*vocab*vocab + 3*vocab*vocab - vocab) +2 * vocab * d_model *d_model - vocab * d_model + 2 * vocab + 4 * vocab * d_model *d_ff) decoder = d_out * ( 4 * d_model +1) + 6 * (16 * d_out *d_model + 2 * d_out * d_model * d_model + 4*d_model*d_out*d_ff + 4 * d_model * vocab + vocab +2 * d_model * d_model * d_out + h * (8 * d_model * d_out * d_k - 6 * d_k + 4 * d_k * d_out * d_out + 3 * d_out * d_out - 2 * d_out +4 * d_model * vocab * d_k - 2 * d_k * vocab + 4 * d_k * d_out * vocab + 3 * vocab * d_out)) generator = 2 * tgt_vocab * d_out - d_out + 2 * d_model * tgt_vocab * d_out result = (encoder + decoder + generator) / 1e9
名称 FLOPs(GFLOPs) Latency(ms) VGG16 15.483862016 1791.4 resnet18 1.819066368 190.82 resnet34 3.671263232 408.05 resnet50 4.111514624 445.56 resnet152 11.558837248 1362.19 GoogleNet 1.504879712 180.35 vit_tiny_patch16_384 >49.35 2463.74 vit_tiny_patch16_224 >16.85 712.19