pytorch 交叉验证_SAST Weekly | 让Pytorch变得更简单

PyTorch Lightning 是一个组织 PyTorch 代码的库,旨在将科学代码与工程代码分离,简化训练过程。通过将加载数据、训练循环等工程细节交给 Trainer 处理,开发者可以更专注于模型实现。本文通过 MNIST 例子介绍了如何使用 LightningModule,展示其如何提升代码清晰度和减少重复工作,并提到 Lightning 支持 GPU 训练和各种可视化工具。
摘要由CSDN通过智能技术生成
5cc31290a45f57b53d204d67141d9770.png afc43084b8e677c67921d540bfa7f8d8.png 7e729daca824d23837b0139f5f387c5b.png

SAST Weekly 是由电子工程系学生科协推出的科技系列推送,内容涵盖信息领域技术科普、研究前沿热点介绍、科技新闻跟进探索等多个方面,帮助同学们增长姿势,开拓眼界,每周更新,欢迎关注!欢迎愿意分享知识的同学投稿至 eesast@mail.tsinghua.edu.cn , 期待你的作品!

e0d074588e64274df795f245c5ce4c4f.png a90d606035bc295e7f4dbe788307b967.png f28af1a453286ac7f84b67b28f42e71f.gif

相信已经有不少同学接触过Pytorch了(媒认课上还见到了九字班的同学),作为目前最流行的深度学习框架之一,Pytorch本身已经具备了入门简单,代码简洁等优点,它可以让你把精力专注在在实现自己的模型上,很适合学生、研究人员来使用。

尽管Pytorch中需要用户自己实现的功能已经非常少了,但是除了定义自己的模型之外,我们还需要在代码中加入导入、加载数据集,手写训练、交叉验证循环,保存、绘制loss曲线等等写起来重复度高并且还会让你的代码变得不那么优美的内容(主要是写起来麻烦)。那么,有没有一种——

f712c69a53f05afff7ca2a63b16ce17a.png

Pytorch Lightning可以帮助你把写Pytorch变得更加轻松愉快[小恐龙点赞] 

f28af1a453286ac7f84b67b28f42e71f.gif

官方文档是这样描述它的

Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It's more of a PyTorch style-guide than a framework.

简单来说,你只需要关注你的模型的实现,也就是Research Code,其他的Engineering code都交给Lightning的Trainer来实现

先用一张图来看看Pytorch Lightning能为你做什么

4d4d7e29d87db3e70e1363d61d74773a.gif

具体到代码,你可以省去图中蓝色的部分

7fdfe3b00406e818b5e93dfce55a0b51.png fa52f31ebf5269a98114c4aac0ea3d17.png

而原本的定义Pytorch模型的代码可以直接用于LightningModule。事实上,你只需要用另一种风格来重新组织你的Pytorch代码而不需要学习新的概念。

f28af1a453286ac7f84b67b28f42e71f.gif

接下来我们用一个MNIST手写数字识别的例子来体验下Lightning。

首先模型的父类要换成LightningModule

f2f1332b7d58f2541a0e36c0cec38733.png

__init__forward方法可以直接使用不需要修改。

c6907f8aa3b0180db5b0246b4219ca64.png

与Pytorch不同的是,你需要把加载数据以及创建DataLoader的代码写在模型中,需要定义如下几个方法

c4008e3b5049dbef335cb49f5e0554de.png

事实上我们只是把加载数据集的代码放到了prepare_data中,把创建dataloader的代码分别放到了train_dataloaderval_dataloader

接下来我们要把原本的training loop中最内层的代码写到training_stepvalidation_step

0cb87763bb675dba2f1a96641d59e449.png

training_stepvalidation_step中的返回值都是一个字典,其中training_step的返回值必须包含loss属性,log是你需要输出的数据的字典,而validation_step的返回值字典保存的是你想要在每次validation结束输出的数据,如下

e782bee8f0a4fdb6185cbb15dbdcc9c9.png

返回的字典中progress_bar也需要是一个字典,对应显示在进度条上的值(稍后展示),log同上作为输出的数据被保存下来

最后,我们还需要定义一个optimizer

2d4d2d98438b60acac2a9f5815bac3c0.png

这样一来一个LightningModule就完成了,而他的训练异常简单,你只需要

cd8ef6e4acff740497909664453e2f27.png

不需要手写training loop,也不需要自己加入进度条,一切都由Trainer帮你完成

2f1e933f416c523f0b96873991b75705.png

可以看到进度条右侧显示出了我们设置的validation的结果

如果你想使用GPU训练,也不必像原来一样一个个设置,只需要

48aaaee448a9c5866850d44540db26d6.png

至于我们在step中设置要输出的log,它们被按照tensorboard(可替换其他可视化工具)的格式保存在了./lightning_logs中,你可以直接使用tensorboard来查看我们保存的log的曲线

c6c9d1bcf70ee0f01645a93618d01705.png

如果你在模型中定义了test_dataloader,   test_steptest_epoch_end,你就可以直接进行test

483a75adc509719461639adc67a17409.png f28af1a453286ac7f84b67b28f42e71f.gif

以上简单地展示了Pytorch Lightning的使用,已经可以明显地看出他在pytorch上做的改进:你的代码条理更加清晰,你不需要自己完成诸如training loop的工程性强研究意义低的代码,你也不需要为可视化花心思。

除了改变了写代码的风格之外,其余没有与pytorch不同的地方,实际上你的代码中使用的依然是pytorch里的方法和函数,只是把他们交给了Lightning来组织、运行、输出

Lightning对RNN、GAN等也有着很好的支持,Trainer可设置的参数也十分丰富,在此不再作介绍,有兴趣的同学可以参考官方文档进行深入了解

参考资料:

    官方文档: 

      https://pytorch-lightning.readthedocs.io/en/latest/    

    From PyTorch to PyTorch Lightning — A gentle introduction: 

      https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09

97f660819b70dc25340301b10eb99ac3.png

撰稿:李煜泽

审核:孙志尧

e87c918013a4f65e163d659df0ff8eed.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值