目录
4.1 Manual vs automatic optimization
4.1.1 自动优化 (Automatic optimization)
4.1.1 手动优化 (Manual optimization)
相关文章
【PyTorch Ligntning】如何将 PyTorch 组织为 Lightning
【PyTorch Lightning】1.0 正式发布:从 0 到 1
项目地址:https://github.com/PyTorchLightning/pytorch-lightning
一、简介
本指南将展示如何分两步将 PyTorch 代码组织成 Lightning。
使用 PyTorch Lightning 组织代码,可以使代码:
- 保留所有灵活性(这全是纯 PyTorch),但去除了大量样板(boilerplate)
- 将研究代码与工程解耦,更具可读性
- 更容易复现
- 通过自动化大多数训练循环和棘手的工程设计,减少了出错的可能性
- 可扩展到任何硬件而无需更改模型
二、安装 PyTorch Lightning
pip 安装:
pip install pytorch-lightning
或 conda 安装:
conda install pytorch-lightning -c conda-forge
或在 conda 虚拟环境下安装:
conda activate my_env
pip install pytorch-lightning
在新源文件中导入以下将用到的依赖:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
三、定义 LightningModule
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
3.1 SYSTEM VS MODEL
注意,LightningModule 定义了一个系统 (system) 而不仅仅是一个模型 (model):
关于系统 (system) 的例子还有:
- Autoencoder
- BERT
- DQN
- GAN
- Image classifier
- Seq2seq
- SimCLR
- VAE
在内部,LightningModule 仍只是一个 torch.nn.Module,它将所有研究代码分组到一个文件中以使其自成一体:
- The Train loop
- The Validation loop
- The Test loop
- The Model or system of Models
- The Optimizer
可以通过覆盖 Available Callback hooks 中找到的 20+ 个 hooks 中的任意一个,来自定义任何训练部分 (如反向传播):
class LitAutoEncoder(pl.LightningModule):
def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
3.2 FORWARD vs TRAINING_STEP
在 Lightning 中,我们将训练与推理分开。training_step 定义了完整的训练循环。鼓励用户用 forward 定义推理行为。
例如,在这种情况下,可以定义自动编码器以充当嵌入提取器 (embedding extractor):
def forward(self, x):
embeddings = self.encoder(x)
return embeddings
当然,没有什么可以阻止你在 training_step 中使用 forward:
def training_step(self, batch, batch_idx):
...
z = self(x)
这确实取决于个人的应用程序,但仍建议将两个意图分开。
- 使用 forward 推理/预测
- 使用 training_step 训练
更多细节在 Lig