你熟悉的模型可以更快线性复杂度的Transformer模型Perfomer

一种新的Transformer架构:Performer

论文地址

Transformer充分发挥了注意力机制潜力,基于Transformer架构的模型(Bert,GPT,Plato,Erine等)在NLP领域取得了非常巨大的成功。然而,由于Transformer对于输入序列需要构建一个与序列长度等长的注意力矩阵,其内存和计算量会随着序列长度的增加而呈现平方增长,这使得长序列的处理代价很大。Performer使用谷歌提出的FAVOR算法(Fast Attention Via Positive Orthogonal Random Feature) 为注意力机制提供了一种可拓展的低方差、无偏估计,同时保证空间和时间的复杂度是接近线性的。

感兴趣的可以参考源文章。这个项目提供了该架构的Paddle代码实现(原文附有源代码)。

代码及测试

请解压Performers.zip获得Paddle代码。代码结构如下:

-----Performers

      -------paddlepaddle  // Paddle代码

           ----------fast_attention.py  // Performer核心代码
           
           ----------fast_attention_test.py  // 简单测试代码
           
           ----------transformer.py  // 基于Performer实现的Transformer
   
				----------bert  // 基于Performer实现的Bert模型
                
       --------tf2  // Performer源代码


```python
!unzip -o Performers.zip

如下,调用fast_attention_test.py,我们可以看到Performer模型相对于Transformer的加速效果。其中加速程度用speed up ratio来表示,为运行完单个注意力模块花费的时间之比(time_transformer / time_performer)。结果可以看出,当序列长度length > 1000时,Performer就可以发挥非常可观的加速效果;但是在1000以下时,加速的效果并不显著,特别是在序列长度本来就不长时,效果甚至差于Transformer。因此,希望用此模型加速的需要根据实际需求,其实在实际中,序列长度超过1000的情况并不多见。

from Performers.paddlepaddle import fast_attention_test
import prettytable

test = fast_attention_test.TransformerLayersTest()

lengthes = [128, 256, 512, 1024, 2048, 4096, 8192]
ratio = []
for l in lengthes:
    ratio.append(test.test_softmax_noncausal_attention_block_output(l))

table = prettytable.PrettyTable()
table.field_names = ['length'] + [str(l) for l in lengthes]
table.add_row(['speed up ratio'] + ['%.5f'%r for r in ratio])
print(table)
+----------------+---------+---------+---------+---------+---------+----------+----------+
|     length     |   128   |   256   |   512   |   1024  |   2048  |   4096   |   8192   |
+----------------+---------+---------+---------+---------+---------+----------+----------+
| speed up ratio | 0.15373 | 0.35420 | 0.91381 | 3.22711 | 7.37718 | 14.36226 | 28.28287 |
+----------------+---------+---------+---------+---------+---------+----------+----------+

实例

下面以之前的项目用BERT实现自动写诗来示例如何使用基于Performer的模型。

数据准备

# 下载诗歌数据集 (从镜像网站github.com.cnpmjs.org下载可提高下载速度)
!git clone https://github.com.cnpmjs.org/chinese-poetry/chinese-poetry
# 下载繁体转简体工具
!git clone https://github.com.cnpmjs.org/fiyen/cht2chs
fatal: destination path 'chinese-poetry' already exists and is not an empty directory.
fatal: destination path 'cht2chs' already exists and is not an empty directory.
import os
import json
import re
from cht2chs.langconv import cht_to_chs

def sentenceParse(para):
    """
    剔除诗歌字符中的非文字符号以及数字
    """
    result, number = re.subn(u"(.*)", "", para)
    result, number = re.subn(u"{.*}", "", result)
    result, number = re.subn(u"《.*》", "", result)
    result, number = re.subn(u"《.*》", "", result)
    result, number = re.subn(u"[\]\[]", "", result)
    r = ""
    for s in result:
        if s not in set('0123456789-'):
            r += s
    r, number = re.subn(u"。。", u"。", r)
    return r


def data_preprocess(poem_dir='./chinese-poetry/json', len_limit=120):
    """
    预处理诗歌数据,返回符合要求的诗歌列表
    """
    poems = []
    for f in os.listdir(poem_dir)[:1]:
        if f.endswith('.json'):
            json_data = json.load(open(os.path.join(poem_dir, f)))
            for d in json_data:
                try:
                    poem = ''.join(d['paragraphs'])
                    poem = sentenceParse(poem)
                    # 控制长度,并将繁体字转换为简体字
                    if len(poem) <= len_limit:
                        poems.append(cht_to_chs(poem))
                except:
                    continue
    return poems
<>:14: DeprecationWarning: invalid escape sequence \]
<>:14: DeprecationWarning: invalid escape sequence \]
<>:14: DeprecationWarning: invalid escape sequence \]
<ipython-input-16-c76de160f7ba>:14: DeprecationWarning: invalid escape sequence \]
  result, number = re.subn(u"[\]\[]", "", result)
from paddlenlp.transformers import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
[2021-04-22 20:51:45,286] [    INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt
[2021-04-22 20:51:45,286] [    INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt
import paddle
from paddle.io import Dataset
import numpy as np

class PoemData(Dataset):
    """
    构造诗歌数据集,继承paddle.io.Dataset
    Parameters:
        poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码
        max_len: 接收诗歌的最大长度
    """
    def __init__(self, poems, tokenizer, max_len=128):
        super(PoemData, self).__init__()
        self.poems = poems
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __getitem__(self, idx):
        line = poems[idx]
        token_line = self.tokenizer.encode(line)
        token, token_type = token_line['input_ids'], token_line['token_type_ids']
        if len(token) > self.max_len + 1:
            token = token[:self.max_len] + token[-1]
            token_type = token_type[:self.max_len] + token_type[-1]
        input_token, input_token_type = token[:-1], token_type[:-1]
        label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')
        # 输入填充
        input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_pad_mask = (input_token != 0).astype('float32')
        return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask
    
    def __len__(self):
        return len(self.poems)

构建模型

以下以简写‘T’代表Transformer,‘P’代表Performer,分别构造诗歌生成模型,来进行效果对比。

其中,BertP为调用基于Performer的Transformer后构造的Bert模型,由于Performer没有改变注意力模型的结构,因此仍然可以使用预训练模型的参数。

from paddle.nn import Layer, Linear, Softmax
from paddlenlp.transformers import BertModel as BertT
from paddlenlp.transformers import BertForTokenClassification as BertClassT

class PoetryBertModelT(Layer):
    """
    基于BERT预训练模型的诗歌生成模型
    """
    def __init__(self, pretrained_bert_model: str, input_length: int):
        super(PoetryBertModelT, self).__init__()
        bert_model = BertT.from_pretrained(pretrained_bert_model)
        self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
        self.bert_for_class = BertClassT(bert_model, self.vocab_size)
        # 生成下三角矩阵,用来mask句子后边的信息
        self.sequence_length = input_length
        self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))

    def forward(self, token, token_type, input_mask, input_length=None):
        # 计算attention mask
        mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
        mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
        # 输入句子中有效的位置
        mask_left = paddle.cast(mask_left, 'float32')
        mask_right = paddle.cast(mask_right, 'float32')
        attention_mask = paddle.matmul(mask_left, mask_right)
        # 注意力机制计算中有效的位置
        if input_length is not None:
            lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
        else:
            lower_triangle_mask = self.lower_triangle_mask
        attention_mask = attention_mask * lower_triangle_mask
        # 无效的位置设为极小值
        attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
        attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)

        output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
        
        return output_logits
from paddle.nn import Layer, Linear, Softmax
from Performers.paddlepaddle.bert.modeling import BertModel as BertP
from Performers.paddlepaddle.bert.modeling import BertForTokenClassification as BertClassP

class PoetryBertModelP(Layer):
    """
    基于BERT预训练模型的诗歌生成模型
    """
    def __init__(self, pretrained_bert_model: str, input_length: int):
        super(PoetryBertModelP, self).__init__()
        bert_model = BertP.from_pretrained(pretrained_bert_model)
        self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
        self.bert_for_class = BertClassP(bert_model, self.vocab_size)
        # 生成下三角矩阵,用来mask句子后边的信息
        self.sequence_length = input_length
        self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))

    def forward(self, token, token_type, input_mask, input_length=None):
        # 计算attention mask
        mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
        mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
        # 输入句子中有效的位置
        mask_left = paddle.cast(mask_left, 'float32')
        mask_right = paddle.cast(mask_right, 'float32')
        attention_mask = paddle.matmul(mask_left, mask_right)
        # 注意力机制计算中有效的位置
        if input_length is not None:
            lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
        else:
            lower_triangle_mask = self.lower_triangle_mask
        attention_mask = attention_mask * lower_triangle_mask
        # 无效的位置设为极小值
        attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
        attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)

        output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
        
        return output_logits
class PoetryBertModelLossCriterion(Layer):
    def forward(self, pred_logits, label, input_mask):
        loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')
        masked_loss = paddle.mean(loss * input_mask, axis=0)
        return paddle.sum(masked_loss)

训练测试

以下分别训练两种诗歌生成模型,观察训练效果的差异。

from paddle.static import InputSpec
from paddlenlp.metrics import Perplexity
from paddle.optimizer import AdamW

length = 1024

net_t = PoetryBertModelT('bert-base-chinese', length)
net_p = PoetryBertModelP('bert-base-chinese', length)

token_ids = InputSpec((-1, length), 'int64', 'token')
token_type_ids = InputSpec((-1, length), 'int64', 'token_type')
input_mask = InputSpec((-1, length), 'float32', 'input_mask')
label = InputSpec((-1, length), 'int64', 'label')

inputs = [token_ids, token_type_ids, input_mask]
labels = [label, input_mask]

model_t = paddle.Model(net_t, inputs, labels)
model_t.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model_t.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])
model_t.summary(inputs, [input.dtype for input in inputs])

model_p = paddle.Model(net_p, inputs, labels)
model_p.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model_p.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])
model_p.summary(inputs, [input.dtype for input in inputs])

[2021-04-22 20:51:45,354] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:45,354] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:49,461] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:49,461] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:50,310] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:50,310] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:54,397] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:54,397] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']


----------------------------------------------------------------------------------------------------------------------------------------------
        Layer (type)                                     Input Shape                                     Output Shape            Param #    
==============================================================================================================================================
        Embedding-1                                      [[1, 1024]]                                    [1, 1024, 768]         16,226,304   
        Embedding-2                                      [[1, 1024]]                                    [1, 1024, 768]           393,216    
        Embedding-3                                      [[1, 1024]]                                    [1, 1024, 768]            1,536     
        LayerNorm-1                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Dropout-1                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
      BertEmbeddings-1                                        []                                        [1, 1024, 768]              0       
          Linear-1                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
          Linear-2                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
          Linear-3                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
          Linear-4                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-1     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-3                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-2                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
          Linear-5                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-2                                    [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
          Linear-6                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-4                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-3                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-1                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
          Linear-7                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
          Linear-8                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
          Linear-9                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-10                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-2     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-6                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-4                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-11                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-5                                    [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-12                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-7                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-5                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-2                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-13                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-14                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-15                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-16                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-3     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-9                                     [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-6                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-17                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-8                                    [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-18                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-10                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-7                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-3                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-19                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-20                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-21                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-22                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-4     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-12                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-8                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-23                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-11                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-24                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-13                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-9                                    [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-4                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-25                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-26                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-27                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-28                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-5     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-15                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-10                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-29                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-14                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-30                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-16                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-11                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-5                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-31                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-32                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-33                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-34                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-6     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-18                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-12                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-35                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-17                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-36                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-19                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-13                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-6                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-37                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-38                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-39                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-40                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-7     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-21                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-14                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-41                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-20                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-42                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-22                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-15                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-7                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-43                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-44                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-45                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-46                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-8     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-24                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-16                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-47                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-23                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-48                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-25                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-17                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-8                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-49                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-50                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-51                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-52                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
    MultiHeadAttention-9     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-27                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-18                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-53                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-26                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-54                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-28                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-19                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-9                             [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-55                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-56                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-57                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-58                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-10     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-30                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-20                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-59                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-29                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-60                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-31                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-21                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-10                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-61                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-62                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-63                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-64                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-11     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-33                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-22                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-65                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-32                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-66                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-34                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-23                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-11                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-67                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-68                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-69                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-70                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-12     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-36                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-24                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-71                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-35                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-72                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-37                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-25                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-12                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
    TransformerEncoder-1                     [[1, 1024, 768], [1, 1, 1024, 1024]]                       [1, 1024, 768]              0       
         Linear-73                                        [[1, 768]]                                       [1, 768]              590,592    
           Tanh-2                                         [[1, 768]]                                       [1, 768]                 0       
        BertPooler-1                                   [[1, 1024, 768]]                                    [1, 768]                 0       
        BertModel-1                                      [[1, 1024]]                              [[1, 1024, 768], [1, 768]]        0       
         Dropout-38                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-74                                     [[1, 1024, 768]]                                [1, 1024, 21128]        16,247,432   
BertForTokenClassification-1                             [[1, 1024]]                                   [1, 1024, 21128]             0       
==============================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1752.15
Params size (MB): 452.10
Estimated Total Size (MB): 2204.26
----------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------------------------------------------------------------------
        Layer (type)                                     Input Shape                                     Output Shape            Param #    
==============================================================================================================================================
        Embedding-4                                      [[1, 1024]]                                    [1, 1024, 768]         16,226,304   
        Embedding-5                                      [[1, 1024]]                                    [1, 1024, 768]           393,216    
        Embedding-6                                      [[1, 1024]]                                    [1, 1024, 768]            1,536     
        LayerNorm-26                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Dropout-39                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
      BertEmbeddings-2                                        []                                        [1, 1024, 768]              0       
         Linear-75                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-76                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-77                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-78                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-13     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-41                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-27                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-79                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-40                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-80                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-42                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-28                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-13                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-81                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-82                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-83                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-84                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-14     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-44                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-29                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-85                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-43                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-86                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-45                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-30                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-14                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-87                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-88                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-89                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-90                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-15     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-47                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-31                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-91                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-46                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-92                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-48                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-32                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-15                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-93                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-94                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-95                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-96                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-16     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-50                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-33                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-97                                     [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-49                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-98                                    [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-51                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-34                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-16                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-99                                     [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-100                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-101                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-102                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-17     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-53                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-35                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-103                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-52                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-104                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-54                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-36                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-17                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-105                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-106                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-107                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-108                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-18     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-56                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-37                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-109                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-55                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-110                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-57                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-38                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-18                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-111                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-112                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-113                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-114                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-19     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-59                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-39                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-115                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-58                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-116                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-60                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-40                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-19                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-117                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-118                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-119                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-120                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-20     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-62                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-41                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-121                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-61                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-122                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-63                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-42                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-20                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-123                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-124                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-125                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-126                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-21     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-65                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-43                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-127                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-64                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-128                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-66                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-44                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-21                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-129                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-130                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-131                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-132                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-22     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-68                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-45                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-133                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-67                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-134                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-69                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-46                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-22                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-135                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-136                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-137                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-138                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-23     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-71                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-47                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-139                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-70                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-140                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-72                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-48                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-23                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-141                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-142                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-143                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
         Linear-144                                    [[1, 1024, 768]]                                 [1, 1024, 768]           590,592    
   MultiHeadAttention-24     [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]]       [1, 1024, 768]              0       
         Dropout-74                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-49                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
         Linear-145                                    [[1, 1024, 768]]                                [1, 1024, 3072]          2,362,368   
         Dropout-73                                   [[1, 1024, 3072]]                                [1, 1024, 3072]              0       
         Linear-146                                   [[1, 1024, 3072]]                                 [1, 1024, 768]          2,360,064   
         Dropout-75                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
        LayerNorm-50                                   [[1, 1024, 768]]                                 [1, 1024, 768]            1,536     
 TransformerEncoderLayer-24                            [[1, 1024, 768]]                                 [1, 1024, 768]              0       
    TransformerEncoder-2                     [[1, 1024, 768], [1, 1, 1024, 1024]]                       [1, 1024, 768]              0       
         Linear-147                                       [[1, 768]]                                       [1, 768]              590,592    
           Tanh-3                                         [[1, 768]]                                       [1, 768]                 0       
        BertPooler-2                                   [[1, 1024, 768]]                                    [1, 768]                 0       
        BertModel-2                                      [[1, 1024]]                              [[1, 1024, 768], [1, 768]]        0       
         Dropout-76                                    [[1, 1024, 768]]                                 [1, 1024, 768]              0       
         Linear-148                                    [[1, 1024, 768]]                                [1, 1024, 21128]        16,247,432   
BertForTokenClassification-2                             [[1, 1024]]                                   [1, 1024, 21128]             0       
==============================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1752.15
Params size (MB): 452.10
Estimated Total Size (MB): 2204.26
----------------------------------------------------------------------------------------------------------------------------------------------






{'total_params': 118515080, 'trainable_params': 118515080}
from paddle.io import DataLoader

# 开始处理
poems = data_preprocess(len_limit=length)

---------------------
    





    {'total_params': 118515080, 'trainable_params': 118515080}




```python
from paddle.io import DataLoader

# 开始处理
poems = data_preprocess(len_limit=length)

train_loader = DataLoader(PoemData(poems, bert_tokenizer, length), batch_size=2, shuffle=True)
model_t.fit(train_data=train_loader, epochs=1, verbose=2)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and


step  10/500 - loss: 725.6817 - Perplexity: 10427.6664 - 245ms/step
step  20/500 - loss: 555.3558 - Perplexity: 7645.9100 - 238ms/step
step  30/500 - loss: 1065.7946 - Perplexity: 5479.9408 - 236ms/step
step  40/500 - loss: 397.7573 - Perplexity: 4748.2296 - 234ms/step
step  50/500 - loss: 740.5175 - Perplexity: 3929.7151 - 233ms/step
step  60/500 - loss: 565.6876 - Perplexity: 3378.1735 - 232ms/step
step  70/500 - loss: 346.3937 - Perplexity: 2928.5249 - 232ms/step
step  80/500 - loss: 429.1893 - Perplexity: 2516.5488 - 232ms/step
step  90/500 - loss: 289.0322 - Perplexity: 2337.0202 - 232ms/step
step 100/500 - loss: 226.7606 - Perplexity: 2125.5245 - 231ms/step
step 110/500 - loss: 1059.4875 - Perplexity: 1981.2700 - 231ms/step
step 120/500 - loss: 779.4556 - Perplexity: 1861.8594 - 231ms/step
step 130/500 - loss: 1293.4988 - Perplexity: 1777.7791 - 231ms/step
step 140/500 - loss: 365.6456 - Perplexity: 1721.5329 - 232ms/step
step 150/500 - loss: 329.3266 - Perplexity: 1633.8337 - 232ms/step
step 160/500 - loss: 394.9833 - Perplexity: 1568.7457 - 232ms/step
step 170/500 - loss: 760.9999 - Perplexity: 1511.2424 - 232ms/step
step 180/500 - loss: 319.8988 - Perplexity: 1451.3553 - 232ms/step
step 190/500 - loss: 457.5365 - Perplexity: 1433.0332 - 232ms/step
step 200/500 - loss: 322.3470 - Perplexity: 1399.3860 - 232ms/step
step 210/500 - loss: 290.7716 - Perplexity: 1369.4990 - 232ms/step
step 220/500 - loss: 243.7075 - Perplexity: 1340.5879 - 232ms/step
step 230/500 - loss: 425.2493 - Perplexity: 1315.3565 - 232ms/step
step 240/500 - loss: 300.5534 - Perplexity: 1284.2336 - 232ms/step
step 250/500 - loss: 334.7778 - Perplexity: 1256.7779 - 232ms/step
step 260/500 - loss: 449.5439 - Perplexity: 1233.9944 - 232ms/step
step 270/500 - loss: 273.1763 - Perplexity: 1212.2996 - 232ms/step
step 280/500 - loss: 521.5925 - Perplexity: 1196.3253 - 232ms/step
step 290/500 - loss: 833.1059 - Perplexity: 1181.9895 - 232ms/step
step 300/500 - loss: 1103.6512 - Perplexity: 1165.6223 - 232ms/step
step 310/500 - loss: 1067.5687 - Perplexity: 1170.9803 - 232ms/step
step 320/500 - loss: 232.1698 - Perplexity: 1154.2350 - 232ms/step
step 330/500 - loss: 446.2109 - Perplexity: 1142.6015 - 232ms/step
step 340/500 - loss: 406.2852 - Perplexity: 1150.5179 - 232ms/step
step 350/500 - loss: 355.0574 - Perplexity: 1138.8899 - 232ms/step
step 360/500 - loss: 443.1240 - Perplexity: 1130.0989 - 232ms/step
step 370/500 - loss: 450.0056 - Perplexity: 1125.7253 - 232ms/step
step 380/500 - loss: 209.2993 - Perplexity: 1116.8284 - 232ms/step
step 390/500 - loss: 260.5238 - Perplexity: 1103.2239 - 232ms/step
step 400/500 - loss: 326.5078 - Perplexity: 1094.5374 - 232ms/step
step 410/500 - loss: 613.3599 - Perplexity: 1085.8666 - 232ms/step
step 420/500 - loss: 311.0659 - Perplexity: 1077.0101 - 232ms/step
step 430/500 - loss: 552.8693 - Perplexity: 1078.9429 - 232ms/step
step 440/500 - loss: 349.0563 - Perplexity: 1070.4293 - 232ms/step
step 450/500 - loss: 322.9454 - Perplexity: 1066.5452 - 232ms/step
step 460/500 - loss: 576.1926 - Perplexity: 1059.5172 - 232ms/step
step 470/500 - loss: 316.3662 - Perplexity: 1055.9816 - 232ms/step
step 480/500 - loss: 270.8025 - Perplexity: 1048.0514 - 232ms/step
step 490/500 - loss: 829.6791 - Perplexity: 1039.9843 - 232ms/step
step 500/500 - loss: 495.9052 - Perplexity: 1034.8373 - 232ms/step
model_p.fit(train_data=train_loader, epochs=1, verbose=2)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1
step  10/500 - loss: 551.1293 - Perplexity: 9490.1874 - 196ms/step
step  20/500 - loss: 952.8583 - Perplexity: 6732.4827 - 195ms/step
step  30/500 - loss: 501.8406 - Perplexity: 5721.9359 - 194ms/step
step  40/500 - loss: 244.4778 - Perplexity: 4413.8827 - 194ms/step
step  50/500 - loss: 650.4541 - Perplexity: 3721.9919 - 194ms/step
step  60/500 - loss: 756.1465 - Perplexity: 2999.6214 - 197ms/step
step  70/500 - loss: 273.4387 - Perplexity: 2500.5908 - 196ms/step
step  80/500 - loss: 704.1033 - Perplexity: 2296.5336 - 196ms/step
step  90/500 - loss: 400.3583 - Perplexity: 2160.8528 - 196ms/step
step 100/500 - loss: 216.9302 - Perplexity: 1938.2109 - 196ms/step
step 110/500 - loss: 170.1186 - Perplexity: 1837.5074 - 196ms/step
step 120/500 - loss: 323.2706 - Perplexity: 1732.3691 - 196ms/step
step 130/500 - loss: 466.0241 - Perplexity: 1634.3989 - 196ms/step
step 140/500 - loss: 297.2783 - Perplexity: 1570.9105 - 196ms/step
step 150/500 - loss: 418.7701 - Perplexity: 1513.4602 - 197ms/step
step 160/500 - loss: 342.0039 - Perplexity: 1468.8108 - 196ms/step
step 170/500 - loss: 463.7025 - Perplexity: 1428.2714 - 196ms/step
step 180/500 - loss: 1011.8646 - Perplexity: 1380.5736 - 196ms/step
step 190/500 - loss: 221.5213 - Perplexity: 1352.1810 - 196ms/step
step 200/500 - loss: 1533.4623 - Perplexity: 1317.0389 - 197ms/step
step 210/500 - loss: 336.1730 - Perplexity: 1287.4186 - 197ms/step
step 220/500 - loss: 431.8987 - Perplexity: 1265.3528 - 197ms/step
step 230/500 - loss: 536.3519 - Perplexity: 1248.9458 - 197ms/step
step 240/500 - loss: 1302.3909 - Perplexity: 1223.9964 - 197ms/step
step 250/500 - loss: 223.1607 - Perplexity: 1210.7413 - 198ms/step
step 260/500 - loss: 873.6428 - Perplexity: 1196.4195 - 201ms/step
step 270/500 - loss: 301.4405 - Perplexity: 1182.2666 - 205ms/step
step 280/500 - loss: 326.8961 - Perplexity: 1167.6853 - 208ms/step
step 290/500 - loss: 209.5225 - Perplexity: 1157.2032 - 211ms/step
step 300/500 - loss: 966.4078 - Perplexity: 1144.7167 - 215ms/step
step 310/500 - loss: 436.6230 - Perplexity: 1132.1078 - 216ms/step
step 320/500 - loss: 223.2067 - Perplexity: 1123.4431 - 218ms/step
step 330/500 - loss: 716.7662 - Perplexity: 1114.5825 - 219ms/step
step 340/500 - loss: 227.7101 - Perplexity: 1106.4880 - 221ms/step
step 350/500 - loss: 421.9973 - Perplexity: 1096.3944 - 222ms/step
step 360/500 - loss: 321.5515 - Perplexity: 1080.4534 - 222ms/step
step 370/500 - loss: 994.4940 - Perplexity: 1071.2373 - 221ms/step
step 380/500 - loss: 1063.0651 - Perplexity: 1064.0193 - 221ms/step
step 390/500 - loss: 375.4459 - Perplexity: 1058.5955 - 220ms/step
step 400/500 - loss: 341.6979 - Perplexity: 1048.0488 - 219ms/step
step 410/500 - loss: 528.9408 - Perplexity: 1039.0551 - 219ms/step
step 420/500 - loss: 420.7586 - Perplexity: 1033.6232 - 218ms/step
step 430/500 - loss: 260.0015 - Perplexity: 1028.3416 - 218ms/step
step 440/500 - loss: 453.0840 - Perplexity: 1022.2022 - 217ms/step
step 450/500 - loss: 461.1561 - Perplexity: 1018.8781 - 216ms/step
step 460/500 - loss: 586.7026 - Perplexity: 1012.4488 - 216ms/step
step 470/500 - loss: 557.7255 - Perplexity: 1007.6774 - 216ms/step
step 480/500 - loss: 213.4639 - Perplexity: 1002.1562 - 215ms/step
step 490/500 - loss: 493.5714 - Perplexity: 991.8984 - 215ms/step
step 500/500 - loss: 312.0637 - Perplexity: 986.7728 - 214ms/step

从以上结果可以看出,基于Performer的Bert模型训练速度更快(序列长度为1024),同时保证了单轮训练效果和普通Bert持平。

总结

这个项目实现了基于Paddle框架的Performer,并用基于Performer的Bert进行测试,实际结果在保证训练效果的前提下,训练速度更快。需要注意的是,本项目没有提供mask的解决方案,也没有实现考虑因果关系序列(即序列处理为单向时的情况)的处理方法,这将在以后进行解决,感兴趣的请关注后续项目的更新。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值