BERT参数计算,RBT3模型结构

B E R T B a s e BERT_{Base} BERTBase的参数来源

  

  论文中的 B E R T b a s e BERT{base} BERTbase的模型由多层双向的Transformer编码器组成,由12层组成,768隐藏单元,12个head,总参数量110M,约1.15亿参数量。
  

在这里插入图片描述
我们在这里来探究一下 B E R T B a s e BERT_{Base} BERTBase的参数来源。

第一部分(计算词向量时不同编码的参数量之和):

  可以看到分别经过了word_embeddings,position_embeddings,token_type_embeddings 编码,由此可以计算在embeddings的参数量的个数。

word_embeddings参数:词汇量的大小为30522,每个词都是768维,共30522*768。

position_embeddings参数:文本输入的最大长度max_position_embeddings=512,也就是可以输入512个词,每个词都是768维,共512*768。

token_type_embeddings参数:(2个句子,0和1区分上下句子) 共2*768。

embedding总参数 = (30522+512+2)*768 = 23,835,648 = 22.7MB


第二部分(计算 W Q W^Q WQ W K W^K WK W V W^V WV W O W^O WO部分的参数个数):

  token经过编码之后X(m,768), W Q W^Q WQ(768,64), W K W^K WK(768,64), W V W^V WV W O W^O WO

  q=X W Q W^Q WQ(m,64),k=X W K W^K WK(m,64)------>q* k T k^T kT(m,m)------>z = softmax(q* k T k^T kT/8)*v (m,64)。

解释:m为输入的单词的数量,768位每个词的维度,64是因为分成了12个head(768/12)。

最后还有一个 W O W^O WO。易得, W O W^O WO(768,768)。最终Z=(m, 768),与输入保持一致。

当然图中的m=2,每个词的维度为4。被分为了8个head。

在这里插入图片描述

故,12multi-heads的参数为:(768* 64* 3 ) *12+768 * 768= 2,359,296

故,12multi-heads的参数为:2,359,296 * 12 = 28,311,552 = 27MB


第三部分:全连接层(FeedForward)参数

  前馈网络feed forword的参数主要由2个全连接层组成,论文中全连接层的公式为:

  FFN(x) = max(0, xW1 + b1)W2 + b2。Bert沿用了惯用的全连接层大小设置,即4 * dmodle = 3072,其中用到了两个参数W1,W2,其中W1(768,3072),W2(3072,768),b1(768,1),b2(3072,1)。

12层的全连接层参数为:12*( 2 * 768 * 3072)+768+3072≈ 56,623,104 ≈ 54MB


第四部分:LayerNorm层参数

LN层有gamma和beta等2个参数。在三个地方用到了layernorm层:embedding层后、multi-head attention后、feed forward后。

12层LN层参数为:768 * 2 + (768 * 2) * 12 + (768 * 2) * 12 = 38,400 = 37.5KB=0.037MB


第五部分:基于预训练过程的Masked LM和NSP(next sentence prediction)的参数:

Masked LM:768*2

NSP:768*768

,预训练过程的参数为:768*2+768 * 768≈0.56MB

所以参数的数量一共有22.7MB+ 27MB+54MB+ 0.037MB+ 0.56MB =104.297MB。

离110M可能还有一些距离,可能有些地方的参数没有加完全,但已经计算量了主体参数量。


RBT3模型的各层的含义:

import torch
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel

model_name = 'rbt3 (2)'
MODEL_PATH = 'E:/rbt3 (2)/'
# a. 通过词典导入分词器
tokenizer = BertTokenizer.from_pretrained(model_name)
# b. 导入配置文件
model_config = BertConfig.from_pretrained(model_name)
# 修改配置
model_config.output_hidden_states = True
model_config.output_attentions = True
# 通过配置和路径导入模型
bert_model = BertModel.from_pretrained(MODEL_PATH, config=model_config)
print(bert_model)

运行结果(rbt3模型的结构):

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)



  
在这里插入图片描述

可以根据这张图片来理解:

其中(embeddings)部分(与上图一一对应):

 (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    ##进行标准化
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )

rbt3encoder层,一共有3个BertLayer

  (encoder): BertEncoder(
    (layer): ModuleList(
        #一共三层
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        #Bert模型中的一个中间层,
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )

其中BertAttention包括BertSelfAttentionBertSelfOutput图和模型的输出一一对应)。

BertSelfAttention主要是自注意力机制,包括了三个矩阵的线性变换(query)(key)(value)。

BertSelfOutput:还通过残差连接(residual connection)和层归一化(layer normalization)技术,将BertSelfAttention层的输出归一化,从而保留了输入的信息,并且有助于缓解梯度消失问题。

BertIntermediateBertIntermediateBert模型中的一个中间层,它负责将输入通过全连接层进行映射,并应用激活函数(通常是GELU激活函数)。这个非线性映射引入了更丰富的特征表示能力,使得Bert模型能够学习到更复杂的语义信息。BertIntermediate充当了encoder层中的非线性变换,帮助模型更好地捕捉输入序列中的上下文关系和语义信息。

BertOutputBertOutputBert模型中的一个输出层,它接收BertIntermediate层的输出,并通过全连接层进行线性映射,将特征维度映射回原始维度。此外,BertOutput还通过残差连接(residual connection)和层归一化(layer normalization)技术,将BertIntermediate层的输出与输入进行相加和归一化,从而保留了输入的信息,并且有助于缓解梯度消失问题。这种残差连接和层归一化技术有助于提高模型的训练稳定性,使得Bert模型更容易训练并且能够更好地捕捉输入序列中的语义信息。

BertPooler层:

 (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )

其中encoder就是将BERT的所有token经过12个TransformerEncoder进行embeddingpooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值