fairseq给出的训练框架中,包含几个部分,main()函数,train(), get_traning_stats(),validate(),get_valid_stats()。这个框架是不改变的,我们通过在fairseq/tasks中注册自己的task,fairseq/models中注册自己的model,fairseq/critirion中注册自己的critirion来完成基于fairseq框架的训练。我们来看一下在fairseq框架中如何调用我们自己定义的部分。
(1)首先建立task:
task = tasks.setup_task(args)
而调用的这个task正是我们自己定义的task。
setup_task的主要作用是读入src_dict和tgt_dict
@register_task('guided_translation')
class GuidedTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def setup_task(cls, args, **kwargs):
src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
return cls(args, src_dict, tgt_dict)#返回task的实体
(2)建立model
model = task.build_model(args)
而在task.build_model()中有:
return models.build_model(args, self)
build_model在上一篇叙述model的文章中提及过,
@register_model("guided_transformer")
class GuidedTransformerModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
self.supports_align_args = True
def build_model(cls, args, task):
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return cls(args, encoder, decoder)
(3)建立criterion
criterion = task.build_criterion(args)
在guided transformer中,采用的是已经注册的label_smoothed_cross_entropy
(4)建立trainer
在fairseq中有trainer.py 可以根据需求提取其中的功能
trainer.py中class Trainer定义了 get_train_iterator,save_check_poinnt, load_check_point, train_step, valid_step等以及一些参数的接口。
from fairseq.trainer import Trainer
trainer = Trainer(args, task, model, criterion)
(5)读取、保存断点
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
在load_checkpoint中,做了几项工作
首先,通过class trainer中的load_checkpoint读取checkpoint_last.pt
extra_state = trainer.load_checkpoint(
checkpoint_path,
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
reset_meters=args.reset_meters,
)
其次,通过get_train_iterator获取下一epoch的训练数据
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator(
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=0, load_dataset=True, **passthrough_args
)
(6)使用fairseq.logging 中的meters来对训练进行计时
train_meter = meters.StopwatchMeter() #"""Computes the sum/avg duration of some event in seconds"""
train_meter.start()
train过程
train_meter.stop()
logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
(7)开始训练
在以下条件成立时循环进行训练
while (
lr > args.min_lr
and (
epoch_itr.epoch < max_epoch
# allow resuming training from the final checkpoint
or epoch_itr._next_epoch_itr is not None
)
and trainer.get_num_updates() < max_update
):
train(args, trainer, task, epoch_itr)
在train函数中做如下几步
调用trainer.train_step
for samples in progress:
log_output = trainer.train_step(samples)
num_updates = trainer.get_num_updates()
if log_output is None:
continue
在fairseq.trainer.py的train_step中,重点做了以下几步
self._set_seed()
self.model.train()
self.criterion.train()
self.zero_grad()
调用fairseq task. train_step计算结果以及损失函数
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
criterion=self.criterion,
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
)
在验证集上评估模型并返回损失值
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
调用与train_step相似的valid_step完成valid_loss计算
调用fairseq中的trainer class中的lr_step来更新learning rate
only usue first validation loss to update the learing rate.
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
保存checkpoint
if epoch_itr.epoch % args.save_interval == 0:
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
获取下一个train epoch
epoch_itr = trainer.get_train_iterator(
epoch_itr.epoch,
# sharded data: get train iterator for next epoch
load_dataset=(os.pathsep in getattr(args, 'data', '')),
)