2021SC@SDUSC
前面分析了两个基础的类,下面对BertEncoder4Mix进行分析。
BertEncoder4Mix是本文件的核心代码,其主要实现的功能是选择在哪一层进行混合。
def forward(self, hidden_states, hidden_states2=None, lbeta=None, mix_layer=1000, attention_mask=None, attention_mask2=None, head_mask=None):
真正的混合隐藏层则是在这里了
hidden_states: 第一个输入的隐藏状态 hidden_states2:第二个输入的隐藏状态 lbeta: beta分布取的值 mix_layer: 要mix的layer,例如这里是bert的第11层, 当不做mix时,设置默认mix_layer为1000,目的是为了下面循环不做mix attention_mask: 输入1的 attention_mask attention_mask2: 输入2的 attention_mask
all_hidden_states = ()
all_attentions = ()
这里是保存每一层的隐藏层状态和注意力,放入列表中
首先是在mix_layer层进行的混合
下面这段即是论文中h=lh+(1-l)h'
if mix_layer == -1:
if hidden_states2 is not None:
hidden_states = l * hidden_states + (1-l)*hidden_states2
for i, layer_module in enumerate(self.layer):
当当前层小于或等于mix_layer时,inputs1 和inputs2 分别计算hidden_states
if i <= mix_layer:
是否输出隐藏状态,如果现在的layer小于要mix_layer,那么
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
调用transformers的BertLayer, 计算attention和hidden_states, 输出outputs
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i])
获取这一次计算后的得到新的隐藏层状态
hidden_states = layer_outputs[0]
如果self.output_attentions 为True,输出attention
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
如果输入2存在,也做同样计算
if hidden_states2 is not None:
layer_outputs2 = layer_module(hidden_states2, attention_mask2, head_mask[i])
hidden_states2 = layer_outputs2[0]
有当循环到等于mix_layer时,使用mixup公式混合
if i == mix_layer:
if hidden_states2 is not None:
hidden_states = lbeta * hidden_states + (1 - lbeta) * hidden_states2
循环到大于mix_layer的层时, 普通方式计算
if i > mix_layer:
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
最后一层结束后的hidden_states
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs是最后一层的的hidden_states[batch_size,seq_len,Embedding_demision]
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
最后一层的hidden state, (all hidden states), (all attentions)
return outputs