Fairseq 框架剖析(一)之task

2 篇文章 0 订阅
1 篇文章 1 订阅

前言

Pytorch 在目前(2019年12月19)已是比较火热的一个深度学习框架,在学术界的使用率甚至已经超过了Tensorflow. 在深度学习中,序列到序列(seq2seq)是一个比较重要,也适应领域广泛的问题,比如:机器翻译、语音合成与识别和手写识别等领域都有使用。此类模型结构也日益增多,LSTM、CONV、Transformer。因此,推出一个序列到序列模型的框架,就显得很有必要。googleTensorflow为基础推出了tensor2tensor。FaceBook以Pytorch为基础,推出了Fairseq。本文旨在深度剖析Fairseq框架。

因为我在平时使用fairseq框架较多,一直想对其有个深入了解,但始终没有机会和大段的空余时间来进项这项工作。我想不如就已博客的形式,对框架分小项进行剖析。一来方便自己利用琐碎时间,二来方便总结和回顾,再者也能够和这个框架的爱好者进行交流,即使纠正自己的一些问题。

Task 是Fairseq框架中比较重要的一个概念,从训练到推理阶段都离不开它。详细了解Task的细节,对理解Fairseq整体,或者拓展Fairseq都会有所帮助。因为从其出发几乎可以延伸到Fairseq的各个部分。这也是为什么,我决定从task部分开始剖析Fairseq框架。

我使用的代码是其Github上截止目前最新(20191218)的代码。前几天Fairseq 团队刚刚发布了0.9版本。因为fairseq处于快速迭代过程中,不同版本之间还是会有一些差异。但我相信整体的设计,应该不会有太大的改变。

因为这些文章是写给自己的,脉络可能不会很清晰,也不会采用总分,或先抽象再具体的结构,所以阅读时可能会比较跳跃~~。


初识Task

Task 是什么

Task直译过来就是任务。那么什么是任务?翻译是任务,语言模型是任务,文本分类也是任务。以上提到的几类任务,都已被包含在fairseq中。当然,除此之外我们也可以定义自己的任务。比如,假设我的任务是翻译的同时进行文本分类,那可能需要对任务进行一些拓展。

Task 可以存储词典,提供加载迭代数据集的帮助,初始化模型准则(Criterion),以及计算模型损失等等。

Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss.

从上面的内容,可以看出 Task在框架中扮演着重要角色。

Train 中的 Task

使用Fairseq重要的一步就是要训练模型。Fairseq提供了方便的终端操作方法:

fairseq-train --arch transformer ...

具体细节,可以访问其官网的full-document查看,这里就不再赘述。

这一命令相当于执行源码中的train.py 文件,我们来看一下,在这个文件中,fairseq 使用task 做了些什么。

  • 实例化task
    task = tasks.setup_task(args)
  • 加载验证集
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)
  • 加载训练集
    trainer = Trainer(args, task, model, criterion)
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
    epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
  • 加载model 和 criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
  • 计算loss
    log_output = trainer.train_step(samples) #train.py
    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,     ignore_grad  #trainer.py                  )
Task 的结构

Task 均继承 fairseq_task.py 中的 FairseqTask类。所以我们详细了解FairseqTask基类后,其他的子类则比较好理解。接下来深入了解一下fairseq_task.py,会列举一些比较重要的方法。

  • 类函数
    比较重要的几个类函数load_dictionary,build_dictionary,setup_task.
    load_dictionarybuild_dictionary 分别是加载和创建词典的功能。
    setup_task该类方法则会返回一个task实例,具体代码如下:
    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        return cls(args, **kwargs)

args 为我们传入的各种超参数。

该函数参数中的cls 为已注册的FairseqTask 一个具体子类,具体为哪一个,是我们可以使用--task 选项指定的。
如果我们看一下 tasks 文件下的__init__.py 文件,会发现如下代码:


TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()

def setup_task(args, **kwargs):
    return TASK_REGISTRY[args.task].setup_task(args, **kwargs)
    
# 用于注册task的装饰器
def register_task(name):
        def register_task_cls(cls):
        ...
    return register_task_cls

# 自动import tasks文件夹下的所有python文件
for file in os.listdir(os.path.dirname(__file__)):
    ...

def get_task(name):
    return TASK_REGISTRY[name]

__init__.py 为我们提供了使用task的接口,其实Fairseq下的其他模块也大概类似的通过各个模块下的__init__.py 暴露功能给我们。
通过这个文件,我们能够知道,如果我们要定义自己的task,需要在自己定义的类前面加上register_task装饰器,然后将自己的文件放到 tasks目录下,即可通过注册时的name进行实例化调用。
使用tasks代码:

from fairseq import tasks
import argparse
args = argparse.Namespace(task="task_name")  # 定义task类别
task = tasks.setup_task(args) # 实例化 task
  • 其他方法

build_model(self, args): 创建模型
build_criterion(self, args): 创建损失函数
def build_generator(self, args): 创建生成器,用于产生预测结果
train_step(self, sample, model, criterion, optimizer, ignore_grad=False) : 一次训练
source_dictionary(self): 源数据词典
target_dictionary(self): 目标数据词典
get_batch_iterator:传入数据集获取迭代器

  • 10
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值