BertSelfAttention
1. init函数
1)得到3个变量
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
all_head_size = 768
attention_head_size=768/12=64(每个头的大小)
num_attention_heads=12
2)K Q V的定义
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
config.hidden_size=768,self.all_head_size=768,nn.linear(a,b)指定了传入维度768,传出维度768
K Q V的维度为:(1,128,768)即(batch_size,seq_len,hid_size)
3)
2. transpose_for_scores 函数
作用:将self. K Q V 进行multihead拆分,拆成12个头
拆分后的K Q V 保存在key_layer, query_layer, value_layer 中,维度为(1,12,128,64)
在forward函数里有:
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_key_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
这一步调用了transpose_for_scores函数
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
(0,1,2,3)-->(0,2,1,3):
(batch_size,seq_len,num_heads,attn_head_size)--->(batch_size,num_heads,seq_len,attn_head_size)
即(1,12,128,64)
(1, 128, 768)
↓
self.transpose_for_scores()
↓
(1, 12, 128, 64)
3. forward函数
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_key_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
以上都解释完了,下面进行attention_score的计算
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = F.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
下面是context layer的计算 : 计算出的attn_probs 要乘以 输入矩阵V(V和K一样)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = torch.reshape(context_layer, new_context_layer_shape)
return context_layer
attention_probs * value_layer = context_layer
(1, 12, 128, 128) (1, 12, 128, 64) (1, 12, 128, 64)
之后对context_layer进行一些形状上的修剪:
context_layer.size()[:-2] + (self.all_head_size,)
维度位置移一下,(0, 2,1, 3)取后两个维度,并加上768,拼合后 context_layer 的维度
context_layer_shape (1, 128, 768)
最后,输出context_layer