Pytorch-Lightning中的训练器—Trainer
Trainer.__init__()
常用参数
参数名称 | 含义 | 默认值 | 接受类型 |
---|
callbacks | 添加回调函数或回调函数列表 | None (ModelCheckpoint 默认值) | Union[List[Callback], Callback, None] |
enable_checkpointing | 是否使用callbacks | True | bool |
enable_progress_bar | 是否显示进度条 | True | bool |
enable_model_summary | 是否打印模型摘要 | True | bool |
precision | 指定训练精度 | 32(full precision ) | Union[int, str] |
default_root_dir | 模型保存和日志记录默认根路径 | None (os.getcwd() ) | Optional[str] |
logger | 设置日志记录器(支持多个),若没设置logger的save_dir ,则使用default_root_dir | True (默认日志记录) | Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] |
max_epochs | 最多训练轮数(指定为**-1可以设置为无限次**) | None (1000) | Optional[int] |
min_epochs | 最少训练轮数。当有Early Stop时使用 | None(1) | Optional[int] |
max_steps | 最大网络权重更新次数 | -1(禁用) | Optional[int] |
min_steps | 最少网络权重更新次数 | None (禁用) | Optional[int] |
weights_save_path | 权重保存路径(优先级高于default_root_dir ),ModelCheckpoint 未定义路径时将使用该路径 | None(default_root_dir ) | Optional[str] |
log_every_n_steps | 更新n次网络权重后记录一次日志 | 50 | int |
limit_train_batches limit_test_batches limit_val_batches limit_predict_batches | 使用训练/测试/验证/预测数据的百分比.如果数据过多,或正在调试可以使用。 | 1.0 | Union[int, float] (float = 比例, int = num_batches). |
fast_dev_run | 如果设定为true,会只执行一个batch的train, val 和 test,然后结束。仅用于debug | False | bool |
accumulate_grad_batches | 每k次batches累计一次梯度 | None (无梯度累计) | Union[int, Dict[int, int], None] |
check_val_every_n_epoch | 每n个train epoch执行一次验证 | 1 | int |
num_sanity_val_steps | 开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据 | 2 | int |
硬件加速相关选项
参数名称 | 含义 | 默认值 | 接受类型 |
---|
accelerator | 设置硬加类型 | auto | Union[str, Accelerator] |
devices | 使用多少设备,若为-1则使用全部可用设备 | auto | Union[List[int], str, int] |
num_nodes | 分布式环境下使用多少GPU | 1 | int |
额外的解释
- 这里
max_steps/min_steps
中的step
就是指的是优化器的step()
,优化器每step()
一次就会更新一次网络权重 - 梯度累加(Gradient Accumulation):受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size。
Trainer.fit()
参数详解
ckpt_path参数详解(从之前的模型恢复训练)
使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer
将从ckpt文件的下一个epoch继续训练。
示范
net = MyNet(...)
trainer = pl.Trainer(...)
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')
使用注意
- 请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数
Trainer.test()
和Trainer.validate()
参数详解
参数名称 | 含义 | 默认值 |
---|
model | LightningModule 实例 | None |
verbose | 是否打印测试结果 | True |
dataloaders | 测试数据加载器(可以使torch.utils.data.DataLoader ) | None |
ckpt_path | ckpt文件路径(从这里文件恢复训练) | None |
datamodule | LightningDataModule 实例 | None |
Returns:
测试/验证期间相关度量值的字典列表(列表长度等于测试/验证数据加载器个数),比如validation/test_step(), validation/test_epoch_end(),
中的回调钩子ckpt_path:
如果设置了该参数则会使用该ckpt文件中的权重,否则如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks
则加载该checkpoint callbacks
对应的最佳模型
Trainer.predict()
参数详解
ckpt_path:
使用该ckpt文件中的权重,如果为None
如果模型已经训练完毕则使用当前权重,其他情况如果配置了checkpoint callbacks
则加载该checkpoint callbacks
对应的最佳模型
使用注意
Trainer.tune()
功能解释
常用参数
参数名称 | 含义 | 默认值 |
---|
model | LightningModule 实例 | |
train_dataloaders | 训练数据加载器 | None |
val_dataloaders | 验证数据加载器 | None |
datamodule | LightningDataModule 实例 | None |
scale_batch_size_kwargs | 传递给scale_batch_size() 的参数 | None |
lr_find_kwargs | 传递给lr_find() 的参数 | None |
使用注意
auto_lr_find
标志当且仅当执行trainer.tune(model)
代码时工作
其他注意点
.test()
若非直接调用,不会运行。.test()
会自动load最优模型。model.eval()
and torch.no_grad()
在进行测试时会被自动调用。- 默认情况下,
Trainer()
运行于CPU上。
Trainer
属性