文本分类半监督学习(十)

2021SC@SDUSC​​​​​​​​​​​​

前面分析了两个基础的类,下面对BertEncoder4Mix进行分析。

BertEncoder4Mix是本文件的核心代码,其主要实现的功能是选择在哪一层进行混合。

def forward(self, hidden_states, hidden_states2=None, lbeta=None, mix_layer=1000, attention_mask=None, attention_mask2=None, head_mask=None):

真正的混合隐藏层则是在这里了

hidden_states: 第一个输入的隐藏状态 hidden_states2:第二个输入的隐藏状态 lbeta: beta分布取的值 mix_layer: 要mix的layer,例如这里是bert的第11层, 当不做mix时,设置默认mix_layer为1000,目的是为了下面循环不做mix attention_mask: 输入1的 attention_mask attention_mask2: 输入2的 attention_mask

       all_hidden_states = ()
        all_attentions = ()

这里是保存每一层的隐藏层状态和注意力,放入列表中

首先是在mix_layer层进行的混合

下面这段即是论文中h=lh+(1-l)h'

     if mix_layer == -1:
            if hidden_states2 is not None:
                hidden_states = l * hidden_states + (1-l)*hidden_states2
 for i, layer_module in enumerate(self.layer):

当当前层小于或等于mix_layer时,inputs1 和inputs2 分别计算hidden_states

if i <= mix_layer:

是否输出隐藏状态,如果现在的layer小于要mix_layer,那么

                if self.output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)

调用transformers的BertLayer, 计算attention和hidden_states, 输出outputs

 layer_outputs = layer_module(
                    hidden_states, attention_mask, head_mask[i])

获取这一次计算后的得到新的隐藏层状态

hidden_states = layer_outputs[0]

如果self.output_attentions 为True,输出attention

                if self.output_attentions:
                    all_attentions = all_attentions + (layer_outputs[1],)

如果输入2存在,也做同样计算

if hidden_states2 is not None:
layer_outputs2 = layer_module(hidden_states2, attention_mask2, head_mask[i])
hidden_states2 = layer_outputs2[0]

有当循环到等于mix_layer时,使用mixup公式混合

if i == mix_layer:
if hidden_states2 is not None:
hidden_states = lbeta * hidden_states + (1 - lbeta) * hidden_states2

循环到大于mix_layer的层时, 普通方式计算​​​​​​​

if i > mix_layer:
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0]

if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

最后一层结束后的hidden_states

if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs是最后一层的的hidden_states[batch_size,seq_len,Embedding_demision]

outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)

最后一层的hidden state, (all hidden states), (all attentions)

return outputs

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值