Fairseq训练Wav2vec2代码阅读记录

从train开始读代码

初次阅读语音模型相关代码,如有错误多多指正~

本文主要记录阅读解析fairseq训练wav2vec2的相关代码细节。在建立trainer的过程中,需要设置并建立task、model和criterion三部分,下面详细记录以训练wav2vec2为例相关的代码。
代码链接:https://github.com/facebookresearch/fairseq

train的主函数

examples/wav2vec/README.md 中的训练shell命令示例如下:

$ fairseq-hydra-train \
    task.data=/path/to/data \
    --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
    --config-name wav2vec2_large_librivox

其本质上运行的是 fairseq_cli/hydra_train.py,其中_hydra_main中的distributed_utils.call_main(cfg, pre_main, **kwargs)显示训练主函数在pre_main,即fairseq_cli/train.py的main函数上。

hydra_train.py中利用hydra库用来管理ymal参数配置文件。使用hydra.main来读取yaml参数文件的,没用yaml.load。

train.py中的main函数,已经获得了完整的参数并打印到log。

Step1:setup task

task = tasks.setup_task(cfg.task)

tasks是从fairseq/tasks中import的,调用的setup_task在函数应在fairseq/tasks/__init__.py文件中。

def setup_task(cfg: FairseqDataclass, **kwargs):
    task = None
    task_name = getattr(cfg, "task", None)

    if isinstance(task_name, str):
        # legacy tasks
        task = TASK_REGISTRY[task_name]
        if task_name in TASK_DATACLASS_REGISTRY:
            dc = TASK_DATACLASS_REGISTRY[task_name]
            cfg = dc.from_namespace(cfg)
    else:
        task_name = getattr(cfg, "_name", None)

        if task_name and task_name in TASK_DATACLASS_REGISTRY:
            remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"]
            dc = TASK_DATACLASS_REGISTRY[task_name]
            cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing)
            task = TASK_REGISTRY[task_name]

    assert (
        task is not None
    ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}"

    return task.setup_task(cfg, **kwargs)

task_name必须在TASK_REGISTRY.keys()中,TASK_REGISTRY.keys()中包含fairseq/tasks中所有的python文件的文件名,如audio_classification、audio_pretraining、translation等(tasks/xx.py 中有如@register_task("audio_pretraining", dataclass=AudioPretrainingConfig)的语句来将task注册到TASK_REGISTRY列表中去)。由以下代码可知,task.setup_task返回的是指定task_name对应的类,比如audio_pretraining对应的类是AudioPretrainingTask。

	@classmethod
    def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            cfg (AudioPretrainingConfig): configuration of this task
        """

        return cls(cfg)

小注:上述代码段中@classmethod修饰符是一种类方法,无需实例化对象,直接用类本身调用即可,即task.setup_task(cfg, **kwargs)

总结:tasks.setup_task得到的task就是一个类,由参数yaml文件中的task._name决定具体是fairseq/tasks中的哪个类,wav2vec2对应的就是AudioPretrainingTask任务类。

Step2:Build model

model = task.build_model(cfg.model)

根据Step1,task为类AudioPretrainingTask,该类中的build_model函数如下:

def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
	model = super().build_model(model_cfg, from_checkpoint)
	
	actualized_cfg = getattr(model, "cfg", None)
	if actualized_cfg is not None:
	    # if "w2v_args" in actualized_cfg:
	    if hasattr(actualized_cfg, "w2v_args"):
	        model_cfg.w2v_args = actualized_cfg.w2v_args
	
	return model

继承了父类FairseqTask的build_model函数(在fairseq/tasks/fairseq_task.py中定义):

def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
	"""
	Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
	task.
	Args:
	    cfg (FairseqDataclass): configuration object
	Returns:
	    a :class:`~fairseq.models.BaseFairseqModel` instance
	"""
	from fairseq import models, quantization_utils
	
	model = models.build_model(cfg, self, from_checkpoint)
	model = quantization_utils.quantize_model_scalar(model, cfg)
	return model

其中又调用了fairseq/models/__init__.py中的build_model函数:

def build_model(cfg: FairseqDataclass, task, from_checkpoint=False):

    model = None
    model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)

    if not model_type and len(cfg) == 1:
        # this is hit if config object is nested in directory that is named after model type

        model_type = next(iter(cfg))
        if model_type in MODEL_DATACLASS_REGISTRY:
            cfg = cfg[model_type]
        else:
            raise Exception(
                "Could not infer model type from directory. Please add _name field to indicate model type. "
                "Available models: "
                + str(MODEL_DATACLASS_REGISTRY.keys())
                + " Requested model type: "
                + model_type
            )

    if model_type in ARCH_MODEL_REGISTRY:
        # case 1: legacy models
        model = ARCH_MODEL_REGISTRY[model_type]
    elif model_type in MODEL_DATACLASS_REGISTRY:
        # case 2: config-driven models
        model = MODEL_REGISTRY[model_type]

    if model_type in MODEL_DATACLASS_REGISTRY:
        # set defaults from dataclass. note that arch name and model name can be the same
        dc = MODEL_DATACLASS_REGISTRY[model_type]

        if isinstance(cfg, argparse.Namespace):
            cfg = dc.from_namespace(cfg)
        else:
            cfg = merge_with_parent(dc(), cfg, from_checkpoint)
    else:
        if model_type in ARCH_CONFIG_REGISTRY:
            with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack():
                # this calls the different "arch" functions (like base_architecture()) that you indicate
                # if you specify --arch on the command line. this is only applicable to the old argparse based models
                # hydra models should expose different architectures via different config files
                # it will modify the cfg object and default parameters according to the arch
                ARCH_CONFIG_REGISTRY[model_type](cfg)

    assert model is not None, (
        f"Could not infer model type from {cfg}. "
        "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
        + f" Requested model type: {model_type}"
    )

    return model.build_model(cfg, task)

models/__init__.py中,与tasks/__init__.py中类似,有register_model函数来注册模型类别到MODEL_REGISTRY列表中,比如models/wav2vec/wav2vec2.py中的调用部分@register_model("wav2vec2", dataclass=Wav2Vec2Config)。上述代码段中的model即类Wav2Vec2Model,返回的是Wav2Vec2Model.build_model()的结果,即一个新的Wav2Vec2Model模型实例。

总结:task.build_model()根据参数文件中设定的model._name参数,建立相对应的模型类实例,wav2vec2对应的就是Wav2Vec2Model模型实例。

Step3:Build criterion

criterion = task.build_criterion(cfg.criterion)

AudioPretrainingTask类中没有build_criterion函数,遂知其调用的是其父类FairseqTask中的该函数。

def build_criterion(self, cfg: DictConfig, from_checkpoint=False):
	"""
	Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
	this task.
	
	Args:
	    cfg (omegaconf.DictConfig): configration object
	
	Returns:
	    a :class:`~fairseq.criterions.FairseqCriterion` instance
	"""
	from fairseq import criterions
	return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint)

由以上代码段可知,调用了fairseq/criterions/__init__.py中的build_criterion函数:

(
    build_criterion_,
    register_criterion,
    CRITERION_REGISTRY,
    CRITERION_DATACLASS_REGISTRY,
) = registry.setup_registry(
    "--criterion", base_class=FairseqCriterion, default="cross_entropy"
)
def build_criterion(cfg: DictConfig, task, from_checkpoint=False):
    return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)

其中build_criterion_又来自fairseq/registry.py中setup_registry函数的第一个产出,register_criterion是第二个产出,用于注册criterion到列表REGISTRY中。对于wav2vec2而言,yaml参数文件中criterion._name=wav2vec,在fairseq/criterion/wav2vec_criterion.py通过@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)进行注册,build_criterion_的返回是类Wav2vecCriterion的一个实例。

总结:task.build_criterion()根据设定的criterion._name参数,建立相对应的criterion类实例,wav2vec2对应的就是Wav2vecCriterion的一个实例。

Step4:Build trainer

if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

Trainer用于同步分布式数据并行训练,多个works每个都有整个model replica,并且在每次更新前在workers之间累积梯度。 MegatronTrainer是一种模型并行的训练方法,与Trainer中有区别的是, Megatron的模型并行将模型进行切分,放在不同的GPU上进行模型的训练。Trainer是 MegatronTrainer的父类。trainer作为train函数的输入,训练一个epoch的代码:

# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)

train函数训练一个epoch,并且输出验证集loss。其中epoch_itr是由fairseq/trainer.py中get_train_iterator函数得到,它返回在训练集上一个指定epoch的EpochBatchIterator,包括load_dataset、get_batch_iterator等。train函数中的训练部分集中在:

log_output = trainer.train_step(samples)

其中的训练步骤又主要集中于task.train_step,即:

# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
    sample=sample,
    model=self.model,
    criterion=self.criterion,
    optimizer=self.optimizer,
    update_num=self.get_num_updates(),
    ignore_grad=is_dummy_batch,
    **extra_kwargs,
)

task.train_step对应于tasks/fairseq_task.py中的train_step函数:

def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        model.train()
        model.set_num_updates(update_num)
        with torch.autograd.profiler.record_function("forward"):
            with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
                loss, sample_size, logging_output = criterion(model, sample)
        if ignore_grad:
            loss *= 0
        with torch.autograd.profiler.record_function("backward"):
            optimizer.backward(loss)
        return loss, sample_size, logging_output

从中可见,三个输出项主要来自于criterion,也就是说criterion在训练过程中承担计算loss的功能。详见Wav2vecCriterion类的forward函数,训练过程中打印的log也大多是在该函数中产生。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值