在Pytorch深度学习框架中,什么时候应该使用torch.no_grad()?

最近在看沐神课,复现代码时出现了一个错误:

Traceback (most recent call last):
  File "C:\Users\GCLuis\Desktop\dl\3.LinearNeuralNetworks\3-2.py", line 83, in <module>
    sgd([w, b], lr, real_batch_size)  # 使用参数的梯度更新参数
  File "C:\Users\GCLuis\Desktop\dl\3.LinearNeuralNetworks\3-2.py", line 67, in sgd
    param -= lr * param.grad / batch_size
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

原因在于在执行sgd优化函数的时候,没有在更新参数前执行:

with torch.no_grad():

所以来总结一下,究竟在什么时候应该执行先with torch.no_grad():

函数定义

在 PyTorch 中,当我们执行一些不需要梯度计算的操作时,我们可以通过将代码包裹在 torch.no_grad() 上下文管理器中来减少内存的消耗并提高代码执行效率。

具体来说,torch.no_grad() 是一个上下文管理器,它会在执行被包裹的代码块时关闭 PyTorch 张量的自动求导功能。这意味着,被包裹的代码块中的所有计算操作都不会被记录在 PyTorch 的计算图中,从而减少了计算图的大小,也避免了不必要的内存消耗。

使用 torch.no_grad() 上下文管理器的语法如下所示:

with torch.no_grad():
    # some code that doesn't require gradients

哪些情况会使用到该函数?

在以下情况下,通常建议使用 torch.no_grad()

评估模型:当你在评估模型(例如,计算验证集上的性能)时,不需要计算梯度。关闭梯度计算可以节省内存并提高计算速度。

model.eval()
with torch.no_grad():
    for inputs, targets in validation_loader:
        outputs = model(inputs)
        # 计算指标,如准确率、损失等

更新参数时:当你在优化算法中更新模型参数时,不需要在参数更新步骤中计算梯度。在更新参数时使用 torch.no_grad() 可以防止出现错误,并确保计算过程正确。

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

梯度不需要的计算:在不需要梯度计算的任何计算过程中,例如当你需要从模型输出中获取预测类别时,可以使用 torch.no_grad() 关闭梯度计算以节省内存和计算资源。

with torch.no_grad():
    inputs = torch.randn(1, 3, 224, 224)
    outputs = model(inputs)
    _, predicted_class = torch.max(outputs, 1)

总之,在你确定不需要计算梯度的情况下,使用 torch.no_grad() 可以节省计算资源并提高运行效率。

如果在这些情况下没有使用torch.no_grad() 会导致哪些错误?

  1. 额外的内存消耗:计算和存储梯度需要额外的内存。在不需要梯度的情况下仍然计算梯度会导致不必要的内存消耗。在内存有限的设备上,如GPU,这可能导致内存不足而无法执行计算。

  2. 降低计算速度:计算梯度会增加计算负担。如果在不需要梯度的情况下仍然计算梯度,会降低计算速度,从而增加模型评估和推理的时间。

  3. 可能的计算错误:在某些情况下,如在优化算法中更新参数时,如果不使用 torch.no_grad(),可能导致错误。例如,如果你在需要梯度的张量上执行原地操作,PyTorch会抛出 RuntimeError,因为这样的操作会破坏计算图和梯度计算。

虽然在某些情况下忘记使用 torch.no_grad() 可能不会立即导致错误,但为了确保计算效率和正确性,建议在不需要梯度计算的情况下使用 torch.no_grad()

  • 10
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
1.项目代码均经过功能验证ok,确保稳定可靠运行。欢迎下载体验!下载完使用问题请私信沟通。 2.主要针对各个计算机相关专业,包括计算机科学、信息安全、数据科学与大数据技术、人工智能、通信、物联网等领域的在校学生、专业教师、企业员工。 3.项目具有丰富的拓展空间,不仅可作为入门进阶,也可直接作为毕设、课程设计、大作业、初期项目立项演示等用途。 4.当然也鼓励大家基于此进行二次开发。在使用过程,如有问题或建议,请及时沟通。 5.期待你能在项目找到乐趣和灵感,也欢迎你的分享和反馈! 【资源说明】 基于Pytorch框架的TPLinker_plus文命名实体识别python源码+使用说明+模型+数据集.zip 还是和之前其它几种实体识别方式相同的代码模板,这里稍微做了一些修改,主要是在数据加载方面。之前都是预先处理好所有需要的数据保存好,由于tplinker需要更多内存,这里使用DataLoader的collate_fn对每一批的数据分别进行操作,可以大大减少内存的使用。模型主要是来自这里:[tplinker_plus](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_tplinker_plus.py),需要额外了解的知识有:[基于Conditional Layer Normalization的条件文本生成 - 科学空间|Scientific Spaces](https://spaces.ac.cn/archives/7124)和[将“softmax+交叉熵”推广到多标签分类问题 - 科学空间|Scientific Spaces](https://www.spaces.ac.cn/archives/7359)。实现运行步骤如下: - 1、在raw_data下新建一个process.py将数据处理成mid_data下的格式。 - 2、修改部分参数运行main.py,以进行训练、验证、测试和预测。 模型及数据下载地址:链接:https://pan.baidu.com/s/1B-e-MV1lOMQj2ur5MADRww?pwd=he3e 提取码:he3e # 依赖 ``` pytorch==1.6.0 tensorboasX seqeval pytorch-crf==0.7.2 transformers==4.4.0 ``` # 运行 在16GB的显存下都只能以batch_size=2进行运行。。。 ```python python main.py \ --bert_dir="model_hub/chinese-bert-wwm-ext/" \ --data_dir="./data/cner/" \ --log_dir="./logs/" \ --output_dir="./checkpoints/" \ --num_tags=8 \ --seed=123 \ --gpu_ids="0" \ --max_seq_len=150 \ --lr=3e-5 \ --other_lr=3e-4 \ --train_batch_size=2 \ --train_epochs=1 \ --eval_batch_size=8 \ --max_grad_norm=1 \ --warmup_proportion=0.1 \ --adam_epsilon=1e-8 \ --weight_decay=0.01 \ --dropout_prob=0.3 \ ``` ### 结果 ```python precision:0.8806 recall:0.8999 micro_f1:0.8901 precision recall f1-score support TITLE 0.87 0.88 0.87 767 RACE 0.88 0.93 0.90 15 CONT 1.00 1.00 1.00 33 ORG 0.89 0.90 0.89 543 NAME 0.99 1.00 1.00 110 EDU 0.82 0.94 0.88 109 PRO 0.67 0.95 0.78 19

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鬼才Luis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值