一种新的Transformer架构:Performer
Transformer充分发挥了注意力机制潜力,基于Transformer架构的模型(Bert,GPT,Plato,Erine等)在NLP领域取得了非常巨大的成功。然而,由于Transformer对于输入序列需要构建一个与序列长度等长的注意力矩阵,其内存和计算量会随着序列长度的增加而呈现平方增长,这使得长序列的处理代价很大。Performer使用谷歌提出的FAVOR算法(Fast Attention Via Positive Orthogonal Random Feature) 为注意力机制提供了一种可拓展的低方差、无偏估计,同时保证空间和时间的复杂度是接近线性的。
感兴趣的可以参考源文章。这个项目提供了该架构的Paddle代码实现(原文附有源代码)。
代码及测试
请解压Performers.zip获得Paddle代码。代码结构如下:
-----Performers
-------paddlepaddle // Paddle代码
----------fast_attention.py // Performer核心代码
----------fast_attention_test.py // 简单测试代码
----------transformer.py // 基于Performer实现的Transformer
----------bert // 基于Performer实现的Bert模型
--------tf2 // Performer源代码
```python
!unzip -o Performers.zip
如下,调用fast_attention_test.py,我们可以看到Performer模型相对于Transformer的加速效果。其中加速程度用speed up ratio来表示,为运行完单个注意力模块花费的时间之比(time_transformer / time_performer)。结果可以看出,当序列长度length > 1000时,Performer就可以发挥非常可观的加速效果;但是在1000以下时,加速的效果并不显著,特别是在序列长度本来就不长时,效果甚至差于Transformer。因此,希望用此模型加速的需要根据实际需求,其实在实际中,序列长度超过1000的情况并不多见。
from Performers.paddlepaddle import fast_attention_test
import prettytable
test = fast_attention_test.TransformerLayersTest()
lengthes = [128, 256, 512, 1024, 2048, 4096, 8192]
ratio = []
for l in lengthes:
ratio.append(test.test_softmax_noncausal_attention_block_output(l))
table = prettytable.PrettyTable()
table.field_names = ['length'] + [str(l) for l in lengthes]
table.add_row(['speed up ratio'] + ['%.5f'%r for r in ratio])
print(table)
+----------------+---------+---------+---------+---------+---------+----------+----------+
| length | 128 | 256 | 512 | 1024 | 2048 | 4096 | 8192 |
+----------------+---------+---------+---------+---------+---------+----------+----------+
| speed up ratio | 0.15373 | 0.35420 | 0.91381 | 3.22711 | 7.37718 | 14.36226 | 28.28287 |
+----------------+---------+---------+---------+---------+---------+----------+----------+
实例
下面以之前的项目用BERT实现自动写诗来示例如何使用基于Performer的模型。
数据准备
# 下载诗歌数据集 (从镜像网站github.com.cnpmjs.org下载可提高下载速度)
!git clone https://github.com.cnpmjs.org/chinese-poetry/chinese-poetry
# 下载繁体转简体工具
!git clone https://github.com.cnpmjs.org/fiyen/cht2chs
fatal: destination path 'chinese-poetry' already exists and is not an empty directory.
fatal: destination path 'cht2chs' already exists and is not an empty directory.
import os
import json
import re
from cht2chs.langconv import cht_to_chs
def sentenceParse(para):
"""
剔除诗歌字符中的非文字符号以及数字
"""
result, number = re.subn(u"(.*)", "", para)
result, number = re.subn(u"{.*}", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"《.*》", "", result)
result, number = re.subn(u"[\]\[]", "", result)
r = ""
for s in result:
if s not in set('0123456789-'):
r += s
r, number = re.subn(u"。。", u"。", r)
return r
def data_preprocess(poem_dir='./chinese-poetry/json', len_limit=120):
"""
预处理诗歌数据,返回符合要求的诗歌列表
"""
poems = []
for f in os.listdir(poem_dir)[:1]:
if f.endswith('.json'):
json_data = json.load(open(os.path.join(poem_dir, f)))
for d in json_data:
try:
poem = ''.join(d['paragraphs'])
poem = sentenceParse(poem)
# 控制长度,并将繁体字转换为简体字
if len(poem) <= len_limit:
poems.append(cht_to_chs(poem))
except:
continue
return poems
<>:14: DeprecationWarning: invalid escape sequence \]
<>:14: DeprecationWarning: invalid escape sequence \]
<>:14: DeprecationWarning: invalid escape sequence \]
<ipython-input-16-c76de160f7ba>:14: DeprecationWarning: invalid escape sequence \]
result, number = re.subn(u"[\]\[]", "", result)
from paddlenlp.transformers import BertTokenizer
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
[2021-04-22 20:51:45,286] [ INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt
[2021-04-22 20:51:45,286] [ INFO] - Found /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese-vocab.txt
import paddle
from paddle.io import Dataset
import numpy as np
class PoemData(Dataset):
"""
构造诗歌数据集,继承paddle.io.Dataset
Parameters:
poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码
max_len: 接收诗歌的最大长度
"""
def __init__(self, poems, tokenizer, max_len=128):
super(PoemData, self).__init__()
self.poems = poems
self.tokenizer = tokenizer
self.max_len = max_len
def __getitem__(self, idx):
line = poems[idx]
token_line = self.tokenizer.encode(line)
token, token_type = token_line['input_ids'], token_line['token_type_ids']
if len(token) > self.max_len + 1:
token = token[:self.max_len] + token[-1]
token_type = token_type[:self.max_len] + token_type[-1]
input_token, input_token_type = token[:-1], token_type[:-1]
label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')
# 输入填充
input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')
input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')
input_pad_mask = (input_token != 0).astype('float32')
return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask
def __len__(self):
return len(self.poems)
构建模型
以下以简写‘T’代表Transformer,‘P’代表Performer,分别构造诗歌生成模型,来进行效果对比。
其中,BertP为调用基于Performer的Transformer后构造的Bert模型,由于Performer没有改变注意力模型的结构,因此仍然可以使用预训练模型的参数。
from paddle.nn import Layer, Linear, Softmax
from paddlenlp.transformers import BertModel as BertT
from paddlenlp.transformers import BertForTokenClassification as BertClassT
class PoetryBertModelT(Layer):
"""
基于BERT预训练模型的诗歌生成模型
"""
def __init__(self, pretrained_bert_model: str, input_length: int):
super(PoetryBertModelT, self).__init__()
bert_model = BertT.from_pretrained(pretrained_bert_model)
self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
self.bert_for_class = BertClassT(bert_model, self.vocab_size)
# 生成下三角矩阵,用来mask句子后边的信息
self.sequence_length = input_length
self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
def forward(self, token, token_type, input_mask, input_length=None):
# 计算attention mask
mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
# 输入句子中有效的位置
mask_left = paddle.cast(mask_left, 'float32')
mask_right = paddle.cast(mask_right, 'float32')
attention_mask = paddle.matmul(mask_left, mask_right)
# 注意力机制计算中有效的位置
if input_length is not None:
lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
else:
lower_triangle_mask = self.lower_triangle_mask
attention_mask = attention_mask * lower_triangle_mask
# 无效的位置设为极小值
attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)
output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
return output_logits
from paddle.nn import Layer, Linear, Softmax
from Performers.paddlepaddle.bert.modeling import BertModel as BertP
from Performers.paddlepaddle.bert.modeling import BertForTokenClassification as BertClassP
class PoetryBertModelP(Layer):
"""
基于BERT预训练模型的诗歌生成模型
"""
def __init__(self, pretrained_bert_model: str, input_length: int):
super(PoetryBertModelP, self).__init__()
bert_model = BertP.from_pretrained(pretrained_bert_model)
self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
self.bert_for_class = BertClassP(bert_model, self.vocab_size)
# 生成下三角矩阵,用来mask句子后边的信息
self.sequence_length = input_length
self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
def forward(self, token, token_type, input_mask, input_length=None):
# 计算attention mask
mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
# 输入句子中有效的位置
mask_left = paddle.cast(mask_left, 'float32')
mask_right = paddle.cast(mask_right, 'float32')
attention_mask = paddle.matmul(mask_left, mask_right)
# 注意力机制计算中有效的位置
if input_length is not None:
lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
else:
lower_triangle_mask = self.lower_triangle_mask
attention_mask = attention_mask * lower_triangle_mask
# 无效的位置设为极小值
attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)
output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
return output_logits
class PoetryBertModelLossCriterion(Layer):
def forward(self, pred_logits, label, input_mask):
loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')
masked_loss = paddle.mean(loss * input_mask, axis=0)
return paddle.sum(masked_loss)
训练测试
以下分别训练两种诗歌生成模型,观察训练效果的差异。
from paddle.static import InputSpec
from paddlenlp.metrics import Perplexity
from paddle.optimizer import AdamW
length = 1024
net_t = PoetryBertModelT('bert-base-chinese', length)
net_p = PoetryBertModelP('bert-base-chinese', length)
token_ids = InputSpec((-1, length), 'int64', 'token')
token_type_ids = InputSpec((-1, length), 'int64', 'token_type')
input_mask = InputSpec((-1, length), 'float32', 'input_mask')
label = InputSpec((-1, length), 'int64', 'label')
inputs = [token_ids, token_type_ids, input_mask]
labels = [label, input_mask]
model_t = paddle.Model(net_t, inputs, labels)
model_t.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model_t.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])
model_t.summary(inputs, [input.dtype for input in inputs])
model_p = paddle.Model(net_p, inputs, labels)
model_p.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model_p.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])
model_p.summary(inputs, [input.dtype for input in inputs])
[2021-04-22 20:51:45,354] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:45,354] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:49,461] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:49,461] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:50,310] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:50,310] [ INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-22 20:51:54,397] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
[2021-04-22 20:51:54,397] [ INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
----------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==============================================================================================================================================
Embedding-1 [[1, 1024]] [1, 1024, 768] 16,226,304
Embedding-2 [[1, 1024]] [1, 1024, 768] 393,216
Embedding-3 [[1, 1024]] [1, 1024, 768] 1,536
LayerNorm-1 [[1, 1024, 768]] [1, 1024, 768] 1,536
Dropout-1 [[1, 1024, 768]] [1, 1024, 768] 0
BertEmbeddings-1 [] [1, 1024, 768] 0
Linear-1 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-2 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-3 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-4 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-1 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-3 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-2 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-5 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-2 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-6 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-4 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-3 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-1 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-7 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-8 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-9 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-10 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-2 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-6 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-4 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-11 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-5 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-12 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-7 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-5 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-2 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-13 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-14 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-15 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-16 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-3 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-9 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-6 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-17 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-8 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-18 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-10 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-7 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-3 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-19 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-20 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-21 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-22 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-4 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-12 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-8 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-23 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-11 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-24 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-13 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-9 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-4 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-25 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-26 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-27 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-28 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-5 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-15 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-10 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-29 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-14 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-30 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-16 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-11 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-5 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-31 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-32 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-33 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-34 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-6 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-18 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-12 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-35 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-17 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-36 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-19 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-13 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-6 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-37 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-38 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-39 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-40 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-7 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-21 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-14 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-41 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-20 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-42 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-22 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-15 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-7 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-43 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-44 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-45 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-46 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-8 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-24 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-16 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-47 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-23 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-48 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-25 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-17 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-8 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-49 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-50 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-51 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-52 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-9 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-27 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-18 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-53 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-26 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-54 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-28 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-19 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-9 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-55 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-56 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-57 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-58 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-10 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-30 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-20 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-59 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-29 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-60 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-31 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-21 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-10 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-61 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-62 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-63 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-64 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-11 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-33 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-22 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-65 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-32 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-66 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-34 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-23 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-11 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-67 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-68 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-69 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-70 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-12 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-36 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-24 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-71 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-35 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-72 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-37 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-25 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-12 [[1, 1024, 768]] [1, 1024, 768] 0
TransformerEncoder-1 [[1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Linear-73 [[1, 768]] [1, 768] 590,592
Tanh-2 [[1, 768]] [1, 768] 0
BertPooler-1 [[1, 1024, 768]] [1, 768] 0
BertModel-1 [[1, 1024]] [[1, 1024, 768], [1, 768]] 0
Dropout-38 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-74 [[1, 1024, 768]] [1, 1024, 21128] 16,247,432
BertForTokenClassification-1 [[1, 1024]] [1, 1024, 21128] 0
==============================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1752.15
Params size (MB): 452.10
Estimated Total Size (MB): 2204.26
----------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==============================================================================================================================================
Embedding-4 [[1, 1024]] [1, 1024, 768] 16,226,304
Embedding-5 [[1, 1024]] [1, 1024, 768] 393,216
Embedding-6 [[1, 1024]] [1, 1024, 768] 1,536
LayerNorm-26 [[1, 1024, 768]] [1, 1024, 768] 1,536
Dropout-39 [[1, 1024, 768]] [1, 1024, 768] 0
BertEmbeddings-2 [] [1, 1024, 768] 0
Linear-75 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-76 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-77 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-78 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-13 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-41 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-27 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-79 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-40 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-80 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-42 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-28 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-13 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-81 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-82 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-83 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-84 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-14 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-44 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-29 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-85 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-43 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-86 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-45 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-30 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-14 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-87 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-88 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-89 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-90 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-15 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-47 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-31 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-91 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-46 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-92 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-48 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-32 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-15 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-93 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-94 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-95 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-96 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-16 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-50 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-33 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-97 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-49 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-98 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-51 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-34 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-16 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-99 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-100 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-101 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-102 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-17 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-53 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-35 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-103 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-52 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-104 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-54 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-36 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-17 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-105 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-106 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-107 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-108 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-18 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-56 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-37 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-109 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-55 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-110 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-57 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-38 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-18 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-111 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-112 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-113 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-114 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-19 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-59 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-39 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-115 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-58 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-116 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-60 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-40 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-19 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-117 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-118 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-119 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-120 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-20 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-62 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-41 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-121 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-61 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-122 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-63 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-42 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-20 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-123 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-124 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-125 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-126 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-21 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-65 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-43 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-127 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-64 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-128 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-66 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-44 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-21 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-129 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-130 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-131 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-132 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-22 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-68 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-45 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-133 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-67 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-134 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-69 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-46 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-22 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-135 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-136 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-137 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-138 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-23 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-71 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-47 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-139 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-70 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-140 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-72 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-48 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-23 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-141 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-142 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-143 [[1, 1024, 768]] [1, 1024, 768] 590,592
Linear-144 [[1, 1024, 768]] [1, 1024, 768] 590,592
MultiHeadAttention-24 [[1, 1024, 768], [1, 1024, 768], [1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Dropout-74 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-49 [[1, 1024, 768]] [1, 1024, 768] 1,536
Linear-145 [[1, 1024, 768]] [1, 1024, 3072] 2,362,368
Dropout-73 [[1, 1024, 3072]] [1, 1024, 3072] 0
Linear-146 [[1, 1024, 3072]] [1, 1024, 768] 2,360,064
Dropout-75 [[1, 1024, 768]] [1, 1024, 768] 0
LayerNorm-50 [[1, 1024, 768]] [1, 1024, 768] 1,536
TransformerEncoderLayer-24 [[1, 1024, 768]] [1, 1024, 768] 0
TransformerEncoder-2 [[1, 1024, 768], [1, 1, 1024, 1024]] [1, 1024, 768] 0
Linear-147 [[1, 768]] [1, 768] 590,592
Tanh-3 [[1, 768]] [1, 768] 0
BertPooler-2 [[1, 1024, 768]] [1, 768] 0
BertModel-2 [[1, 1024]] [[1, 1024, 768], [1, 768]] 0
Dropout-76 [[1, 1024, 768]] [1, 1024, 768] 0
Linear-148 [[1, 1024, 768]] [1, 1024, 21128] 16,247,432
BertForTokenClassification-2 [[1, 1024]] [1, 1024, 21128] 0
==============================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1752.15
Params size (MB): 452.10
Estimated Total Size (MB): 2204.26
----------------------------------------------------------------------------------------------------------------------------------------------
{'total_params': 118515080, 'trainable_params': 118515080}
from paddle.io import DataLoader
# 开始处理
poems = data_preprocess(len_limit=length)
---------------------
{'total_params': 118515080, 'trainable_params': 118515080}
```python
from paddle.io import DataLoader
# 开始处理
poems = data_preprocess(len_limit=length)
train_loader = DataLoader(PoemData(poems, bert_tokenizer, length), batch_size=2, shuffle=True)
model_t.fit(train_data=train_loader, epochs=1, verbose=2)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
step 10/500 - loss: 725.6817 - Perplexity: 10427.6664 - 245ms/step
step 20/500 - loss: 555.3558 - Perplexity: 7645.9100 - 238ms/step
step 30/500 - loss: 1065.7946 - Perplexity: 5479.9408 - 236ms/step
step 40/500 - loss: 397.7573 - Perplexity: 4748.2296 - 234ms/step
step 50/500 - loss: 740.5175 - Perplexity: 3929.7151 - 233ms/step
step 60/500 - loss: 565.6876 - Perplexity: 3378.1735 - 232ms/step
step 70/500 - loss: 346.3937 - Perplexity: 2928.5249 - 232ms/step
step 80/500 - loss: 429.1893 - Perplexity: 2516.5488 - 232ms/step
step 90/500 - loss: 289.0322 - Perplexity: 2337.0202 - 232ms/step
step 100/500 - loss: 226.7606 - Perplexity: 2125.5245 - 231ms/step
step 110/500 - loss: 1059.4875 - Perplexity: 1981.2700 - 231ms/step
step 120/500 - loss: 779.4556 - Perplexity: 1861.8594 - 231ms/step
step 130/500 - loss: 1293.4988 - Perplexity: 1777.7791 - 231ms/step
step 140/500 - loss: 365.6456 - Perplexity: 1721.5329 - 232ms/step
step 150/500 - loss: 329.3266 - Perplexity: 1633.8337 - 232ms/step
step 160/500 - loss: 394.9833 - Perplexity: 1568.7457 - 232ms/step
step 170/500 - loss: 760.9999 - Perplexity: 1511.2424 - 232ms/step
step 180/500 - loss: 319.8988 - Perplexity: 1451.3553 - 232ms/step
step 190/500 - loss: 457.5365 - Perplexity: 1433.0332 - 232ms/step
step 200/500 - loss: 322.3470 - Perplexity: 1399.3860 - 232ms/step
step 210/500 - loss: 290.7716 - Perplexity: 1369.4990 - 232ms/step
step 220/500 - loss: 243.7075 - Perplexity: 1340.5879 - 232ms/step
step 230/500 - loss: 425.2493 - Perplexity: 1315.3565 - 232ms/step
step 240/500 - loss: 300.5534 - Perplexity: 1284.2336 - 232ms/step
step 250/500 - loss: 334.7778 - Perplexity: 1256.7779 - 232ms/step
step 260/500 - loss: 449.5439 - Perplexity: 1233.9944 - 232ms/step
step 270/500 - loss: 273.1763 - Perplexity: 1212.2996 - 232ms/step
step 280/500 - loss: 521.5925 - Perplexity: 1196.3253 - 232ms/step
step 290/500 - loss: 833.1059 - Perplexity: 1181.9895 - 232ms/step
step 300/500 - loss: 1103.6512 - Perplexity: 1165.6223 - 232ms/step
step 310/500 - loss: 1067.5687 - Perplexity: 1170.9803 - 232ms/step
step 320/500 - loss: 232.1698 - Perplexity: 1154.2350 - 232ms/step
step 330/500 - loss: 446.2109 - Perplexity: 1142.6015 - 232ms/step
step 340/500 - loss: 406.2852 - Perplexity: 1150.5179 - 232ms/step
step 350/500 - loss: 355.0574 - Perplexity: 1138.8899 - 232ms/step
step 360/500 - loss: 443.1240 - Perplexity: 1130.0989 - 232ms/step
step 370/500 - loss: 450.0056 - Perplexity: 1125.7253 - 232ms/step
step 380/500 - loss: 209.2993 - Perplexity: 1116.8284 - 232ms/step
step 390/500 - loss: 260.5238 - Perplexity: 1103.2239 - 232ms/step
step 400/500 - loss: 326.5078 - Perplexity: 1094.5374 - 232ms/step
step 410/500 - loss: 613.3599 - Perplexity: 1085.8666 - 232ms/step
step 420/500 - loss: 311.0659 - Perplexity: 1077.0101 - 232ms/step
step 430/500 - loss: 552.8693 - Perplexity: 1078.9429 - 232ms/step
step 440/500 - loss: 349.0563 - Perplexity: 1070.4293 - 232ms/step
step 450/500 - loss: 322.9454 - Perplexity: 1066.5452 - 232ms/step
step 460/500 - loss: 576.1926 - Perplexity: 1059.5172 - 232ms/step
step 470/500 - loss: 316.3662 - Perplexity: 1055.9816 - 232ms/step
step 480/500 - loss: 270.8025 - Perplexity: 1048.0514 - 232ms/step
step 490/500 - loss: 829.6791 - Perplexity: 1039.9843 - 232ms/step
step 500/500 - loss: 495.9052 - Perplexity: 1034.8373 - 232ms/step
model_p.fit(train_data=train_loader, epochs=1, verbose=2)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/1
step 10/500 - loss: 551.1293 - Perplexity: 9490.1874 - 196ms/step
step 20/500 - loss: 952.8583 - Perplexity: 6732.4827 - 195ms/step
step 30/500 - loss: 501.8406 - Perplexity: 5721.9359 - 194ms/step
step 40/500 - loss: 244.4778 - Perplexity: 4413.8827 - 194ms/step
step 50/500 - loss: 650.4541 - Perplexity: 3721.9919 - 194ms/step
step 60/500 - loss: 756.1465 - Perplexity: 2999.6214 - 197ms/step
step 70/500 - loss: 273.4387 - Perplexity: 2500.5908 - 196ms/step
step 80/500 - loss: 704.1033 - Perplexity: 2296.5336 - 196ms/step
step 90/500 - loss: 400.3583 - Perplexity: 2160.8528 - 196ms/step
step 100/500 - loss: 216.9302 - Perplexity: 1938.2109 - 196ms/step
step 110/500 - loss: 170.1186 - Perplexity: 1837.5074 - 196ms/step
step 120/500 - loss: 323.2706 - Perplexity: 1732.3691 - 196ms/step
step 130/500 - loss: 466.0241 - Perplexity: 1634.3989 - 196ms/step
step 140/500 - loss: 297.2783 - Perplexity: 1570.9105 - 196ms/step
step 150/500 - loss: 418.7701 - Perplexity: 1513.4602 - 197ms/step
step 160/500 - loss: 342.0039 - Perplexity: 1468.8108 - 196ms/step
step 170/500 - loss: 463.7025 - Perplexity: 1428.2714 - 196ms/step
step 180/500 - loss: 1011.8646 - Perplexity: 1380.5736 - 196ms/step
step 190/500 - loss: 221.5213 - Perplexity: 1352.1810 - 196ms/step
step 200/500 - loss: 1533.4623 - Perplexity: 1317.0389 - 197ms/step
step 210/500 - loss: 336.1730 - Perplexity: 1287.4186 - 197ms/step
step 220/500 - loss: 431.8987 - Perplexity: 1265.3528 - 197ms/step
step 230/500 - loss: 536.3519 - Perplexity: 1248.9458 - 197ms/step
step 240/500 - loss: 1302.3909 - Perplexity: 1223.9964 - 197ms/step
step 250/500 - loss: 223.1607 - Perplexity: 1210.7413 - 198ms/step
step 260/500 - loss: 873.6428 - Perplexity: 1196.4195 - 201ms/step
step 270/500 - loss: 301.4405 - Perplexity: 1182.2666 - 205ms/step
step 280/500 - loss: 326.8961 - Perplexity: 1167.6853 - 208ms/step
step 290/500 - loss: 209.5225 - Perplexity: 1157.2032 - 211ms/step
step 300/500 - loss: 966.4078 - Perplexity: 1144.7167 - 215ms/step
step 310/500 - loss: 436.6230 - Perplexity: 1132.1078 - 216ms/step
step 320/500 - loss: 223.2067 - Perplexity: 1123.4431 - 218ms/step
step 330/500 - loss: 716.7662 - Perplexity: 1114.5825 - 219ms/step
step 340/500 - loss: 227.7101 - Perplexity: 1106.4880 - 221ms/step
step 350/500 - loss: 421.9973 - Perplexity: 1096.3944 - 222ms/step
step 360/500 - loss: 321.5515 - Perplexity: 1080.4534 - 222ms/step
step 370/500 - loss: 994.4940 - Perplexity: 1071.2373 - 221ms/step
step 380/500 - loss: 1063.0651 - Perplexity: 1064.0193 - 221ms/step
step 390/500 - loss: 375.4459 - Perplexity: 1058.5955 - 220ms/step
step 400/500 - loss: 341.6979 - Perplexity: 1048.0488 - 219ms/step
step 410/500 - loss: 528.9408 - Perplexity: 1039.0551 - 219ms/step
step 420/500 - loss: 420.7586 - Perplexity: 1033.6232 - 218ms/step
step 430/500 - loss: 260.0015 - Perplexity: 1028.3416 - 218ms/step
step 440/500 - loss: 453.0840 - Perplexity: 1022.2022 - 217ms/step
step 450/500 - loss: 461.1561 - Perplexity: 1018.8781 - 216ms/step
step 460/500 - loss: 586.7026 - Perplexity: 1012.4488 - 216ms/step
step 470/500 - loss: 557.7255 - Perplexity: 1007.6774 - 216ms/step
step 480/500 - loss: 213.4639 - Perplexity: 1002.1562 - 215ms/step
step 490/500 - loss: 493.5714 - Perplexity: 991.8984 - 215ms/step
step 500/500 - loss: 312.0637 - Perplexity: 986.7728 - 214ms/step
从以上结果可以看出,基于Performer的Bert模型训练速度更快(序列长度为1024),同时保证了单轮训练效果和普通Bert持平。
总结
这个项目实现了基于Paddle框架的Performer,并用基于Performer的Bert进行测试,实际结果在保证训练效果的前提下,训练速度更快。需要注意的是,本项目没有提供mask的解决方案,也没有实现考虑因果关系序列(即序列处理为单向时的情况)的处理方法,这将在以后进行解决,感兴趣的请关注后续项目的更新。