bert主要的实现是基于transformer的encoder部分,参数维度不同的地方是1)输入多了一项segment embedding,2)中间维度基本是768,以及多头注意力以及前向网络重复了12次。
在统计bert参数的时候,一共要考虑5部分。
1)第一部分:输入层包含三项
token embedding | 词表大小*768 |
position emb | max_len(512*768) |
segment emb | 两个取值0,1(2*768) |
2)第二部分:多头注意力
12个头,其中每个头包括Q\K\V三组参数
768(原始维度)*768/12(每个头的q\k\v的维度)*3*12(头的个数)
然后concat起来所有输出,再变换一下 768*768+768
3)第三部分:Add and Norm
add不需要参数,norm有两个参数需要学习:shift和scale(2*768)
4)第四部分:前向网络
两层全连接网络(W,b):第一层是768*3072(4H)+3072
第二层是3072*768+768
5)第五部分:Add and Norm
同第三部分:2*768
总参数: 第一部分+12*(第二+第三+第四+第五部分)