Bert理解
模型模块代码
forward主函数常用参数介绍
input_ids:表示输入的token ID序列,通常由Tokenizer生成,每个token ID对应一个词汇表中的词;
attention_mask:用于指示哪些token是有效的,值为1表示有效token,值为0表示填充值(padding),主要用于句子尾部;
position_ids:指定每个token的位置ID,默认情况下模型会自动生成,但也可以手动提供以控制位置编码;
inputs_embeds:直接提供输入的嵌入表示,而不是通过 input_ids 生成;
encoder_embeds:如果BERT模型作为decoder使用,这是从编码器接收到的嵌入表示
encoder_hidden_states:来自encoder最后一层的输出的隐藏状态序列,常用于图文多模态时候加载图像特征
encoder_attention_mask:指示哪些编码器输出是有效的
mode:text只是文本模态,multimodal用于图文多模态
BertEmbedding
(1)判断输入的是input_ids还是inputs_embeds,如果是input_ids,需要使用word_embeddings进行文本embedding操作;
(2)将position_ids同样进行position_embeddings,之后和文本embedding相加得到维度为[batch, seq_len, channel] ;
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, max_position_embeddings)
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
embeddings = inputs_embeds
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings) # [batch, seq_len, channel]
return embeddings
BertEncoder(BertLayer)
主要执行num_hidden_layers次 BertLayer 该网络模块
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.attention = BertAttention(config)
self.layer_num = layer_num
if self.config.add_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
mode=None,
):
# 先对text向量进行自注意力操作,返回维度为[]的特征;主要看BertSelfAttention有关单文本模态的自注意力操作和图文多模态时候的额外交叉注意力操作
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
# 自注意力操作返回的特征,如果多模态,送入交叉注意力
attention_output = self_attention_outputs[0]
if mode=='multimodal':
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
# 残差连接,加上交叉注意力返回的特征
outputs = outputs + cross_attention_outputs[0]
return outputs
BertSelfAttention
主要根据是否是多模态,使用不同的key和value对文本或者图像特征进行映射
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
# 如果多模态,继续用为视觉特征设定key和value的线性层
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# 先对文本特征(hidden_states)进行query映射
mixed_query_layer = self.query(hidden_states)
is_cross_attention = encoder_hidden_states is not None
# 根据多模态使用各自的key和value进行映射
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# 计算相似度的值
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# softmax注意力矩阵
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# 加权
context_layer = torch.matmul(attention_probs_dropped, value_layer)