BERT代码结构


基于 Transformers 版本 4.4.2(2021年3月19日发布)项目中pytorch版的BERT相关代码进行分析。

BertTokenizer

BertTokenizer是基于BasicTokenizer和WordPieceTokenizer的分词器。

bt = BertTokenizer.from_pretrained('bert-base-uncased')
bt('I like natural language progressing!')

BasicTokenizer

BasicTokenizer负责处理的第一步——按标点、空格等分割句子,并处理是否统一小写,以及清理非法字符。

  • 对于中文字符,通过预处理(加空格)来按字分割;
  • 同时可以通过never_split指定对某些词不进行分割;

WordPieceTokenizer

WordPieceTokenizer在词的基础上,进一步将词分解为子词(subword)。subword 介于 char 和 word 之间,既在一定程度保留了词的含义,又能够照顾到英文中单复数、时态导致的词表爆炸和未登录词的 OOV(Out-Of-Vocabulary)问题,将词根与时态词缀等分割出来,从而减小词表,也降低了训练难度;
BertTokenizer 有以下常用方法:

  • from_pretrained:从包含词表文件(vocab.txt)的目录中初始化一个分词器;
  • tokenize:将文本(词或者句子)分解为子词列表;
  • convert_tokens_to_ids:将子词列表转化为子词对应下标的列表;
  • convert_ids_to_tokens :与上一个相反;
  • convert_tokens_to_string:将 subword 列表按“##”拼接回词或者句子;
  • encode:对于单个句子输入,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;对于两个句子输入(多个句子只取前两个),分解词并加入特殊词形成“[CLS], x1, [SEP], x2, [SEP]”的结构并转换为下标列表;
  • decode:可以将 encode 方法的输出变为完整句子。

BertModel

BertModel主要为transformer encoder结构,包含三个部分:

  • BertEmbeddings类
  • BertEncoder类
  • BertPooler类(这一部分是可选的)

BertModel前向传播过程中各个参数的含义以及返回值:

def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    )
  • input_ids:经过 tokenizer 分词后的 subword 对应的下标列表;
  • attention_mask:在 self-attention 过程中,这一块 mask 用于标记 subword所处句子和padding的区别,将padding部分填充为0;
  • token_type_ids:标记 subword 当前所处句子(第一句/第二句/ padding);
  • position_ids:标记当前词所在句子的位置下标;
  • head_mask:用于将某些层的某些注意力计算无效化;
  • inputs_embeds:如果提供了,那就不需要input_ids,跨过 embedding lookup 过程直接作为 Embedding 进入 Encoder 计算;
  • encoder_hidden_states:这一部分在BertModel配置为decoder时起作用,将执行 cross-attention 而不是 self-attention;
  • encoder_attention_mask:同上,在cross-attention中用于标记 encoder端输入的padding;
  • past_key_values:把预先计算好的 K-V 乘积传入,以降低 cross-attention 的开销(因为原本这部分是重复计算);
  • use_cache:将保存上一个参数并传回,加速 decoding;
  • output_attentions:是否返回中间每层的 attention 输出;
  • output_hidden_states:是否返回中间每层的输出;
  • return_dict:是否按键值对的形式(ModelOutput 类,也可以当作 tuple 用)返回输出,默认为真。

BertModel 的其他方法:

  • get_input_embeddings:提取embedding中的 word_embeddings 即词向量部分;
  • set_input_embeddings:为embedding中的 word_embeddings赋值;
  • _prune_heads:提供了将注意力头剪枝的函数,输入为{layer_num: list of heads to prune in this layer}的字典,可以将指定层的某些注意力头剪枝。

注:剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(!先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法。

BertEmbeddings类

在这里插入图片描述From:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

求和
求和
求和
input
word_embeddings
token_type_embeddings
position_embeddings
embeddings
LayerNorm+Dropout
output
  • word_embeddings:即为图中TokenEmbedding,为subword对应的嵌入。
  • token_type_embeddings:即为图中SegmentEmbedding,用于表示当前词所在的句子,辅助区别句子与padding、句子对间的差异。
  • position_embeddings:句子中每个词的位置嵌入,用于区别词的顺序。和 transformer论文中的设计不同,这里的位置编码是训练出来的,而不是通过Sinusoidal函数计算得到的固定嵌入。这种方法有更好的拓展性可以直接迁移到更长的句子中。
  • layer normalization可以使得前向传播的输入分布变得稳定,同时使得后向的梯度更加稳定。

BertEncoder类

由多个BertLayer组成。

BertLayer

在每层layer会利用gradient checkpointing(梯度检查点),通过减少保存的计算图节点压缩模型占用空间,降低训练的显存占用。调用函数为:torch.utils.checkpoint.checkpoint。

BertAttention
  • self:多头注意力的实现

  • output:实现 attention 后的多头全连接+dropout+residual+LayerNorm 一系列操作

  • pruned_heads:剪枝操作

class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs
BertSelfAttention
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

注意:

  • hidden_size 和 all_head_size 在一开始是一样的。剪枝操作后以后 all_head_size会减小;
  • hidden_size 必须是 num_attention_heads 的整数倍,以 bert-base 为例,每个 attention 包含 12 个 head,hidden_size 是 768,所以每个 head 大小即 attention_head_size=768/12=64;
  • 对于不同的positional_embedding_type,有三种操作:
    • absolute:默认值,这部分就不用处理;
    • relative_key:对 key_layer 作处理,将其与这里的positional_embedding和 key 矩阵相乘作为 key 相关的位置编码;
    • relative_key_query:对 key 和 value 都进行相乘以作为位置编码。
BertSelfOutput

这里是先 Dropout,在进行残差连接,最后再进行 LayerNorm。残差连接的目的就是降低网络层数过深带来的训练难度,对原始输入更加敏感。

class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
BertIntermediate
  • 全连接:做一个扩展,以 bert-base 为例,扩展维度为 3072,是原始维度 768 的 4 倍之多;
  • 激活函数默认实现为 gelu(Gaussian Error Linerar Units(GELUS),它是无法直接计算的,可以用一个包含tanh的表达式进行近似。
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states
BertOutput

和BertSelfOutput一样,进行全连接、残差连接、dropout、LayerNorm。

BertPooler类

取出句子的第一个token(即[CLS]对应的向量),然后过一个全连接层和一个激活函数后输出

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output  

参考

datawhale文档

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值