前言
本文是在阅读fairseq官方文档的基础上做一些翻译以及自身理解的记录。水平有限,如有错误,敬请指正!
一、插件
Fairseq 支持五种通过用户提供的插件进行扩展:
- Models:定义了神经网络的架构,并封装了所有可学习的参数。
- Criterions:根据模型的输出和目标值计算损失函数。
- Tasks:存储字典并提供用于加载/迭代数据集的辅助功能,初始化模型/损失函数,并计算损失。
- Optimizers:根据梯度更新模型参数。
- Learning Rate Schedulers:在训练过程中更新学习率。
二、训练流程
2.1 概述
给定物种插件的情况下,fairseq 实现了以下高级训练流程:
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
即:
- 循环迭代每个训练周期 (epoch):外层循环是针对训练周期的迭代,通常在整个训练数据集上多次迭代以便模型能够学习到数据的不同特征。
- 获取批量迭代器 (Batch Iterator):在每个训练周期内,内层循环使用任务 (task) 的 get_batch_iterator 方法来获取一个数据批量的迭代器。这个迭代器(itr)用于逐批次地处理训练数据,每次迭代会生成一个 batch 的数据。
- 逐批次进行训练:内层循环逐批次处理训练数据,对每个批次执行以下操作:
- task.train_step(batch, model, criterion, optimizer):执行模型的前向传播,计算损失函数,反向传播并且计算梯度。
- average_and_clip_gradients():对计算得到的梯度进行平均和剪裁。这通常有助于稳定训练过程。
- optimizer.step():使用梯度来更新模型的参数。这一步是训练中的关键步骤,通过梯度下降来更新模型,以使损失函数最小化。
- lr_scheduler.step_update(num_updates):更新学习率调度器,通常是在每个迭代步骤中更新以适应训练进程。
- 更新学习率调度器 (Learning Rate Scheduler):在每个训练周期结束后,通过调用 lr_scheduler.step(epoch) 来更新学习率调度器的状态,以便下一个训练周期使用不同的学习率。
其中默认task.train_step的实现大致是:
def train_step(self, batch, model, criterion, optimizer, **unused):
loss = criterion(model, batch)
optimizer.backward(loss)
return loss
2.2 注册新插件
新插件是通过一组@register
函数装饰器注册的,例如:
关于装饰器,可以去这篇文章看一看,写的很好!
@register_model('my_lstm')
class MyLSTM(FairseqEncoderDecoderModel):
(...)
一经注册,新插件可以与现有的命令行工具一起使用。
2.3 从其他目录导入插件
新的插件可以在用户系统中存储的自定义模块中定义。为了导入模块并使插件可用于fairseq,命令行支持--user-dir
以用于指定附加模块加载到fairseq 的自定义位置的标志。
示例:目录结构树如下
/home/user/my-module/
└── __init__.py
__init__.py
(当一个文件夹被当成模块调用时,其__init__.py会被执行。一个py文件执行的方式就是自上而下的过程,会把import区,装饰器,和函数外的语句全部执行):
from fairseq.models import register_model_architecture
from fairseq.models.transformer import transformer_vaswani_wmt_en_de_big
@register_model_architecture('transformer', 'my_transformer')
def transformer_mmt_big(args):
transformer_vaswani_wmt_en_de_big(args)
可以使用新架构调用fairseq-train脚本:
fairseq-train ... --user-dir /home/user/my-module -a my_transformer --task translation