参考文档:LightningModule — PyTorch Lightning 1.7.5 documentation (pytorch-lightning.readthedocs.io)
pytorch_lightning中的训练器Trainer:
import pytorch_lightning as pl
class model(pl.LightningModule):
pass
model1=model(参数2)
trainer = pl.Trainer(参数1)
trainer.fit(model1)
1.参数1详解:
参数名称 | 意义 | 默认值 |
max-epochs | 最多训练轮数 | |
callbacks | 添加回调函数或回调函数列表 | |
gpus | 使用的gpu数量 | |
accumulate_grad_batches | 每k次batches累计一次梯度 | |
logger | 设置日志记录器(支持多个) | |
resume_from_checkpoint | ||
gradient_clip_val | ||
check_val_every_n_epoch | 每n个train epoch执行一次验证 | |
num_sanity_val_steps | 开始训练前加载n个验证数据进行测试, k=-1时加载所有验证数据 | |
log_every_n_steps | 更新n次网络权重后记录一次日志 | |
flush_logs_every_n_steps | ||
limit_train_batches | 使用训练数据的百分比.支持0到1的浮点数和整数,比如0.1代表每个epoch只跑十分之一的数据 | 支持0到1的浮点数和整数 |
limit_val_batches | 使用验证数据的百分比,10代表每个epoch只跑10个batches | 支持0到1的浮点数和整数 |
limit_test_batches | 使用测试数据的百分比. | 支持0到1的浮点数和整数 |
2.lightningmodule方法详解:
也就是定义model类是会定义哪些函数
3.trainer.fit参数详解
Trainer.fit(model, train_dataloaders=None,
val_dataloaders=None, datamodule=None,
ckpt_path=None)
其中model为实例化的pl.LightningModule