Pytorch-Lightning--Trainer

Pytorch-Lightning中的训练器—Trainer

Trainer.__init__()

常用参数

参数名称含义默认值接受类型
callbacks添加回调函数或回调函数列表None(ModelCheckpoint默认值)Union[List[Callback], Callback, None]
enable_checkpointing是否使用callbacksTruebool
enable_progress_bar是否显示进度条Truebool
enable_model_summary是否打印模型摘要Truebool
precision指定训练精度32(full precision)Union[int, str]
default_root_dir模型保存和日志记录默认根路径None(os.getcwd())Optional[str]
logger设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dirTrue(默认日志记录)Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]
max_epochs最多训练轮数(指定为**-1可以设置为无限次**)None(1000)Optional[int]
min_epochs最少训练轮数。当有Early Stop时使用None(1)Optional[int]
max_steps最大网络权重更新次数-1(禁用)Optional[int]
min_steps最少网络权重更新次数None(禁用)Optional[int]
weights_save_path权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径None(default_root_dir)Optional[str]
log_every_n_steps更新n次网络权重后记录一次日志50int
limit_train_batches
limit_test_batches
limit_val_batches

limit_predict_batches
使用训练/测试/验证/预测数据的百分比.如果数据过多,或正在调试可以使用。1.0Union[int, float] (float = 比例, int = num_batches).
fast_dev_run如果设定为true,会只执行一个batch的train, val 和 test,然后结束。仅用于debugFalsebool
accumulate_grad_batches每k次batches累计一次梯度None(无梯度累计)Union[int, Dict[int, int], None]
check_val_every_n_epoch每n个train epoch执行一次验证1int
num_sanity_val_steps开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据2int

硬件加速相关选项

参数名称含义默认值接受类型
accelerator设置硬加类型autoUnion[str, Accelerator]
devices使用多少设备,若为-1则使用全部可用设备autoUnion[List[int], str, int]
num_nodes分布式环境下使用多少GPU1int

额外的解释

  • 这里max_steps/min_steps中的step就是指的是优化器的step(),优化器每step()一次就会更新一次网络权重
  • 梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size

Trainer.fit()

参数详解

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None

ckpt_path参数详解(从之前的模型恢复训练)

​ 使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。

示范
net = MyNet(...)
trainer = pl.Trainer(...)
# 假设模型保存在./ckpt中
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')
使用注意
  • 请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数

Trainer.test()Trainer.validate()

参数详解

参数名称含义默认值
modelLightningModule实例None
verbose是否打印测试结果True
dataloaders测试数据加载器(可以使torch.utils.data.DataLoader)None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None
  • Returns:测试/验证期间相关度量值的字典列表(列表长度等于测试/验证数据加载器个数),比如validation/test_step(), validation/test_epoch_end(),中的回调钩子
  • ckpt_path:如果设置了该参数则会使用该ckpt文件中的权重,否则如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks则加载该checkpoint callbacks对应的最佳模型

Trainer.predict()

参数详解

参数名称含义默认值
modelLightningModule实例None
dataloaders数据加载器None
ckpt_pathckpt文件路径(从这里文件恢复训练)None
datamoduleLightningDataModule实例None
return_predictions是否返回结果,目前不支持设置None(True)
  • ckpt_path:使用该ckpt文件中的权重,如果为None如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks则加载该checkpoint callbacks对应的最佳模型

使用注意

  • preict()中会禁用日志功能

Trainer.tune()

功能解释

  • 对模型超参数进行调整

常用参数

参数名称含义默认值
modelLightningModule实例
train_dataloaders训练数据加载器None
val_dataloaders验证数据加载器None
datamoduleLightningDataModule实例None
scale_batch_size_kwargs传递给scale_batch_size()的参数None
lr_find_kwargs传递给lr_find()的参数None

使用注意

  • auto_lr_find标志当且仅当执行trainer.tune(model)代码时工作

其他注意点

  • .test()若非直接调用,不会运行。
  • .test()会自动load最优模型。
  • model.eval() and torch.no_grad() 在进行测试时会被自动调用。
  • 默认情况下,Trainer()运行于CPU上。

Trainer属性

### PyTorch-Lightning介绍 PyTorch-Lightning 是一种简化深度学习模型训练过程的库,旨在使研究人员能够专注于研究本身而不是工程细节[^1]。通过该工具包可以更高效地编写结构化清晰、易于维护的研究代码。 #### 特点概述 - **解耦业务逻辑**:将数据处理、模型定义以及训练循环分离出来; - **跨平台支持**:无论是单机多卡还是分布式环境都能轻松切换配置而无需修改大量底层实现; - **内置最佳实践**:提供了许多经过验证的最佳做法来加速实验迭代并提高性能; #### 安装方式 对于不同需求场景下有多种安装途径可供选择: - 使用 `pip` 工具直接从Python Package Index获取最新稳定版: ```bash pip install pytorch-lightning ``` - 如果偏好Anaconda/Miniconda生态系统,则可以通过Conda渠道获得相同效果: ```bash conda install pytorch-lightning -c conda-forge ``` 需要注意的是,在安装过程中要确保所选版本之间的兼容性问题,即保持PyTorch LightningPyTorch本身的版本匹配以避免潜在冲突[^3]。 --- 下面给出一段简单的例子展示如何利用此框架快速搭建一个基于MNIST手写数字识别任务的基础网络架构: ```python import torch from torch.nn import functional as F from torchvision.datasets import MNIST from torchvision.transforms import ToTensor, Normalize, Compose from torch.utils.data.dataloader import DataLoader import pytorch_lightning as pl class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) def training_step(self, batch, batch_nb): x, y = batch loss = F.cross_entropy(self(x), y) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) if __name__ == '__main__': transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) dataset = MNIST('', train=True, download=True, transform=transform) loader = DataLoader(dataset, batch_size=32) model = LitModel() trainer = pl.Trainer(max_epochs=1) trainer.fit(model, loader) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值