Fairseq训练Wav2vec2代码阅读记录
从train开始读代码
初次阅读语音模型相关代码,如有错误多多指正~
本文主要记录阅读解析fairseq训练wav2vec2的相关代码细节。在建立trainer的过程中,需要设置并建立task、model和criterion三部分,下面详细记录以训练wav2vec2为例相关的代码。
代码链接:https://github.com/facebookresearch/fairseq
train的主函数
examples/wav2vec/README.md 中的训练shell命令示例如下:
$ fairseq-hydra-train \
task.data=/path/to/data \
--config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \
--config-name wav2vec2_large_librivox
其本质上运行的是 fairseq_cli/hydra_train.py,其中_hydra_main中的distributed_utils.call_main(cfg, pre_main, **kwargs)
显示训练主函数在pre_main,即fairseq_cli/train.py的main函数上。
hydra_train.py中利用hydra库用来管理ymal参数配置文件。使用hydra.main来读取yaml参数文件的,没用yaml.load。
train.py中的main函数,已经获得了完整的参数并打印到log。
Step1:setup task
task = tasks.setup_task(cfg.task)
tasks是从fairseq/tasks中import的,调用的setup_task在函数应在fairseq/tasks/__init__.py
文件中。
def setup_task(cfg: FairseqDataclass, **kwargs):
task = None
task_name = getattr(cfg, "task", None)
if isinstance(task_name, str):
# legacy tasks
task = TASK_REGISTRY[task_name]
if task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = dc.from_namespace(cfg)
else:
task_name = getattr(cfg, "_name", None)
if task_name and task_name in TASK_DATACLASS_REGISTRY:
remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"]
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing)
task = TASK_REGISTRY[task_name]
assert (
task is not None
), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}"
return task.setup_task(cfg, **kwargs)
task_name必须在TASK_REGISTRY.keys()
中,TASK_REGISTRY.keys()
中包含fairseq/tasks中所有的python文件的文件名,如audio_classification、audio_pretraining、translation等(tasks/xx.py 中有如@register_task("audio_pretraining", dataclass=AudioPretrainingConfig)
的语句来将task注册到TASK_REGISTRY列表中去)。由以下代码可知,task.setup_task返回的是指定task_name对应的类,比如audio_pretraining对应的类是AudioPretrainingTask。
@classmethod
def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
小注:上述代码段中@classmethod修饰符是一种类方法,无需实例化对象,直接用类本身调用即可,即task.setup_task(cfg, **kwargs)
总结:tasks.setup_task得到的task就是一个类,由参数yaml文件中的task._name决定具体是fairseq/tasks中的哪个类,wav2vec2对应的就是AudioPretrainingTask任务类。
Step2:Build model
model = task.build_model(cfg.model)
根据Step1,task为类AudioPretrainingTask
,该类中的build_model函数如下:
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
# if "w2v_args" in actualized_cfg:
if hasattr(actualized_cfg, "w2v_args"):
model_cfg.w2v_args = actualized_cfg.w2v_args
return model
继承了父类FairseqTask的build_model函数(在fairseq/tasks/fairseq_task.py
中定义):
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
cfg (FairseqDataclass): configuration object
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(cfg, self, from_checkpoint)
model = quantization_utils.quantize_model_scalar(model, cfg)
return model
其中又调用了fairseq/models/__init__.py
中的build_model函数:
def build_model(cfg: FairseqDataclass, task, from_checkpoint=False):
model = None
model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None)
if not model_type and len(cfg) == 1:
# this is hit if config object is nested in directory that is named after model type
model_type = next(iter(cfg))
if model_type in MODEL_DATACLASS_REGISTRY:
cfg = cfg[model_type]
else:
raise Exception(
"Could not infer model type from directory. Please add _name field to indicate model type. "
"Available models: "
+ str(MODEL_DATACLASS_REGISTRY.keys())
+ " Requested model type: "
+ model_type
)
if model_type in ARCH_MODEL_REGISTRY:
# case 1: legacy models
model = ARCH_MODEL_REGISTRY[model_type]
elif model_type in MODEL_DATACLASS_REGISTRY:
# case 2: config-driven models
model = MODEL_REGISTRY[model_type]
if model_type in MODEL_DATACLASS_REGISTRY:
# set defaults from dataclass. note that arch name and model name can be the same
dc = MODEL_DATACLASS_REGISTRY[model_type]
if isinstance(cfg, argparse.Namespace):
cfg = dc.from_namespace(cfg)
else:
cfg = merge_with_parent(dc(), cfg, from_checkpoint)
else:
if model_type in ARCH_CONFIG_REGISTRY:
with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack():
# this calls the different "arch" functions (like base_architecture()) that you indicate
# if you specify --arch on the command line. this is only applicable to the old argparse based models
# hydra models should expose different architectures via different config files
# it will modify the cfg object and default parameters according to the arch
ARCH_CONFIG_REGISTRY[model_type](cfg)
assert model is not None, (
f"Could not infer model type from {cfg}. "
"Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
+ f" Requested model type: {model_type}"
)
return model.build_model(cfg, task)
在models/__init__.py
中,与tasks/__init__.py
中类似,有register_model函数来注册模型类别到MODEL_REGISTRY列表中,比如models/wav2vec/wav2vec2.py中的调用部分@register_model("wav2vec2", dataclass=Wav2Vec2Config)
。上述代码段中的model即类Wav2Vec2Model,返回的是Wav2Vec2Model.build_model()的结果,即一个新的Wav2Vec2Model模型实例。
总结:task.build_model()根据参数文件中设定的model._name参数,建立相对应的模型类实例,wav2vec2对应的就是Wav2Vec2Model模型实例。
Step3:Build criterion
criterion = task.build_criterion(cfg.criterion)
AudioPretrainingTask类中没有build_criterion函数,遂知其调用的是其父类FairseqTask中的该函数。
def build_criterion(self, cfg: DictConfig, from_checkpoint=False):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
cfg (omegaconf.DictConfig): configration object
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint)
由以上代码段可知,调用了fairseq/criterions/__init__.py
中的build_criterion函数:
(
build_criterion_,
register_criterion,
CRITERION_REGISTRY,
CRITERION_DATACLASS_REGISTRY,
) = registry.setup_registry(
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
)
def build_criterion(cfg: DictConfig, task, from_checkpoint=False):
return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)
其中build_criterion_又来自fairseq/registry.py中setup_registry函数的第一个产出,register_criterion是第二个产出,用于注册criterion到列表REGISTRY中。对于wav2vec2而言,yaml参数文件中criterion._name=wav2vec,在fairseq/criterion/wav2vec_criterion.py通过@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
进行注册,build_criterion_的返回是类Wav2vecCriterion的一个实例。
总结:task.build_criterion()根据设定的criterion._name参数,建立相对应的criterion类实例,wav2vec2对应的就是Wav2vecCriterion的一个实例。
Step4:Build trainer
if cfg.common.model_parallel_size == 1:
trainer = Trainer(cfg, task, model, criterion, quantizer)
else:
trainer = MegatronTrainer(cfg, task, model, criterion)
Trainer用于同步分布式数据并行训练,多个works每个都有整个model replica,并且在每次更新前在workers之间累积梯度。 MegatronTrainer是一种模型并行的训练方法,与Trainer中有区别的是, Megatron的模型并行将模型进行切分,放在不同的GPU上进行模型的训练。Trainer是 MegatronTrainer的父类。trainer作为train函数的输入,训练一个epoch的代码:
# train for one epoch
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
train函数训练一个epoch,并且输出验证集loss。其中epoch_itr是由fairseq/trainer.py中get_train_iterator
函数得到,它返回在训练集上一个指定epoch的EpochBatchIterator
,包括load_dataset、get_batch_iterator
等。train函数中的训练部分集中在:
log_output = trainer.train_step(samples)
其中的训练步骤又主要集中于task.train_step
,即:
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
criterion=self.criterion,
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
**extra_kwargs,
)
task.train_step对应于tasks/fairseq_task.py中的train_step函数:
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
从中可见,三个输出项主要来自于criterion,也就是说criterion在训练过程中承担计算loss的功能。详见Wav2vecCriterion类的forward函数,训练过程中打印的log也大多是在该函数中产生。