Sequence-to-Sequence modeling with nn.Transformer and torchtext
使用nn.Transformer
模块训练sequence-to-sequence模型。
pytorch1.2版本之后包含了标准的transformer模块, 这个模块是基于paper《Attention is All You Need》。nn.Transformer
模块完全依赖注意力机制,来构建从输入到输出的全局依赖关系。
nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', custom_encoder=None, custom_decoder=None)
A transformer model。可以根据需要设置参数。这个架构是基于paper《Attention is all you need》。
参数:
- d_model: 输入到encoder/decoder中特征数量,即:embedding的维度,默认512。
- nhead:the multiheadattention models的heads的数量,默认8。
- num_encoder_layer: encoder中sub-encoder-layers的数量,默认6。
- num_decoder_layers: decoder中sub-decoder-layers的数量,默认6。
- dim_feedforward: the feedforward network model的维度,默认2048。
- dropout: the dropout value,默认0.1。
- activation: encoder/decoder intermediate layer的激活函数,relu/gelu,默认relu。
- custom_encoder: 自定义encoder。
- custom_decoder:自定义decoder。
Encoder: encoder有N=6个完全相同的layer组合。每个layer有2个sub-layers。其中第一个是multi-head self-attention机制,第二个是全连接的前馈神经网络。这2个sub-layers都加入了 residual connection(残差连接)和layer normalization的处理,即:LayerNorm(x+Sublayer(x))。为了利用残差连接,在Transformer模型中所有的sub-layers和embedding layers,所产生的outputs的维度
d
m
o
d
e
l
d_{model}
dmodel=512.
1. Define the model
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def generate_square_subsequent_mask(self, sz):
# triu函数,表示对角线和对角线以上元素保持不变,其他设为0.
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
# masked_fill函数,另元素取值为mask==0的位置,设置为'-inf',元素取值mask==1的位置设置为0.
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
# 返回一个tensor:
"""
tensor([[0., -inf, -inf, ..., -inf, -inf, -inf],
[0., 0., -inf, ..., -inf, -inf, -inf],
[0., 0., 0., ..., -inf, -inf, -inf],
...,
[0., 0., 0., ..., 0., -inf, -inf],
[0., 0., 0., ..., 0., 0., -inf],
[0., 0., 0., ..., 0., 0., 0.]])
"""
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src, src_mask):
src = self.encoder(src) * math.sqrt(self.ninp) # torch.Size([35, 20, 200])
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
output = self.decoder(output) # torch.Size([35, 20, 28783])
return output
PositionalEncoding
modules 学习了sequence中tokens的相对和绝对位置信息。positional encoding和the embedding有相同的维度,以便于两者能够相加。
P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 ( 2 i / d ) ) P E ( p o s , 2 i + 1 ) = c o n ( p o s / 1000 0 ( 2 i / d ) ) PE_{(pos, 2i)} = sin(pos/10000^{(2i/d)})\\ PE_{(pos, 2i+1)} = con(pos/10000^{(2i/d)}) PE(pos,2i)=sin(pos/10000(2i/d))PE(pos,2i+1)=con(pos/10000(2i/d))
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# 构建 position encoding 矩阵,维度是(max_len, d_model),
# 其中max_len是假设sequence(句子)长度最大是max_len=5000,sequence中token的维度d_model=512
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0)/d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
# 向模块添加持久缓冲区
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
2. Load and batch data
vocab 是根据训练数据集生成的,
batchify函数的作用,如图所示:
import io
import torch
from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
# 根据URL地址下载并解压提取出文件,生成三个文件:wiki.test.tokens、wiki.train.tokens、wiki.valid.tokens
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url))
# 预处理sequence,去除特殊字符,并split等操作,生成tokens。
tokenizer = get_tokenizer('basic_english')
# 基于训练集构建vocab,即具有 {'I': 1, ... 'am': 1732, ...}的功能
vocab = build_vocab_from_iterator(map(tokenizer,
iter(io.open(train_filepath,
encoding='utf-8'))))
def data_process(raw_text_iter):
"""
raw_text_iter: 文本集中的每行一个迭代
tokenizer: 对sequence(句子)进行数据预处理,比如去除特殊字符和空格,并split生成类似['I', 'love', 'pytorch', ...]的迭代器
vocab: vocab是一个字典,每个token都有唯一的数字id,从而对sequence转换为数字序列。
返回值:先过滤掉元素个数为0的数字序列,然后用cat拼接在一起,生成类似tensor([ 10, 3850, 3870, ..., 2443]形式。
"""
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel()>0, data)))
train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def batchify(data, bsz):
"""
若data的长度是26,batch_size=4,那么我们将分为4个sequence,每个sequence长度为6。
"""
# 每个batch
nbatch = data.size(0) // bsz
# 函数narrow参数:(dim, start, length),dim是narrow要操作的维度,start,length表示在该维度上进行切片[start: start+length]
data = data.narrow(0, 0, nbatch*bsz)
# contiguous 函数一般与transpose,permute,view搭配使用,使得tensor内存中连续存储
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
print(train_data.size())
36718lines [00:01, 25799.69lines/s]
torch.Size([102499, 20])
3. Functions to generate input and target sequence
get_batch()函数是构建transformer model的input和target,然后source的数据样本长度length设定为bptt。继续以序列[A、B、C…X、Y、Z]举例子,设定bptt=2,其batch_size还是等于4.
bptt = 35
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i: i+seq_len] # torch.Size([35, 20])
target = source[i+1: i+1+seq_len].reshape(-1) # torch.Size([700])
return data, target
4. Initiate an instance
超参数设置如下所示,vocab的大小等于vocab object的长度。
ntokens = len(vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransforerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
5. Run the model
# loss object
criterion = nn.CrossEntropyLoss()
# learning rate
lr = 5.0
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# 初始lr设定5.0,根据epochs的变化,利用StepLR自动调节lr
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
import time
def train():
# turn on the train mode
model.train()
total_loss = 0.0
start_time = time.time()
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
optimizer.zero_grad()
if data.size(0) != bptt:
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = model(data, src_mask)
# print(output.size()) # torch.Size([35, 20, 28783])
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
# 梯度剪枝,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
log_interval = 200
if batch % log_interval == 0 and batch > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | '
'lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
elapsed * 1000 / log_interval,
cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
def evaluate(eval_model, data_source):
eval_model.eval() # turn on the evaluation mode
total_loss = 0.0
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
if data.size(0) != bptt:
src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = eval_model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source)-1)
loop over epochs,并保存模型直到validation loss 的效果最好。根据each epoch自动调整learning rate。
best_val_loss = float("inf")
epochs = 3
best_model = None
for epoch in range(1, epochs+1):
epoch_start_time = time.time()
train()
val_loss = evaluate(model, val_data)
print('-'*89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-'*89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model
# Adjust the learning rate after each epoch
scheduler.step()
| epoch 1 | 200/ 2928 batches | lr 5.00 | ms/batch 12.63 | loss 5.48 | ppl 240.28
| epoch 1 | 400/ 2928 batches | lr 5.00 | ms/batch 12.38 | loss 5.51 | ppl 246.10
| epoch 1 | 600/ 2928 batches | lr 5.00 | ms/batch 12.42 | loss 5.31 | ppl 202.03
| epoch 1 | 800/ 2928 batches | lr 5.00 | ms/batch 12.63 | loss 5.38 | ppl 216.16
| epoch 1 | 1000/ 2928 batches | lr 5.00 | ms/batch 12.61 | loss 5.35 | ppl 209.95
| epoch 1 | 1200/ 2928 batches | lr 5.00 | ms/batch 12.59 | loss 5.39 | ppl 218.51
| epoch 1 | 1400/ 2928 batches | lr 5.00 | ms/batch 12.48 | loss 5.41 | ppl 223.10
| epoch 1 | 1600/ 2928 batches | lr 5.00 | ms/batch 12.50 | loss 5.44 | ppl 230.27
| epoch 1 | 1800/ 2928 batches | lr 5.00 | ms/batch 14.89 | loss 5.39 | ppl 218.99
| epoch 1 | 2000/ 2928 batches | lr 5.00 | ms/batch 13.27 | loss 5.42 | ppl 226.92
| epoch 1 | 2200/ 2928 batches | lr 5.00 | ms/batch 12.62 | loss 5.27 | ppl 195.07
| epoch 1 | 2400/ 2928 batches | lr 5.00 | ms/batch 12.56 | loss 5.40 | ppl 221.41
| epoch 1 | 2600/ 2928 batches | lr 5.00 | ms/batch 12.62 | loss 5.40 | ppl 222.03
| epoch 1 | 2800/ 2928 batches | lr 5.00 | ms/batch 12.68 | loss 5.34 | ppl 208.62
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 39.47s | valid loss 5.58 | valid ppl 264.86
-----------------------------------------------------------------------------------------
| epoch 2 | 200/ 2928 batches | lr 4.51 | ms/batch 12.75 | loss 5.37 | ppl 214.61
| epoch 2 | 400/ 2928 batches | lr 4.51 | ms/batch 13.61 | loss 5.40 | ppl 220.96
| epoch 2 | 600/ 2928 batches | lr 4.51 | ms/batch 13.16 | loss 5.20 | ppl 181.73
| epoch 2 | 800/ 2928 batches | lr 4.51 | ms/batch 13.10 | loss 5.28 | ppl 196.55
| epoch 2 | 1000/ 2928 batches | lr 4.51 | ms/batch 13.23 | loss 5.27 | ppl 193.49
| epoch 2 | 1200/ 2928 batches | lr 4.51 | ms/batch 13.56 | loss 5.28 | ppl 196.04
| epoch 2 | 1400/ 2928 batches | lr 4.51 | ms/batch 12.68 | loss 5.30 | ppl 200.68
| epoch 2 | 1600/ 2928 batches | lr 4.51 | ms/batch 12.65 | loss 5.34 | ppl 208.51
| epoch 2 | 1800/ 2928 batches | lr 4.51 | ms/batch 13.01 | loss 5.29 | ppl 199.06
| epoch 2 | 2000/ 2928 batches | lr 4.51 | ms/batch 14.39 | loss 5.29 | ppl 199.01
| epoch 2 | 2200/ 2928 batches | lr 4.51 | ms/batch 12.66 | loss 5.17 | ppl 175.10
| epoch 2 | 2400/ 2928 batches | lr 4.51 | ms/batch 12.67 | loss 5.28 | ppl 195.85
| epoch 2 | 2600/ 2928 batches | lr 4.51 | ms/batch 12.67 | loss 5.31 | ppl 202.56
| epoch 2 | 2800/ 2928 batches | lr 4.51 | ms/batch 12.73 | loss 5.22 | ppl 184.74
-----------------------------------------------------------------------------------------
| end of epoch 2 | time: 40.29s | valid loss 5.58 | valid ppl 266.35
-----------------------------------------------------------------------------------------
| epoch 3 | 200/ 2928 batches | lr 4.29 | ms/batch 12.78 | loss 5.25 | ppl 191.13
| epoch 3 | 400/ 2928 batches | lr 4.29 | ms/batch 12.69 | loss 5.28 | ppl 196.60
| epoch 3 | 600/ 2928 batches | lr 4.29 | ms/batch 12.68 | loss 5.10 | ppl 163.39
| epoch 3 | 800/ 2928 batches | lr 4.29 | ms/batch 13.14 | loss 5.17 | ppl 176.29
| epoch 3 | 1000/ 2928 batches | lr 4.29 | ms/batch 14.30 | loss 5.13 | ppl 168.89
| epoch 3 | 1200/ 2928 batches | lr 4.29 | ms/batch 12.64 | loss 5.17 | ppl 176.30
| epoch 3 | 1400/ 2928 batches | lr 4.29 | ms/batch 12.99 | loss 5.20 | ppl 180.48
| epoch 3 | 1600/ 2928 batches | lr 4.29 | ms/batch 12.67 | loss 5.23 | ppl 186.36
| epoch 3 | 1800/ 2928 batches | lr 4.29 | ms/batch 12.91 | loss 5.19 | ppl 179.06
| epoch 3 | 2000/ 2928 batches | lr 4.29 | ms/batch 12.73 | loss 5.20 | ppl 181.16
| epoch 3 | 2200/ 2928 batches | lr 4.29 | ms/batch 12.68 | loss 5.06 | ppl 157.73
| epoch 3 | 2400/ 2928 batches | lr 4.29 | ms/batch 12.81 | loss 5.18 | ppl 178.23
| epoch 3 | 2600/ 2928 batches | lr 4.29 | ms/batch 12.82 | loss 5.20 | ppl 180.67
| epoch 3 | 2800/ 2928 batches | lr 4.29 | ms/batch 12.62 | loss 5.14 | ppl 169.99
-----------------------------------------------------------------------------------------
| end of epoch 3 | time: 39.80s | valid loss 5.55 | valid ppl 256.71
-----------------------------------------------------------------------------------------
6. Evaluate the model with the test dataset
test_loss = evaluate(best_model, test_data)
print('='*89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
=========================================================================================
| End of training | test loss 5.46 | test ppl 234.11
=========================================================================================