前言
Pytorch
在目前(2019年12月19)已是比较火热的一个深度学习框架,在学术界的使用率甚至已经超过了Tensorflow
. 在深度学习中,序列到序列(seq2seq)是一个比较重要,也适应领域广泛的问题,比如:机器翻译、语音合成与识别和手写识别等领域都有使用。此类模型结构也日益增多,LSTM、CONV、Transformer。因此,推出一个序列到序列模型的框架,就显得很有必要。google
以Tensorflow
为基础推出了tensor2tensor
。FaceBook以Pytorch
为基础,推出了Fairseq
。本文旨在深度剖析Fairseq
框架。
因为我在平时使用fairseq
框架较多,一直想对其有个深入了解,但始终没有机会和大段的空余时间来进项这项工作。我想不如就已博客的形式,对框架分小项进行剖析。一来方便自己利用琐碎时间,二来方便总结和回顾,再者也能够和这个框架的爱好者进行交流,即使纠正自己的一些问题。
Task
是Fairseq框架中比较重要的一个概念,从训练到推理阶段都离不开它。详细了解Task
的细节,对理解Fairseq
整体,或者拓展Fairseq都会有所帮助。因为从其出发几乎可以延伸到Fairseq
的各个部分。这也是为什么,我决定从task
部分开始剖析Fairseq
框架。
我使用的代码是其Github上截止目前最新(20191218)的代码。前几天Fairseq 团队刚刚发布了0.9
版本。因为fairseq
处于快速迭代过程中,不同版本之间还是会有一些差异。但我相信整体的设计,应该不会有太大的改变。
因为这些文章是写给自己的,脉络可能不会很清晰,也不会采用总分,或先抽象再具体的结构,所以阅读时可能会比较跳跃~~。
初识Task
Task 是什么
Task
直译过来就是任务。那么什么是任务?翻译
是任务,语言模型
是任务,文本分类
也是任务。以上提到的几类任务,都已被包含在fairseq
中。当然,除此之外我们也可以定义自己的任务
。比如,假设我的任务是翻译的同时进行文本分类,那可能需要对任务进行一些拓展。
Task
可以存储词典,提供加载
和迭代
数据集的帮助,初始化模型
和准则(Criterion)
,以及计算模型损失等等。
Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.
从上面的内容,可以看出 Task
在框架中扮演着重要角色。
Train 中的 Task
使用Fairseq
重要的一步就是要训练模型。Fairseq
提供了方便的终端操作方法:
fairseq-train --arch transformer ...
具体细节,可以访问其官网的full-document查看,这里就不再赘述。
这一命令相当于执行源码中的train.py
文件,我们来看一下,在这个文件中,fairseq 使用task
做了些什么。
- 实例化task
task = tasks.setup_task(args)
- 加载验证集
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=0)
- 加载训练集
trainer = Trainer(args, task, model, criterion)
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
- 加载model 和 criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
- 计算loss
log_output = trainer.train_step(samples) #train.py
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer, ignore_grad #trainer.py )
Task 的结构
Task
均继承 fairseq_task.py
中的 FairseqTask
类。所以我们详细了解FairseqTask
基类后,其他的子类则比较好理解。接下来深入了解一下fairseq_task.py
,会列举一些比较重要的方法。
- 类函数
比较重要的几个类函数load_dictionary
,build_dictionary
,setup_task
.
load_dictionary
和build_dictionary
分别是加载和创建词典的功能。
setup_task
该类方法则会返回一个task实例,具体代码如下:
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args, **kwargs)
args
为我们传入的各种超参数。
该函数参数中的cls
为已注册的FairseqTask
一个具体子类,具体为哪一个,是我们可以使用--task
选项指定的。
如果我们看一下 tasks
文件下的__init__.py
文件,会发现如下代码:
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args, **kwargs):
return TASK_REGISTRY[args.task].setup_task(args, **kwargs)
# 用于注册task的装饰器
def register_task(name):
def register_task_cls(cls):
...
return register_task_cls
# 自动import tasks文件夹下的所有python文件
for file in os.listdir(os.path.dirname(__file__)):
...
def get_task(name):
return TASK_REGISTRY[name]
此__init__.py
为我们提供了使用task的接口,其实Fairseq
下的其他模块也大概类似的通过各个模块下的__init__.py
暴露功能给我们。
通过这个文件,我们能够知道,如果我们要定义自己的task,需要在自己定义的类前面加上register_task
装饰器,然后将自己的文件放到 tasks目录下,即可通过注册时的name进行实例化调用。
使用tasks代码:
from fairseq import tasks
import argparse
args = argparse.Namespace(task="task_name") # 定义task类别
task = tasks.setup_task(args) # 实例化 task
- 其他方法
build_model(self, args)
: 创建模型
build_criterion(self, args)
: 创建损失函数
def build_generator(self, args)
: 创建生成器,用于产生预测结果
train_step(self, sample, model, criterion, optimizer, ignore_grad=False)
: 一次训练
source_dictionary(self)
: 源数据词典
target_dictionary(self)
: 目标数据词典
get_batch_iterator
:传入数据集获取迭代器