小黑开题结束继续征程:LEBertModel整体架构

from transformers import BertPreTrainedModel
from transformers.modeling_bert import BERT_INPUTS_DOCSTRING,_TOKENIZER_FOR_DOC,_CONFIG_FOR_DOC
from transformers.modeling_bert import BertEmbeddings,BertConfig,BertPooler,BertLayer,BaseModelOutput,BaseModelOutputWithPooling
from lebert import BertEncoder
import torch
from transformers.file_utils import (
    add_code_sample_docstrings,
    add_start_docstrings_to_callable,
)

class LEBertModel(BertPreTrainedModel):
    
    def __init__(self,config):
        super().__init__(config)
        self.config = config
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.init_weights()
        
    def get_input_embeddings(self):
        return self.embeddings.word_embeddings
    
    def set_input_embeddings(self,value):
        self.embeddings.word_embeddings = value
    # 提供了将注意力头剪枝的函数
    def _prune_heads(self,heads_to_prune):
        for layer,heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)
    
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format('(batch_size,sequence_length)'))
    @add_code_sample_docstrings(
        tokenizer_class = _TOKENIZER_FOR_DOC,
        checkpoint = 'bert-base-uncased',
        output_type = BaseModelOutputWithPooling,
        config_class = _CONFIG_FOR_DOC,
    )
    def forward(self,
                input_ids = None,    
                attention_mask = None,
                token_type_ids = None,
                word_embeddings = None,
                word_mask = None,
                position_ids = None,
                head_mask = None,
                inputs_embeds = None,
                encoder_hidden_states = None,
                encoder_attention_mask = None,
                output_attentions = None,
                output_hidden_states = None,
                return_dict = None
               ):
        # input_ids:[batch_size,max_len]
        # attention_mask:[batch_size,max_len]
        # token_type_ids:[batch_size,max_len]
        # word_embeddings:[batch_size,max_len,num_words,word_dim]
        # word_mask:[batch_size,max_len,num_words]
        # position_ids,head_mask,inputs_embeds,encoder_hidden_states,encoder_attention_mask,output_attentions,output_hidden_states,return_dict:None
        
        # output_attentions:False
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # output_hidden_states:False
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # return_dict:False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time.')
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError('You have to specify either input_ids or inputs_embeds')
        
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        
        if attention_mask is None:
            attention_mask = torch.ones(input_shape,device = device)
        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape,dtype = torch.long,device = device)
        # entended_attention_mask:[batch_size,1,1,max_len]
        extended_attention_mask:torch.Tensor = self.get_extended_attention_mask(attention_mask,input_shape,device)
        
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size,encoder_sequence_length,_ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size,encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape,device = device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None
        
        # head_mask:[num_hidden_layers]
        head_mask = self.get_head_mask(head_mask,self.config.num_hidden_layers)
        # embedding_output:[batch_size,max_len,bert_dim]
        embedding_output = self.embeddings(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            inputs_embeds = inputs_embeds
        )
        # 输入:
        # embedding_output:[batch_size,max_len,bert_dim]
        # word_embeddings:[batch_size,max_len,num_words,word_dim]
        # word_mask:[batch_size,max_len,num_words]
        # extended_attention_mask:[batch_size,1,1,max_len]
        # head_mask:[num_hidden_layers]
        # encoder_hidden_states,encoder_attention_mask:None,output_attentions,output_hidden_states,return_dict:False
        
        # 输出:
        # encoder_outputs:([batch_size,max_len,bert_dim],)
        encoder_outputs = self.encoder(
            embedding_output,
            word_embeddings = word_embeddings,
            word_mask = word_mask,
            attention_mask = extended_attention_mask,
            head_mask = head_mask,
            encoder_hidden_states = encoder_hidden_states,
            encoder_attention_mask = encoder_extended_attention_mask,
            output_attentions = output_attentions,
            output_hidden_states = output_hidden_states,
            return_dict = return_dict
        )
        # sequence_output:[batch_size,max_len,bert_dim]
        sequence_output = encoder_outputs[0]
        # pooled_output:[batch_size,bert_dim]
        pooled_output = self.pooler(sequence_output)
        
        if not return_dict:
            return (sequence_output,pooled_output) + encoder_outputs[1:]
        return BaseModelOutputWithPooling(
            last_hidden_state = sequence_output,
            pooler_output = pooled_output,
            hidden_states = encoder_outputs.hidden_states,
            attention = encoder_outputs.attentions
        )

pretrain_model_path = 'bert-base-chinese'
config = BertConfig.from_pretrained(pretrain_model_path)
config.word_embed_dim = 200
config.add_layer = 0
model = LEBertModel(config)
num_words = 3
max_len = 10
bert_dim = 768
batch_size = 4
input_ids = torch.randint(0,100,[batch_size,max_len])
word_embeddings = torch.randn([batch_size,max_len,num_words,config.word_embed_dim])
word_mask = torch.ones([batch_size,max_len,num_words]).long()
attention_mask = torch.ones([batch_size,max_len]).byte()
token_type_ids = torch.zeros([batch_size,max_len]).long()


outputs = model(
                input_ids = input_ids,    
                attention_mask = attention_mask,
                token_type_ids = token_type_ids,
                word_embeddings = word_embeddings,
                word_mask = word_mask,
                position_ids = None,
                head_mask = None,
                inputs_embeds = None,
                encoder_hidden_states = None,
                encoder_attention_mask = None,
                output_attentions = None,
                output_hidden_states = None,
                return_dict = None
            )
print('outputs[0].shape',outputs[0].shape)
print('outputs[1].shape',outputs[1].shape)
输出:

outputs[0].shape torch.Size([4, 10, 768])
outputs[1].shape torch.Size([4, 768])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值