一、运算量计算介绍
1、 计算内容:训练过程中的浮点运算次数(FLOP)
- 浮点数的加减乘除运算
- 浮点数的指数函数运算
- 对数函数运算
- 三角函数运算
2、 计算涉及流程
- 多头注意力计算
- 线性变换计算
- 归一化、输出映射和旋转位置编码等运算量较少,输入编码层无需计算,一般会省略这几个部分
二、字母映射
- 模型参数量:P
- 批处理大小:B
- 输出序列长度:T
- 训练词元总数:C = BT
- 多头注意力包含N个头,每个头维度为D,中间状态维度H,满足H=ND
三、多头注意力运算量
1、运算量组成
- 单层前向传播:
- 注意力计算的矩阵乘法:2BT²ND
- 标准化操作:BT²N
- softmax操作需要进行指数、加和、归一化操作:3BT²N
- 结果的矩阵乘法:2BT²ND
- 单层后向传播:是单层前向传播的2倍
- 所有层多头注意力运算量:(单层前向传播+单层后向传播)乘上L层
2、计算过程
- 单层前向传播:4BT²ND + 4BT²N
- 单层多头注意力:3*(4BT²ND + 4BT²N)
- 所有层多头注意力:L3(4BT²ND + 4BT²N)
- 由于C=BT,所以最终运算量:12CTL(H + N)
四、线性变换运算量
1、运算量涉及部分
- 注意力层中四个映射
- 前馈网络层的变换
- 输出层映射
2、运算量组成
- 前向传播:2BTHH′
- 反向传播:4BTHH′
- 由于C=BT,综合运算量:6C*线性变换参数量
- 若采用激活重计算,反向传播时需要额外进行一次前向传播,总运算量:8C*线性变换参数量
五、总运算量
1、多头注意力计算量约为线性变换的T/6H,大模型训练场景下,序列长度小于等于中间状态维度H,因此多头注意力计算量最多为线性变换计算量的1/6,多头注意力运算量影响较小
2、根据参数量计算公式(参考之前文章大模型参数量计算),线性变换运算量占总运算量的95%以上,因此使用参数量P在C个词元上进行预训练的总运算量≈6CP,如果使用了激活冲计算,运算量≈8CP