pytorch-lighting

import pytorch-lighting as pl

在大多数研究项目中,通常可以归纳到以下关键部分:

  • 模型
  • 训练/验证/测试 数据
  • 优化器
  • 训练/验证/测试 计算

        以上包含的关键部分在 Lightning 中,是抽象为 LightningModule 类;而这个类与我们平时使用的 torch.nn.Module 是一样的(在原有代码中直接替换 Module 而不改其他代码也是可以的),但不同的是,Lightning 围绕 torch.nn.Module 做了很多功能性的补充,把上面4个关键部分都囊括了进来。这么做的意义在于:我们的关键部分都是围绕 我们的神经网络模型 来运行的,所以 Lightning 把这部分代码都集合在一个类里。所以我们接下来的介绍,都是围绕 LightningModule 类来展开。

一般的,训练开始之后执行的 默认顺序如下所示

准备工作:包括初始化 LightningModule,准备数据 和 配置优化器。

这部分代码 只执行一次

1. __init__()(初始化 LightningModule )
2. prepare_data()  (准备数据,包括下载数据、预处理等等)
3. configure_optimizers()(配置优化器)
  • 测试 “验证代码”。

提前来做的意义在于:不需要等待漫长的训练过程才发现验证代码有错。
这部分就是提前执行 “验证代码”,所以和下面的验证部分是一样的。

1. val_dataloader()
2. validation_step()
3. validation_epoch_end()
  • 开始加载dataloader,用来给训练加载数据
1. train_dataloader()
2. val_dataloader() (如果你定义了)
  • 下面部分就是循环训练了,_step() 的意思就是按batch来进行的部分;_epoch_end() 就是所有batch执行完后要进行的部分。
# 循环训练与验证
1. training_step()
2. validation_step()
3. validation_epoch_end()
  • 最后训练完了,就要进行测试,但测试部分需要手动调用 .test(),这是为了避免误操作。
# 测试(需要手动调用)
1. test_dataloader()
2. test_step()
3. test_epoch_end()

        在这里,我们很容易总结出,在训练部分,主要是三部分:_dataloader/_step/_epoch_end。Lightning把训练的三部分抽象成三个函数,而我们只需要“填鸭式”地补充这三部分,就可以完成模型训练部分代码的编写。

        为了让大家更清晰地了解这三部分的具体位置,下面用 PyTorch实现方式 来展现其位置。

for epoch in epochs:
    for batch in train_dataloader:
        # train_step
        # ....
        # train_step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    for batch in val_dataloader:
        # validation_step
        # ....
        # validation_step
    
    # *_step_end
    # ....
    # *_step_end

使用Lightning的好处

        不需要写一大堆的 .cuda() 和 .to(device),Lightning会帮你自动处理。如果要新建一个tensor,可以使用type_as来使得新tensor处于相同的处理器上。

def training_step(self, batch, batch_idx):
    x, y = batch

    # 把z放在和x一样的处理器上
    z = sample_noise()
    z = z.type_as(x)
在这里,有个地方需要注意的是,不是所有的在LightningModule 的 tensor 都会被自动处理,而是只有从 Dataloader 里获取的 tensor 才会被自动处理,所以对于 transductive learning 的训练,最好自己写Dataloader的处理函数。
  • 工程代码参数化

        平时我们写模型训练的时候,这部分代码会不断重复,但又不得不做,不如说ealy stopping,精度的调整,显存内存之间的数据转移。这部分代码虽然不难,但减少这部分代码会使得 研究代码 更加清晰,整体也更加简洁。

        下面是简单的展示,表示使用 LightningModule 建立好模型后,如何进行训练。

model = LightningModuleClass()
trainer = pl.Trainer(gpus="0",  # 用来配置使用什么GPU
                     precision=32, # 用来配置使用什么精度,默认是32
                     max_epochs=200 # 迭代次数
                     )

trainer.fit(model)  # 开始训练
trainer.test()  # 训练完之后测试
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值