首先,我们需要了解如何使用Pytorch Lightning和Torchtext
Pytorch是我主要的深度学习框架。但是,有一部分我觉得可以改进。将Pytorch Lightning,Pytorch Ignite和fast.ai之间进行了比较。Ignite并非每个模型都具有标准接口,需要更多的代码行来训练模型,没有直接与Tensorboard集成,并且不像Lightning那样具有其他高性能计算。虽然fast.ai具有比其他两个更高的学习曲线,并且用例可能与Pytorch Lightning和Pytorch Ignite不同。
在本文中,将重点介绍使Pytorch Lightning提高生产率以及如何将Pytorch Lightning与Torchtext集成的一些功能。
为什么要使用Pytorch Lightning
减少样板。可以将训练定义为
from pytorch_lightning import Trainertrainer = Trainer( gpus=1, logger=[logger], max_epochs=5)trainer.fit(model)
进行常规训练。
- 没有更多的写循环。
- 无需将模型转换为GPU。
- 没有自定义打印功能,可避免损失。
在此屏幕截图中,我将logger变量定义为
from pytorch_lightning.loggers import TensorBoardLoggerlogger = TensorBoardLogger('tb_logs', name='my_model')
Pytorch Lightning将创建一个名为的日志目录,tb_logs并且您可以为Tensorboard引用该日志目录(如果您要从Jupyter笔记本电脑单独运行Tensorboard)。
tensorboard --logdir tb_logs/
组织代码
除了构造函数,forward您将能够定义更多的函数
- configure_optimiz