bert的pytorch实现

这个实现并不完整,只是bert的基础组件,代码中只是我自己的理解和实现,可能有不正确的地方(还有些地方还没理解就没写),具体nsp,mlm等下游任务模型构建代码请参考huggingface的transformers

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import inspect
# BertForMaskedLM - BertModel,BertOnlyMLMHead

# BertModel - BertEmbedding, BertEncoder(no BertPooler)
# BertOnlyMLMHead - BertLMPredictionHead

# BertEncoder - BertLayer
# BertLayer - BertAttention, BertIntermediate, BertOutput

class Config(object):
    def __init__(self):
        self.vocab_size = 1000
        self.hidden_size = 768
        self.max_position_embeddings = 32  # 输入句子的最大长度
        self.hidden_dropout_prob = 0.5
        self.layer_norm_eps = 1e-6  # layer_norm中防止分母为0
        self.type_vocab_size = 1  # 输入句子的类别,例如问和答
        self.pad_token_id = 0
        self.num_attention_heads = 12
        self.attention_probs_dropout_prob = 0.5
        self.position_embedding_type = "absolute"
        self.intermediate_size = 512
        self.hidden_act = "relu"
        self.chunk_size_feed_forward = 0
        self.num_hidden_layers = 12


class BertEmbedding(nn.Module):
    def __init__(self, config):
        super(BertEmbedding, self).__init__()
        # 可学习的三个embeddings
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)  # 一个段落中最长的position
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)  # token_type分别句子的类别
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.position_embeddings_type = getattr(config, "position_embeddings_type", "absolute")
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
        # 注册到内存缓冲区中,参与模型训练,但参数不更新,不保存和恢复,只在每次创建模型实例时动态生成
        self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False)

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
        # 在BERT的某些变种中(例如,用于序列生成的BERT),模型可能需要用到之前的键值对来计算当前的输出。past_key_values_length参数告诉模型要从前一个时间步开始使用多少键值对,还没遇到
        if input_ids is None:  # 如果input_ids没有输入,输入在input_embeds
            input_shape = inputs_embeds.size()[:-1]
        else:
            input_shape = input_ids.size()
        seq_length = input_shape[1]
        if position_ids is None:  # 输入没有position_ids从缓冲提取
            position_ids = self.position_ids[:, past_key_values_length:seq_length+past_key_values_length]  # [1, seq_len]
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                buffer_token_type_ids = self.token_type_ids[:, :seq_length]  # [1,seq_len]
                token_type_ids_expanded = buffer_token_type_ids.expand(input_shape[0], seq_length)  # [bs, seq_len]
                token_type_ids = token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)  # 重新申请的参数需要把它放到device上
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embeddings_type == "absolute":  # 是否使用绝对位置嵌入
            position_embeddings = self.position_embeddings(position_ids)  # [1, hidden_size]
            embeddings = embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):  # bert的自注意力机制
    def __init__(self, config, position_embedding_type=None):
        super(BertSelfAttention, self).__init__()
        assert config.hidden_size % config.num_attention_heads == 0
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = config.hidden_size // config.num_attention_heads
        self.all_head_size = config.hidden_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")

    def transpose_for_scores(self, x):
        # [bs, seq-len, head, head_size
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_shape)
        return x.permute(0, 2, 1, 3)  # [bs, head, seq_len, head_size]

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        # hugggingface的源码中还有encoder_inputs等参数,现在还不知道有什么用,等用到了再说
        # head_mask是用于将某些注意力头的计算无效化,研究不同的注意力头对模型性能的影响
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.all_head_size)  # [bs, head, seq_len, seq_len]
        if attention_mask is not None:
            attention_scores += attention_mask
            # 加上attention_mask,其中的元素如果被mask就为-inf
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        if head_mask is not None:
            attention_probs = attention_probs * head_mask
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        view_shape = context_layer.size()[: -2] + (self.all_head_size, )
        outputs = context_layer.view(view_shape)
        return (outputs, )  # 输入和输出shape一样,[bs, seq_len, hidden_size]

class BertSelfOutput(nn.Module):  # 进行add和norm,注意力机制的output
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int=0):  
    # 对线性层的结构进行裁剪
    # layer.weight.size() = (output_feature, input_feature)
    # 减少某些输出结点和输入结点的连接
    index = index.to(layer.weight.device)
    W = layer.weight.index_select(dim, index).clone().detach()
    # dim=0,保留某些行,dim=1,保留某些列(即认为输出结点都保留,bias也都保留)
    if layer.bias is not None:  # bias [output_features]
        if dim == 1:  # 在输入特征维度上进行选择,偏置项不需要进行修剪,保留了所有输出特征对应的偏置项
            b = layer.bias.clone().detach()
        else:  # 如果为0,那么对于bias的1*output_feature的矩阵来说
            b = layer.bias[index].clone().detach()
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    new_layer.wieght.requires_grad = False  # 避免在复制时产生梯度计算
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True
    if layer.bias is not None:
        new_layer.bias.requires_grad = False
        new_layer.bias.copy_(b.contiguous())
        new_layer.bias.requires_grad = True
    return new_layer

class BertAttention(nn.Module):
    # 这里面其实还有q,k,v中linear的裁剪,写到这里时还没看见用法,所以就没写
    def __init__(self, config):
        super().__init__()
        self.attention = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
    
    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        self_outputs = self.attention(hidden_states, attention_mask, head_mask)  # 自注意力机制的输出应该是元组,outputs[0]代表的是自注意力机制的输出(没注意)
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # 这里返回的也是元组
        return outputs

ACT2FN = {}  # 这是激活函数的字典

class BertIntermediate(nn.Module):  # 增加模型的非线性,改变隐藏状态的维度(激活函数)
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
    
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

class BertOutput(nn.Module):  # intermeidate和bertoutput一起使用通过激活函数和正则化,残差连接
    def __init__(self, config):
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)  # 维度又回到了hidden_size
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

def apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors):
    # 将输入张量按照指定的维度和块大小进行切分,然后独立地对每个块应用一个前向传播函数(forward_fn),以节省内存使用
    # 参数检查,张量切分,分块处理,结果组合
    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
    if num_args_in_forward_chunk_fn != len(input_tensors):  # 函数的输入参数量是否和input_tensor匹配
        raise ValueError(
            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
            "tensors are given"
        )
    if chunk_size > 0:
        tensor_shape = input_tensors[0].shape[chunk_dim]
        for input_tensor in input_tensors:  # 每一个元素在该维度上的维数是否一样
            if input_tensor.shape[chunk_dim] != tensor_shape:
                raise ValueError(
                    f"All input tenors have to be of the same shape: {tensor_shape}, "
                    f"found shape {input_tensor.shape[chunk_dim]}"
                )
        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
            raise ValueError(
                f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
                f"size {chunk_size}"
            )
        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
        return torch.cat(output_chunks, dim=chunk_dim)
    return forward_fn(*input_tensors)
        


class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        # 没使用交叉注意力机制
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

    def forward(self, hidden_states, attention_mask, head_mask):
        self_attention_output = self.attention(hidden_states, attention_mask, head_mask)
        attention_output = self_attention_output[0]
        outputs = self_attention_output[1:]
        layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output)
        outputs = (layer_output,) +outputs
        return outputs

class BertEncoder(nn.Module):
    def __init__(self, config):
        self.config = config
        self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
    
    def forward(self, hidden_states, attention_mask=None, output_hidden_states=False):
        # output_hidden_states是否要输出每一个隐藏层的状态
        all_hidden_states = () if output_hidden_states else None
        for layer in self.layers:
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            hidden_states = layer(hidden_states, attention_mask)
        if output_hidden_states:  #  最后一层的hidden_states
            all_hidden_states += (hidden_states,)
        return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)   # 返回可能包含两个元素的元组


class BertPooler(nn.Module):  # 一个分类器
    def __init__(self, config):
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
    
    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class BertPredictionHeadTransform(nn.Module):  # transformer的预测头
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    
    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class BertLMPredictionHead(nn.Module):  # 大模型的预测头,将数据放缩为词表大小
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        self.decoder.bias = self.bias
    
    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):  # nsp任务的头
    def __init__(self, config):
        super().__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)  # 上一句是否跟下一句有联系

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score

class BertPreTrainedModel(PreTrainedModel):
    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()  # nn.Embedding层的内部实现其实是查表 # seq_len中是对应元素的下标,也就是embedding权重矩阵中对应的行
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

class BertModel(BertPreTrainedModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__()
        self.config = config
        self.embeddings = BertEmbedding(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config) if add_pooling_layer else None
        self.init_weights()
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, output_hidden_states=None):
        if attention_mask = None:
            attention_mask = torch.ones_like(input_ids)
            attention_mask[input_ids==self.config.pad_token_id] = 0
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        # 利用广播机制适应[bs, head, seq_len, seq_len]
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask.float()
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids
        )
        
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            output_hidden_states=output_hidden_states
        )
        
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
        return (sequence_output, pooled_output) + encoder_outputs[1:]

下游任务的模型构建代码请参考huggingface的transformers

  • 13
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值