from transformers import BertPreTrainedModel
from transformers.modeling_bert import BERT_INPUTS_DOCSTRING,_TOKENIZER_FOR_DOC,_CONFIG_FOR_DOC
from transformers.modeling_bert import BertEmbeddings,BertConfig,BertPooler,BertLayer,BaseModelOutput,BaseModelOutputWithPooling
from lebert import BertEncoder
import torch
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings_to_callable,
)
class LEBertModel(BertPreTrainedModel):
def __init__(self,config):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self,value):
self.embeddings.word_embeddings = value
# 提供了将注意力头剪枝的函数
def _prune_heads(self,heads_to_prune):
for layer,heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format('(batch_size,sequence_length)'))
@add_code_sample_docstrings(
tokenizer_class = _TOKENIZER_FOR_DOC,
checkpoint = 'bert-base-uncased',
output_type = BaseModelOutputWithPooling,
config_class = _CONFIG_FOR_DOC,
)
def forward(self,
input_ids = None,
attention_mask = None,
token_type_ids = None,
word_embeddings = None,
word_mask = None,
position_ids = None,
head_mask = None,
inputs_embeds = None,
encoder_hidden_states = None,
encoder_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None
):
# input_ids:[batch_size,max_len]
# attention_mask:[batch_size,max_len]
# token_type_ids:[batch_size,max_len]
# word_embeddings:[batch_size,max_len,num_words,word_dim]
# word_mask:[batch_size,max_len,num_words]
# position_ids,head_mask,inputs_embeds,encoder_hidden_states,encoder_attention_mask,output_attentions,output_hidden_states,return_dict:None
# output_attentions:False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# output_hidden_states:False
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# return_dict:False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time.')
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds')
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape,device = device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,dtype = torch.long,device = device)
# entended_attention_mask:[batch_size,1,1,max_len]
extended_attention_mask:torch.Tensor = self.get_extended_attention_mask(attention_mask,input_shape,device)
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size,encoder_sequence_length,_ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size,encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape,device = device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# head_mask:[num_hidden_layers]
head_mask = self.get_head_mask(head_mask,self.config.num_hidden_layers)
# embedding_output:[batch_size,max_len,bert_dim]
embedding_output = self.embeddings(
input_ids = input_ids,
position_ids = position_ids,
token_type_ids = token_type_ids,
inputs_embeds = inputs_embeds
)
# 输入:
# embedding_output:[batch_size,max_len,bert_dim]
# word_embeddings:[batch_size,max_len,num_words,word_dim]
# word_mask:[batch_size,max_len,num_words]
# extended_attention_mask:[batch_size,1,1,max_len]
# head_mask:[num_hidden_layers]
# encoder_hidden_states,encoder_attention_mask:None,output_attentions,output_hidden_states,return_dict:False
# 输出:
# encoder_outputs:([batch_size,max_len,bert_dim],)
encoder_outputs = self.encoder(
embedding_output,
word_embeddings = word_embeddings,
word_mask = word_mask,
attention_mask = extended_attention_mask,
head_mask = head_mask,
encoder_hidden_states = encoder_hidden_states,
encoder_attention_mask = encoder_extended_attention_mask,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict
)
# sequence_output:[batch_size,max_len,bert_dim]
sequence_output = encoder_outputs[0]
# pooled_output:[batch_size,bert_dim]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output,pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state = sequence_output,
pooler_output = pooled_output,
hidden_states = encoder_outputs.hidden_states,
attention = encoder_outputs.attentions
)
pretrain_model_path = 'bert-base-chinese'
config = BertConfig.from_pretrained(pretrain_model_path)
config.word_embed_dim = 200
config.add_layer = 0
model = LEBertModel(config)
num_words = 3
max_len = 10
bert_dim = 768
batch_size = 4
input_ids = torch.randint(0,100,[batch_size,max_len])
word_embeddings = torch.randn([batch_size,max_len,num_words,config.word_embed_dim])
word_mask = torch.ones([batch_size,max_len,num_words]).long()
attention_mask = torch.ones([batch_size,max_len]).byte()
token_type_ids = torch.zeros([batch_size,max_len]).long()
outputs = model(
input_ids = input_ids,
attention_mask = attention_mask,
token_type_ids = token_type_ids,
word_embeddings = word_embeddings,
word_mask = word_mask,
position_ids = None,
head_mask = None,
inputs_embeds = None,
encoder_hidden_states = None,
encoder_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None
)
print('outputs[0].shape',outputs[0].shape)
print('outputs[1].shape',outputs[1].shape)
输出:
outputs[0].shape torch.Size([4, 10, 768])
outputs[1].shape torch.Size([4, 768])