DiAD代码逐行理解train.py - pytorch_lightning.Trainer

一、训练器初始化函数 pytorch_lightning.Trainer

trainer = pl.Trainer(gpus=2, precision=16, callbacks=[logger,ckpt_callback_val_loss], accumulate_grad_batches=4, check_val_every_n_epoch=25)

PyTorch Lightning中,Trainer 类是用于管理训练循环的核心类,它提供了大量的配置选项来优化你的训练过程。下面是对你提供的 Trainer 配置参数的详细解释:

gpus=2: 这个参数指定了训练过程中将使用的GPU数量。在这个例子中,它被设置为2,意味着训练将在两个GPU上进行分布式训练(如果可用)。这可以显著加速训练过程,特别是对于大型模型和数据集。
precision=16: 这个参数设置了训练过程中的数值精度。precision=16 表示使用半精度浮点数(FP16)进行训练。与默认的32位浮点数(FP32)相比,FP16可以减少内存占用,有时还可以加速训练过程,但可能会引入数值稳定性问题。PyTorch Lightning和底层框架(如PyTorch)提供了自动混合精度(Automatic Mixed Precision, AMP)来优化这一过程,以最小化精度损失。
callbacks=[logger, ckpt_callback_val_loss]: 这个参数是一个回调函数列表,用于在训练的不同阶段执行自定义操作。在这个例子中,它包含了两个回调函数:logger 和 ckpt_callback_val_loss。logger 可能用于记录训练过程中的日志信息,如损失值、准确率等;而 ckpt_callback_val_loss 可能是一个自定义的回调函数,用于在验证损失达到某个条件时保存模型检查点。
accumulate_grad_batches=4: 这个参数指定了在执行一次参数更新之前,需要累积多少个批次(batch)的梯度。这对于小批量训练特别有用,因为它可以增加有效批量大小,从而有助于稳定训练过程并可能提高模型性能。在这个例子中,每4个批次的梯度将被累积起来,然后用于一次参数更新。
check_val_every_n_epoch=25: 这个参数指定了每多少个训练周期(epoch)后评估一次验证集。默认情况下,PyTorch Lightning在每个epoch结束时都会评估验证集,但通过设置这个参数,你可以减少评估的频率,以节省计算资源。在这个例子中,它被设置为25,意味着每25个epoch后才会评估一次验证集。
这个 Trainer 配置旨在利用两个GPU进行分布式训练,使用半精度浮点数来加速训练过程,并通过回调函数记录日志和保存检查点。同时,它通过累积梯度来增加有效批量大小,并减少验证集的评估频率以节省资源。这样的配置对于处理大规模数据集和复杂模型特别有用。

二、训练器调用函数 trainer.fit

trainer.fit() 方法是 PyTorch Lightning 框架中用于启动模型训练过程的核心方法。trainer.fit() 方法的基本用法是将模型、训练数据加载器和验证数据加载器作为参数传入,从而开始训练过程。

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

参数说明:
**model:**这是你的 PyTorch Lightning 模型,它应该继承自 pl.LightningModule。
**train_dataloader/train_dataloaders:**训练数据加载器,用于在训练过程中加载数据。这通常是一个 PyTorch 的 DataLoader 对象。
**val_dataloader/val_dataloaders:**验证数据加载器,用于在训练过程中评估模型的性能。同样,这也是一个 PyTorch 的 DataLoader 对象。
额外参数和配置:
trainer.fit() 方法还可以接受其他关键字参数来配置训练过程,但这些参数通常是在创建 Trainer 实例时通过其构造函数设置的,而不是直接传递给 fit 方法。以下是一些常见的 Trainer 配置项:

max_epochs:训练的最大轮数。
gpus:指定使用的 GPU 数量。如果为 None,则使用 CPU。
logger:日志记录器,用于记录训练过程中的日志信息。
checkpoint_callback:检查点回调,用于在训练过程中保存模型状态。
accumulate_grad_batches:梯度累积的批次数。这允许在显存有限的情况下模拟较大的批量大小。
resume_from_checkpoint(已弃用):在旧版本中用于从检查点恢复训练,但在新版本中已被 ckpt_path 替代。
ckpt_path:指定从哪个检查点文件恢复训练。

  • 10
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值