模型概况:
BERT-Base: L = 12 , H = 768 , A = 12 L = 12, H = 768, A = 12 L=12,H=768,A=12
参数计算:
PART 01:input embedding
- Token Embedding: 30522 × 768 30522 \times 768 30522×768
- Position Embedding: (max_length) 512 × 768 512 \times 768 512×768
- Segment Embedding: 2 × 768 2 \times 768 2×768
- 总参数量 ( 30522 + 512 + 2 ) × 768 = 23 , 835 , 648 (30522 + 512 + 2) \times 768 = 23,835,648 (30522+512+2)×768=23,835,648
PART 02:Multi-Head Attention
-
基本信息
- 12个head
- 生成 Q K V 3个向量
-
单个 head 的参数量
-
768
×
768
/
12
×
3
768 \times 768/12 \times 3
768×768/12×3
-
768
×
768
/
12
×
3
768 \times 768/12 \times 3
768×768/12×3
-
多头拼接的参数
- 12 × 768 / 12 × 768 12 \times 768/12 \times 768 12×768/12×768
-
总参数量 ( 768 × 768 / 12 × 3 ) × 12 + 12 × 768 / 12 × 768 = 2 , 359 , 296 (768 \times 768/12 \times 3)\times {\color{red}12} + 12 \times 768/12 \times 768 = 2,359,296 (768×768/12×3)×12+12×768/12×768=2,359,296
PART 03:Add & Norm (第一次)
- 基本信息
-
针对多头注意力的输出,这里使用的是 L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x + Sublayer(x)) LayerNorm(x+Sublayer(x))
进行层标准化需要计算同一层隐层单元中的如上两个参数。
-
- 总参数量: 768 × 2 = 1 , 536 768 \times 2 = 1,536 768×2=1,536
PART 04:Feed Forward
- 公式 F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x)=max(0, xW_{1}+b_{1})W_{2}+b_{2} FFN(x)=max(0,xW1+b1)W2+b2
- 论文指明,feed-forward/filter size 设置为 4H(即 4 × 768 = 3072 4 \times 768 = 3072 4×768=3072)
- 第一层参数: 768 × 3072 + 3072 768 \times 3072 + 3072 768×3072+3072
- 第二层参数: 3072 × 768 + 768 3072 \times 768 + 768 3072×768+768
- 总参数量: ( 768 × 3072 + 3072 ) + ( 3072 × 768 + 768 ) = 4 , 722 , 432 (768 \times 3072 + 3072)+ (3072 \times 768 + 768)= 4,722,432 (768×3072+3072)+(3072×768+768)=4,722,432
PART 05:Add & Norm (第二次)
- 与第一次相同,参数量为 768 × 2 = 1 , 536 768 \times 2 = 1,536 768×2=1,536
计算结果:
- 由于 PART 02-05 在 BERT-Base 模型中共有 12 个 Encoder
- 因此,参数总量为:
23 , 835 , 648 + ( 2 , 359 , 296 + 1 , 536 + 4 , 722 , 432 + 1 , 536 ) × 12 = 108 , 853 , 248 23,835,648 + (2,359,296 + 1,536 + 4,722,432 + 1,536) \times 12 = 108,853,248 23,835,648+(2,359,296+1,536+4,722,432+1,536)×12=108,853,248
参考论文
Transformer: Attention is all you need
Layer Normalization: Layer Normalization
BERT: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding