从源码解析 Bert 的 BertEncoder 模块

31 篇文章 3 订阅
23 篇文章 12 订阅

前文链接->从源码解析 Bert 的 Embedding 模块

上一篇文章解析了 Bert 的 BertEmbedding 模块,接下来分析 bert的第二个重要模块 BertEncoder

BertEncoder源码

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        for layer_module in self.layer:
            hidden_states = layer_module(hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers

1. __ init __(self, config) 函数

首先定义了一个 Bert 的encoder层,然后将这个encoder层复制了N层(深拷贝num_hidden_layers 层)最后将这N层网络堆叠在一起组成 BertEncoder 模块的网络结构

2. forward()函数

forward()函数的输入分别为:

  • hiddenstates:BertEmbedding 模块的输出,输入token的embedding,大小为[batch_size,序列长度,embedding维度]
  • attention_mask:标识序列的真实长度,0表示真实的数据,1表示padding的数据
  • output_all_encoded_layers:True表示返回encoder所有中间层的输出,False表示只返回最后一层encoded模块的输出

forward()函数中使用for循环让数据依次通过每一个encoder,然后根据output_all_encoded_layers决定是否返回每一层encoder的输出

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

energy_百分百

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值