【PyTorch Ligntning】快速上手简明指南

目录

一、简介

二、安装 PyTorch Lightning

三、定义 LightningModule

3.1 SYSTEM VS MODEL 

3.2 FORWARD vs TRAINING_STEP

三、配置 Lightning Trainer

四、基本特性

4.1 Manual vs automatic optimization

4.1.1 自动优化 (Automatic optimization)

4.1.1 手动优化 (Manual optimization)

4.2 Predict or Deploy

4.2.1 选项一 —— 子模型 (Sub-models)

4.2.2 选项二 —— 前馈 (Forward)

4.2.2 选项三 —— 生产 (Production)

4.3 Using CPUs/GPUs/TPUs

4.4 Checkpoints

4.5 Data flow

4.6 Logging

4.7 Optional extensions

4.7.1 回调 (Callbacks)

4.7.2 LightningDataModules

4.8 Debugging

五、其他炫酷特性


相关文章

【PyTorch Lightning】简介

【PyTorch Ligntning】如何将 PyTorch 组织为 Lightning

【PyTorch Lightning】1.0 正式发布:从 0 到 1 


项目地址https://github.com/PyTorchLightning/pytorch-lightning


一、简介

本指南将展示如何分两步将 PyTorch 代码组织成 Lightning。

使用 PyTorch Lightning 组织代码,可以使代码:

  • 保留所有灵活性(这全是纯 PyTorch),但去除了大量样板(boilerplate)
  • 将研究代码与工程解耦,更具可读性
  • 更容易复现
  • 通过自动化大多数训练循环和棘手的工程设计,减少了出错的可能性
  • 可扩展到任何硬件而无需更改模型

二、安装 PyTorch Lightning

pip 安装:

pip install pytorch-lightning

或 conda 安装:

conda install pytorch-lightning -c conda-forge

或在 conda 虚拟环境下安装:

conda activate my_env
pip install pytorch-lightning

在新源文件中导入以下将用到的依赖:

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split

三、定义 LightningModule

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

3.1 SYSTEM VS MODEL 

注意,LightningModule 定义了一个系统 (system) 而不仅仅是一个模型 (model):

关于系统 (system) 的例子还有:

在内部,LightningModule 仍只是一个 torch.nn.Module,它将所有研究代码分组到一个文件中以使其自成一体:

  • The Train loop
  • The Validation loop
  • The Test loop
  • The Model or system of Models
  • The Optimizer

可以通过覆盖 Available Callback hooks 中找到的 20+ 个 hooks 中的任意一个,来自定义任何训练部分 (如反向传播):

class LitAutoEncoder(pl.LightningModule):

    def backward(self, loss, optimizer, optimizer_idx):
        loss.backward()

3.2 FORWARD vs TRAINING_STEP

在 Lightning 中,我们将训练与推理分开。training_step 定义了完整的训练循环。鼓励用户用 forward 定义推理行为。

例如,在这种情况下,可以定义自动编码器以充当嵌入提取器 (embedding extractor):

def forward(self, x):
    embeddings = self.encoder(x)
    return embeddings

当然,没有什么可以阻止你在 training_step 中使用 forward:

def training_step(self, batch, batch_idx):
    ...
    z = self(x)

这确实取决于个人的应用程序,但仍建议将两个意图分开。

  • 使用 forward 推理/预测
  • 使用 training_step 训练

更多细节在 Lig

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值