Bertmodel
## reference: transformers.modeling_bert.BertModel
class BertModel(BertPreTrainedModel):
def __init__(self, config):
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
...
def forward(self, input_ids, attention_mask=None, token_type_ids=None,position_ids=None, head_mask=None):
...
### 第一部分,对 attention_mask 进行操作,并对输入做embedding
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_id