fairseq框架下模型训练-train

6 篇文章 0 订阅
1 篇文章 0 订阅

fairseq给出的训练框架中,包含几个部分,main()函数,train(), get_traning_stats(),validate(),get_valid_stats()。这个框架是不改变的,我们通过在fairseq/tasks中注册自己的task,fairseq/models中注册自己的model,fairseq/critirion中注册自己的critirion来完成基于fairseq框架的训练。我们来看一下在fairseq框架中如何调用我们自己定义的部分。

(1)首先建立task:

task = tasks.setup_task(args)

而调用的这个task正是我们自己定义的task。
setup_task的主要作用是读入src_dict和tgt_dict

@register_task('guided_translation')
class GuidedTranslationTask(FairseqTask):
    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args)
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
    def setup_task(cls, args, **kwargs):
            src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
            tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
    return cls(args, src_dict, tgt_dict)#返回task的实体

(2)建立model

model = task.build_model(args)

而在task.build_model()中有:

return models.build_model(args, self)

build_model在上一篇叙述model的文章中提及过,

@register_model("guided_transformer")
class GuidedTransformerModel(FairseqEncoderDecoderModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(encoder, decoder)
        self.args = args
        self.supports_align_args = True
    def build_model(cls, args, task):
        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return cls(args, encoder, decoder)

(3)建立criterion

criterion = task.build_criterion(args)

在guided transformer中,采用的是已经注册的label_smoothed_cross_entropy

(4)建立trainer
在fairseq中有trainer.py 可以根据需求提取其中的功能
trainer.py中class Trainer定义了 get_train_iterator,save_check_poinnt, load_check_point, train_step, valid_step等以及一些参数的接口。

from fairseq.trainer import Trainer
trainer = Trainer(args, task, model, criterion)

(5)读取、保存断点

    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

在load_checkpoint中,做了几项工作
首先,通过class trainer中的load_checkpoint读取checkpoint_last.pt

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        args.reset_optimizer,
        args.reset_lr_scheduler,
        eval(args.optimizer_overrides),
        reset_meters=args.reset_meters,
    )

其次,通过get_train_iterator获取下一epoch的训练数据

    if extra_state is not None and not args.reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itr = trainer.get_train_iterator(
            epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
        )
        epoch_itr.load_state_dict(itr_state)
    else:
        epoch_itr = trainer.get_train_iterator(
            epoch=0, load_dataset=True, **passthrough_args
        )

(6)使用fairseq.logging 中的meters来对训练进行计时

    train_meter = meters.StopwatchMeter() #"""Computes the sum/avg duration of some event in seconds"""
    train_meter.start()
    train过程
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))

(7)开始训练
在以下条件成立时循环进行训练

    while (
        lr > args.min_lr
        and (
            epoch_itr.epoch < max_epoch
            # allow resuming training from the final checkpoint
            or epoch_itr._next_epoch_itr is not None
        )
        and trainer.get_num_updates() < max_update
    ):
train(args, trainer, task, epoch_itr)

在train函数中做如下几步
调用trainer.train_step

    for samples in progress:
        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

在fairseq.trainer.py的train_step中,重点做了以下几步

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

调用fairseq task. train_step计算结果以及损失函数

        loss, sample_size_i, logging_output = self.task.train_step(
               sample=sample,
               model=self.model,
               criterion=self.criterion,
               optimizer=self.optimizer,
               update_num=self.get_num_updates(),
               ignore_grad=is_dummy_batch,
                    )

在验证集上评估模型并返回损失值

valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

调用与train_step相似的valid_step完成valid_loss计算

调用fairseq中的trainer class中的lr_step来更新learning rate
only usue first validation loss to update the learing rate.

lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

保存checkpoint

if epoch_itr.epoch % args.save_interval == 0:
      checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

获取下一个train epoch

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.epoch,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值