PyTorch 梯度管理完全指南:torch.no_grad() 与 requires_grad 的使用解析

本文主要是学习B站五道口纳什博主[动手写 bert 系列] torch.no_grad() vs. param.requires_grad == False专栏第四个视频的知识归纳总结。原问题表述andup主源代码

在深度学习的实践中,尤其是在处理像 BERT 这样的大型预训练模型时,梯度管理是一个既关键又容易混淆的概念。你是否曾经在训练模型时遇到 GPU 内存不足(OOM)的问题?或者在使用预训练模型进行特征提取时,不确定应该如何正确配置?本文将带你深入理解 PyTorch 中两个核心的梯度管理工具:torch.no_grad() 和 param.requires_grad,并通过实际代码示例展示它们的不同组合如何影响内存使用和训练效果。

一、基本概念解析

1.1 torch.no_grad():运行时上下文管理器

比喻:想象你在观看一位顶级厨师做菜。torch.no_grad() 就像是只看不记笔记——你观察整个过程,但不记录每个步骤的细节。

技术本质

  • 是一个上下文管理器(用 with 语句包裹)

  • 在代码块执行期间临时禁用梯度计算

  • 不改变参数的 requires_grad 属性

  • 主要用于推理(inference)和特征提取

# 示例:推理时不计算梯度
with torch.no_grad():
    predictions = model(input_data)  # 不保存计算图,节省内存

1.2 param.requires_grad:参数级别的梯度控制

比喻:这像是告诉厨师是否可以改进菜谱。如果设置为 False,就是让厨师严格按照现有菜谱操作,不能创新。

技术本质

  • 是张量(Tensor)的一个布尔属性

  • 永久性设置(除非显式修改)

  • 直接影响优化器是否会更新该参数

  • 用于冻结/解冻模型的特定部分

# 示例:冻结模型的前几层
for param in model.layer1.parameters():
    param.requires_grad = False  # 这些参数不会更新

二、四种种组合分析

让我们通过一个 BERT 模型的例子,探索四种不同组合的效果。假设我们有一个 BERT 模型(约 1.1 亿参数)加上一个简单的分类头(约 3 百万参数)。

组合对比表

组合torch.no_gradrequires_grad可学习参数内存占用适合场景
a✅ 使用❌ False3M特征提取
b✅ 使用✅ True112M矛盾配置
c❌ 不用❌ False3M分析BERT输出
d❌ 不用✅ True112M极大完整微调

详细代码演示

import torch
from transformers import BertModel
from torch import nn

# 加载预训练BERT
bert = BertModel.from_pretrained('bert-base-uncased')

def count_trainable_params(model):
    """统计可训练参数数量"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"默认BERT可训练参数: {count_trainable_params(bert):,}")
# 输出: 默认BERT可训练参数: 109,482,240
组合a:完美冻结配置
# 冻结BERT参数
for param in bert.parameters():
    param.requires_grad = False

class MyModelA(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        self.classifier = nn.Linear(768, 2)  # 3M参数
        
    def forward(self, text):
        with torch.no_grad():  # 不记录计算图
            features = self.bert(text)[0]  # 仅特征提取
        output = self.classifier(features[:, 0, :])
        return output

model_a = MyModelA(bert)
print(f"模型a可训练参数: {count_trainable_params(model_a):,}")
# 输出: 模型a可训练参数: 3,000,000(仅分类头)
组合d:完整微调(最容易OOM)
class MyModelD(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert  # 参数保持requires_grad=True
        self.classifier = nn.Linear(768, 2)
        
    def forward(self, text):
        # 无torch.no_grad(),保存完整计算图
        features = self.bert(text)[0]  # 这会保存每一层的激活值!
        output = self.classifier(features[:, 0, :])
        return output

model_d = MyModelD(bert)
print(f"模型d可训练参数: {count_trainable_params(model_d):,}")
# 输出: 模型d可训练参数: 112,482,240(全部参数)

三、为什么同时启用会导致OOM?

理解这个问题的关键在于区分参数存储激活值存储

3.1 参数 vs 激活值

  • 参数:模型需要学习的权重(BERT约1.1亿个)

  • 激活值:前向传播中每一层的中间计算结果

3.2 内存占用分析

# 组合b(有no_grad)的内存占用
# 只存储:输入 + 最终输出 ≈ 较小

# 组合d(无no_grad)的内存占用
# 需要存储:输入 + 每一层的激活值 + 梯度 + 最终输出
# BERT有12层transformer,每层都要保存激活值用于反向传播

计算示例
假设 batch_size=32,序列长度=128,隐藏维度=768:

单层激活值大小 ≈ 32 × 128 × 768 × 4字节 ≈ 12.5MB
12层激活值 ≈ 12 × 12.5MB ≈ 150MB
梯度存储 ≈ 参数数量 × 4字节 ≈ 440MB
其他中间变量 ≈ 100MB
总计 ≈ 690MB

这只是一个保守估计,实际可能更大,很容易超过常见GPU的显存(如11GB的RTX 2080 Ti)。

四、实际应用建议

4.1 场景一:仅特征提取(推荐)

# 冻结参数 + 使用no_grad = 最优解
model.eval()  # 设置为评估模式
with torch.no_grad():
    features = bert(input_ids)
# 后续处理...

4.2 场景二:完整微调(需要足够显存)

# 确保有足够GPU内存
model.train()  # 设置为训练模式
# 不使用no_grad,让所有参数可训练
outputs = bert(input_ids)
loss = compute_loss(outputs)
loss.backward()  # 需要完整的计算图
optimizer.step()

4.3 场景三:内存有限时的微调技巧

技巧1:梯度累积

# 模拟大batch_size,但内存占用小
accumulation_steps = 4
optimizer.zero_grad()

for i, batch in enumerate(data_loader):
    outputs = model(batch)
    loss = loss_fn(outputs)
    loss = loss / accumulation_steps  # 缩放损失
    loss.backward()  # 累积梯度
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 清空梯度

技巧2:混合精度训练

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in data_loader:
    optimizer.zero_grad()
    
    with autocast():  # 自动使用半精度
        outputs = model(batch)
        loss = loss_fn(outputs)
    
    scaler.scale(loss).backward()  # 缩放损失
    scaler.step(optimizer)  # 更新参数
    scaler.update()  # 更新缩放器

技巧3:仅微调部分层

# 只解冻最后几层
layers_to_unfreeze = ['layer.10', 'layer.11', 'pooler']

for name, param in bert.named_parameters():
    # 只有特定层的参数需要梯度
    if any(layer in name for layer in layers_to_unfreeze):
        param.requires_grad = True
    else:
        param.requires_grad = False

五、常见误区与解答

Q1:为什么组合b(有no_grad但requires_grad=True)也能运行?

A:因为torch.no_grad()阻止了计算图的创建,所以即使参数标记为可训练,实际上也没有梯度信息用于更新。这是一个矛盾但技术上可行的配置。

Q2:我应该总是同时使用两者吗?

A:不一定,取决于你的目标:

  • 训练时:通常不用torch.no_grad(),但可能设置部分参数requires_grad=False

  • 推理时:两者都用最安全、最节省内存

Q3:model.eval()和torch.no_grad()有什么区别?

A

  • model.eval():改变模型行为(如关闭Dropout、BatchNorm使用移动统计量)

  • torch.no_grad():改变梯度计算行为(不保存计算图)

  • 通常推理时两者一起使用

# 正确推理做法
model.eval()  # 改变模型行为
with torch.no_grad():  # 禁用梯度计算
    output = model(input_data)

Q4:如何知道我的模型是否被正确冻结?

def check_frozen_status(model):
    for name, param in model.named_parameters():
        print(f"{name}: requires_grad={param.requires_grad}")

六、总结

掌握torch.no_grad()param.requires_grad的区别是高效使用PyTorch的关键:

  1. torch.no_grad()控制的是计算图的创建,影响运行时内存

  2. param.requires_grad控制的是参数的可训练性,影响模型更新

  3. 两者配合使用可以实现灵活的内存管理和训练策略

记住这个简单规则:

  • 要训练:不用torch.no_grad(),设置requires_grad=True

  • 不训练:使用torch.no_grad(),设置requires_grad=False

通过合理使用这两个工具,你可以在有限的硬件资源下有效地利用大型预训练模型,无论是进行特征提取、微调还是其他下游任务。希望这篇博客能帮助你更好地理解PyTorch中的梯度管理。在实际应用中,根据你的具体需求和硬件条件选择合适的配置,才能最大化模型的性能和效率。

“路虽远,行则将至;事虽难,做则必成。”我们一起在学习的路上继续前行! 💪

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值