attention is all your need是谷歌首次提出的,其摒弃了RNN与CNN,改用transformer模型,网路结构如下所示:
本次代码调试使用pytorch,要求python3环境,python文件共有三个,具体步骤为:下载数据、数据预处理、训练模型、测试模型。
数据预处理阶段,需要加载数据,构建词汇的索引,词转化等。分别定义read_instances_from_file、build_vocab_index、convert_instance_to_idx_seq、main等函数。部分代码如下所示:
def read_instances_from_file(inst_file, max_sent_len, keep_case):
''' Convert file into word seq lists and vocab '''
word_insts = []
trimmed_sent_count = 0
with open(inst_file) as f:
for sent in f:
if not keep_case:
sent = sent.lower()
words = sent.split()
if len(words) > max_sent_len:
trimmed_sent_count += 1
word_inst = words[:max_sent_len]
if word_inst:
word_insts += [[Constants.BOS_WORD] + word_inst + [Constants.EOS_WORD]]
else:
word_insts += [None]
print('[Info] Get {} instances from {}'.format(len(word_insts), inst_file))
if trimmed_sent_count > 0:
print('[Warning] {} instances are trimmed to the max sentence length {}.'
.format(trimmed_sent_count, max_sent_len))
return word_insts
模型训练阶段首先要进行网络的构建,按照论文所述,encoder和decoder分别有6层、每一层还有attention,同时需要定义损失函数、评估函数,同其他神经网络一样,在模型训练时需要将数据分为很多ecophs,训练阶段代码如下:
import argparse
import math
import time
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import transformer.Constants as Constants
from dataset import TranslationDataset, paired_collate_fn
from transformer.Models import Transformer
from transformer.Optim import ScheduledOptim
def cal_performance(pred, gold, smoothing=False):
''' Apply label smoothing if needed '''
loss = cal_loss(pred, gold, smoothing)
pred = pred.max(1)[1]
gold = gold.contiguous().view(-1)
non_pad_mask = gold.ne(Constants.PAD)
n_correct = pred.eq(gold)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return loss, n_correct
def cal_loss(pred, gold, smoothing):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.1
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
non_pad_mask = gold.ne(Constants.PAD)
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum() # average later
else:
loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum')
return loss
def train_epoch(model, training_data, optimizer, device, smoothing):
''' Epoch operation in training phase'''
model.train()
total_loss = 0
n_word_total = 0
n_word_correct = 0
for batch in tqdm(
training_data, mininterval=2,
desc=' - (Training) ', leave=False):
# prepare data
src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
gold = tgt_seq[:, 1:]
# forward
optimizer.zero_grad()
pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
# backward
loss, n_correct = cal_performance(pred, gold, smoothing=smoothing)
loss.backward()
# update parameters
optimizer.step_and_update_lr()
# note keeping
total_loss += loss.item()
non_pad_mask = gold.ne(Constants.PAD)
n_word = non_pad_mask.sum().item()
n_word_total += n_word
n_word_correct += n_correct
loss_per_word = total_loss/n_word_total
accuracy = n_word_correct/n_word_total
return loss_per_word, accuracy
def eval_epoch(model, validation_data, device):
''' Epoch operation in evaluation phase '''
model.eval()
total_loss = 0
n_word_total = 0
n_word_correct = 0
with torch.no_grad():
for batch in tqdm(
validation_data, mininterval=2,
desc=' - (Validation) ', leave=False):
# prepare data
src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
gold = tgt_seq[:, 1:]
# forward
pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
loss, n_correct = cal_performance(pred, gold, smoothing=False)
# note keeping
total_loss += loss.item()
non_pad_mask = gold.ne(Constants.PAD)
n_word = non_pad_mask.sum().item()
n_word_total += n_word
n_word_correct += n_correct
loss_per_word = total_loss/n_word_total
accuracy = n_word_correct/n_word_total
return loss_per_word, accuracy
def train(model, training_data, validation_data, optimizer, device, opt):
''' Start training '''
log_train_file = None
log_valid_file = None
if opt.log:
log_train_file = opt.log + '.train.log'
log_valid_file = opt.log + '.valid.log'
print('[Info] Training performance will be written to file: {} and {}'.format(
log_train_file, log_valid_file))
with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
log_tf.write('epoch,loss,ppl,accuracy\n')
log_vf.write('epoch,loss,ppl,accuracy\n')
valid_accus = []
for epoch_i in range(opt.epoch):
print('[ Epoch', epoch_i, ']')
start = time.time()
train_loss, train_accu = train_epoch(
model, training_data, optimizer, device, smoothing=opt.label_smoothing)
print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
'elapse: {elapse:3.3f} min'.format(
ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
elapse=(time.time()-start)/60))
start = time.time()
valid_loss, valid_accu = eval_epoch(model, validation_data, device)
print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
'elapse: {elapse:3.3f} min'.format(
ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
elapse=(time.time()-start)/60))
valid_accus += [valid_accu]
model_state_dict = model.state_dict()
checkpoint = {
'model': model_state_dict,
'settings': opt,
'epoch': epoch_i}
if opt.save_model:
if opt.save_mode == 'all':
model_name = opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
torch.save(checkpoint, model_name)
elif opt.save_mode == 'best':
model_name = opt.save_model + '.chkpt'
if valid_accu >= max(valid_accus):
torch.save(checkpoint, model_name)
print(' - [Info] The checkpoint file has been updated.')
if log_train_file and log_valid_file:
with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
epoch=epoch_i, loss=train_loss,
ppl=math.exp(min(train_loss, 100)), accu=100*train_accu))
log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
epoch=epoch_i, loss=valid_loss,
ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu))
def main():
''' Main function '''
parser = argparse.ArgumentParser()
parser.add_argument('-data', required=True)
parser.add_argument('-epoch', type=int, default=10)
parser.add_argument('-batch_size', type=int, default=64)
#parser.add_argument('-d_word_vec', type=int, default=512)
parser.add_argument('-d_model', type=int, default=512)
parser.add_argument('-d_inner_hid', type=int, default=2048)
parser.add_argument('-d_k', type=int, default=64)
parser.add_argument('-d_v', type=int, default=64)
parser.add_argument('-n_head', type=int, default=8)
parser.add_argument('-n_layers', type=int, default=6)
parser.add_argument('-n_warmup_steps', type=int, default=4000)
parser.add_argument('-dropout', type=float, default=0.1)
parser.add_argument('-embs_share_weight', action='store_true')
parser.add_argument('-proj_share_weight', action='store_true')
parser.add_argument('-log', default=None)
parser.add_argument('-save_model', default=None)
parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-label_smoothing', action='store_true')
opt = parser.parse_args()
opt.cuda = not opt.no_cuda
opt.d_word_vec = opt.d_model
#========= Loading Dataset =========#
data = torch.load(opt.data)
opt.max_token_seq_len = data['settings'].max_token_seq_len
training_data, validation_data = prepare_dataloaders(data, opt)
opt.src_vocab_size = training_data.dataset.src_vocab_size
opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size
#========= Preparing Model =========#
if opt.embs_share_weight:
assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
'The src/tgt word2idx table are different but asked to share word embedding.'
print(opt)
device = torch.device('cuda' if opt.cuda else 'cpu')
transformer = Transformer(
opt.src_vocab_size,
opt.tgt_vocab_size,
opt.max_token_seq_len,
tgt_emb_prj_weight_sharing=opt.proj_share_weight,
emb_src_tgt_weight_sharing=opt.embs_share_weight,
d_k=opt.d_k,
d_v=opt.d_v,
d_model=opt.d_model,
d_word_vec=opt.d_word_vec,
d_inner=opt.d_inner_hid,
n_layers=opt.n_layers,
n_head=opt.n_head,
dropout=opt.dropout).to(device)
optimizer = ScheduledOptim(
optim.Adam(
filter(lambda x: x.requires_grad, transformer.parameters()),
betas=(0.9, 0.98), eps=1e-09),
opt.d_model, opt.n_warmup_steps)
train(transformer, training_data, validation_data, optimizer, device ,opt)
def prepare_dataloaders(data, opt):
# ========= Preparing DataLoader =========#
train_loader = torch.utils.data.DataLoader(
TranslationDataset(
src_word2idx=data['dict']['src'],
tgt_word2idx=data['dict']['tgt'],
src_insts=data['train']['src'],
tgt_insts=data['train']['tgt']),
num_workers=2,
batch_size=opt.batch_size,
collate_fn=paired_collate_fn,
shuffle=True)
valid_loader = torch.utils.data.DataLoader(
TranslationDataset(
src_word2idx=data['dict']['src'],
tgt_word2idx=data['dict']['tgt'],
src_insts=data['valid']['src'],
tgt_insts=data['valid']['tgt']),
num_workers=2,
batch_size=opt.batch_size,
collate_fn=paired_collate_fn)
return train_loader, valid_loader
if __name__ == '__main__':
main()
模型训练好后,需要对其测试,测试阶段代码与训练大致相同,如下所示:
import torch
import torch.utils.data
import argparse
from tqdm import tqdm
from dataset import collate_fn, TranslationDataset
from transformer.Translator import Translator
from preprocess import read_instances_from_file, convert_instance_to_idx_seq
def main():
'''Main Function'''
parser = argparse.ArgumentParser(description='translate.py')
parser.add_argument('-model', required=True,
help='Path to model .pt file')
parser.add_argument('-src', required=True,
help='Source sequence to decode (one line per sequence)')
parser.add_argument('-vocab', required=True,
help='Source sequence to decode (one line per sequence)')
parser.add_argument('-output', default='pred.txt',
help="""Path to output the predictions (each line will
be the decoded sequence""")
parser.add_argument('-beam_size', type=int, default=5,
help='Beam size')
parser.add_argument('-batch_size', type=int, default=30,
help='Batch size')
parser.add_argument('-n_best', type=int, default=1,
help="""If verbose is set, will output the n_best
decoded sentences""")
parser.add_argument('-no_cuda', action='store_true')
opt = parser.parse_args()
opt.cuda = not opt.no_cuda
# Prepare DataLoader
preprocess_data = torch.load(opt.vocab)
preprocess_settings = preprocess_data['settings']
test_src_word_insts = read_instances_from_file(
opt.src,
preprocess_settings.max_word_seq_len,
preprocess_settings.keep_case)
test_src_insts = convert_instance_to_idx_seq(
test_src_word_insts, preprocess_data['dict']['src'])
test_loader = torch.utils.data.DataLoader(
TranslationDataset(
src_word2idx=preprocess_data['dict']['src'],
tgt_word2idx=preprocess_data['dict']['tgt'],
src_insts=test_src_insts),
num_workers=2,
batch_size=opt.batch_size,
collate_fn=collate_fn)
translator = Translator(opt)
with open(opt.output, 'w') as f:
for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False):
all_hyp, all_scores = translator.translate_batch(*batch)
for idx_seqs in all_hyp:
for idx_seq in idx_seqs:
pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
f.write(pred_line + '\n')
print('[Info] Finished.')
if __name__ == "__main__":
main()
首次运行该代码,效果并不太理想,可能和随机种子设置有关,也可能是代码未完全按照论文逻辑书写,需要进行后续调试。