Pytorch中的checkPoint: torch.utils.checkpoint.checkpoint

本文档概述了PyTorch 1.9中的checkpoint功能,介绍了如何利用它进行内存高效计算,以及如何处理随机数状态以确保确定性输出。重点讲解了checkpoint的工作原理、存储逻辑和配置选项,适用于模型训练优化。
摘要由CSDN通过智能技术生成

torch.utils.checkpoint.checkpoint笔记,内容来源于官方手册
仅作笔记只用,不完整之处请查阅官方手册
https://pytorch.org/docs/stable/checkpoint.html

checkpoint是通过在backward期间为每个checkpoint段重新运行forward-pass segment来实现的。
这可能会导致像 RNG 状态这样的持久状态比没有checkpoint的情况更先进。默认情况下,checkpoint包括处理 RNG 状态的逻辑,以便与非checkpoint传递相比,使用 RNG 的checkpoint传递(例如通过 dropout)具有确定性输出。
根据checkpoint操作的运行时间,存储和恢复 RNG 状态的逻辑可能会导致一定的性能下降。

注:RNG状态指随机数状态

如果不需要确定性输出,则为checkpoint或 checkpoint_sequential 设置preserve_rng_state=False ,以在每个checkpoint期间省略存储和恢复RNG 状态。

换言之,1.9版本中,checkpoint能处理随机数状态了!

存储逻辑将当前设备和所有 cuda Tensor 参数的设备的 RNG 状态保存并恢复到 run_fn。但是,逻辑无法预测用户是否会将张量移动到 run_fn 本身内的新设备。因此,如果您在 run_fn 中将张量移动到一个新设备确定性输出永远无法保证。

尽量使用原生的 torch.utils.checkpoint.checkpoint

checkpoint的工作原理是用计算换取内存。与存储整个计算图的所有中间激活用于反向计算不同,checkpoint部分不保存中间激活,而是在反向传递中重新计算它们.它可以应用于模型的任何部分。

具体来说,在前向传递中,函数将以 torch.no_grad() 方式运行,即不存储中间激活。相反,前向传递保存输入元组和函数参数。在向后传递中,检索保存的输入和函数,并再次对函数计算前向传递,现在跟踪中间激活,然后使用这些激活值计算梯度。

函数的输出可以包含非 Tensor 值,并且仅对 Tensor 值执行梯度记录。请注意,如果输出包含由张量组成的嵌套结构(例如:自定义对象、列表、字典等),则这些嵌套在自定义结构中的张量将不会被视为 autograd 的一部分。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值