深入理解 PyTorch 中的自动微分机制与 `.detach()` 用法全解析

作者:Accelemate
发布时间:2025年6月26日
本文摘要:本文将从零开始,系统性地讲解 PyTorch 中的计算图、反向传播、with torch.no_grad().detach() 等核心机制,结合实践场景如可视化中间层特征图、GAN 模型中对生成器的冻结操作等内容,帮助你在实际开发中灵活、正确地使用自动微分特性。


一、自动微分基础概念

1.1 什么是自动微分(Autograd)?

PyTorch 的自动微分机制(Autograd)是其最核心的特性之一。

简而言之:

你只需要写出前向传播的过程(即如何从输入计算出输出),PyTorch 就会自动帮你记录每一步操作,构建出一个“计算图”,并且在调用 .backward() 时,沿着这个图,自动地反向传播误差,计算所有需要梯度的参数的梯度。

1.2 张量的 requires_grad 属性

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  • 这告诉 PyTorch:我要追踪这个张量上的所有操作,未来可能要用来求导。
  • 默认是 False,只有当 requires_grad=True 时,计算图才会追踪这个张量。

1.3 从 .backward() 到参数更新

在模型训练过程中,每次前向传播得到 loss(损失值)后,调用:

loss.backward()  # 反向传播,计算所有参数的梯度
optimizer.step()  # 更新模型参数
optimizer.zero_grad()  # 梯度清零,为下次迭代做准备

.backward() 是我们所说的“反向传播”核心函数,它的入口是 Loss。

你可能会问:我们不是要优化权重吗?为啥对 Loss 做反向传播?

✅ 答案是:

  • 因为 PyTorch 构建的是从输入到 Loss 的整条计算路径,调用 .backward() 就等于从 Loss 开始,自动向后走回网络的各层参数,使用链式法则一步步计算每个参数对 Loss 的梯度。

这就是自动微分的强大之处。


二、计算图的生成与管理

2.1 什么是计算图?

  • 它不是张量,不是图片,而是一个动态图结构
  • 每当你对带 requires_grad=True 的张量进行运算时,PyTorch 就会动态地记录这些操作,并构成一张有向无环图(DAG),这个图就叫做计算图
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 + 3 * x  # 此时构建计算图
  • 此时,PyTorch 记录了 x → x**2 → +3x → y 的整个过程。
  • 这图中每个节点是操作(如加法、乘法),边是数据流。

2.2 什么时候构建计算图?

✅ 不是创建张量时,而是在执行运算时。

x = torch.tensor([2.0], requires_grad=True)  # 无计算图,仅是张量
z = x ** 2 + 1  # 此刻构建计算图

2.3 .grad_fn 与叶子节点

  • 所有由运算生成的张量,都会拥有 .grad_fn 属性,标志其由哪种操作生成。
  • 而你自己用 tensor(...).detach() 创建的张量没有 .grad_fn,这些是叶子节点(计算图的起点)

三、深入理解 .backward()

3.1 基本使用

y = (x ** 2).sum()
y.backward()  # 自动计算 x 的梯度
print(x.grad)  # tensor([2.0, 4.0, 6.0])

3.2 .backward() 的限制

  • 要求输出张量是标量(0维)
  • 如果是向量,需要显式传入梯度:
y.backward(gradient=torch.tensor([1.0, 1.0, 1.0]))

3.3 多次调用 .backward() 的陷阱

loss.backward(retain_graph=True)
  • 默认 .backward() 会释放计算图;多次反向需要设定 retain_graph=True

四、with torch.no_grad():推理与节省资源

4.1 推理阶段的常规做法

model.eval()
with torch.no_grad():
    output = model(input)

4.2 为什么使用 no_grad()

即使你不写 .backward(),PyTorch 仍会构建计算图,记录中间操作,浪费内存。

✅ 使用 no_grad() 可以:

  • 禁用计算图构建;
  • 加快运行速度;
  • 降低显存占用;
  • 防止误用 .grad

所以即便不训练,也应养成用 no_grad() 的好习惯。


五、.detach():从计算图中摘除变量

5.1 基本用法

x = torch.tensor([2.0], requires_grad=True)
y = x * 3
y_detached = y.detach()
  • y_detachedy 数据相同;
  • 但它不再属于计算图y_detached.grad_fn = None

5.2 是深拷贝吗?

❌ 不是。是“浅拷贝 + 剪断计算图”。

  • .detach() 返回的张量与原张量共享内存,但不再追踪其历史。

5.3 典型使用场景

(1)中间层输出特征图可视化
features = model.conv1(input_image)  # 中间层输出
features = features.detach()         # 切断反向传播路径

为什么要可视化?看模型是否提取了我们期望的特征,比如边缘、纹理等:

import matplotlib.pyplot as plt
for i in range(6):
    plt.imshow(features[0, i].cpu(), cmap='gray')
    plt.title(f'Feature Map {i}')
    plt.show()
(2)冻结生成器(如 GAN)
fake_imgs = generator(z).detach()
d_real = discriminator(real_imgs)
d_fake = discriminator(fake_imgs)
  • 防止在训练判别器时 G 的参数也被更新。
(3)防止梯度传播影响损失
x = x.detach() + torch.randn_like(x) * noise_level
  • 假如你想给张量加点噪声,但不想这部分噪声参与反向传播。

六、GAN 场景中的 .detach() 冻结操作

6.1 什么是 GAN?

生成对抗网络(GAN) = 生成器 + 判别器:

  • G:从噪声中生成图像,希望骗过 D。
  • D:判别图像是真是假。

6.2 为什么要 detach() 生成器?

# 1. 训练 D
fake_imgs = generator(z).detach()
loss_D = loss_fn(discriminator(real_imgs), 1) + \
         loss_fn(discriminator(fake_imgs), 0)
loss_D.backward()

# 2. 训练 G
fake_imgs = generator(z)
loss_G = loss_fn(discriminator(fake_imgs), 1)
loss_G.backward()
  • 训练 D 时,G 的输出不应参与反向传播(否则 G 也更新了)。
  • .detach() 手动断开反向路径。

✅ 拆成两步训练,才能实现 G 和 D 的“对抗式交替训练”。


七、几个常见思维误区解答

7.1 为什么 .backward() 的目标是 Loss?

因为 Loss 是模型输出与目标之间差距的标量度量。
从 Loss 出发反推,才能知道“哪里错了”,以修正前面的参数。

7.2 计算图是不是在创建张量那一刻就生成了?

不是。计算图在你使用 requires_grad=True 的张量进行运算时才开始构建。

7.3 推理时既然不写 .backward(),为何还用 no_grad()

虽然你不写 .backward(),但不加 no_grad(),PyTorch 仍然默默构建计算图,浪费内存。
no_grad() 是告诉系统:“我只要 forward,不要 graph 和 grad。”

7.4 .grad_fn 是不是就是 .grad

不是。

  • .grad_fn 是这个张量是由哪个操作生成的(即图中是哪一个函数节点);
  • .grad 是在反向传播时被填充的梯度值。

八、总结对照表(建议收藏)

概念/方法含义用法应用场景
requires_grad=True开启梯度追踪定义张量模型参数
.backward()启动反向传播loss.backward()从 loss 回溯计算各层梯度
torch.no_grad()关闭计算图追踪with torch.no_grad():推理、冻结部分模块
.detach()手动剪断计算图y = x.detach()可视化中间层、GAN 冻结 G
.grad_fn当前张量由哪个操作生成y.grad_fn理解计算图结构
.grad张量的梯度值x.grad反向传播后查看参数梯度

九、结语

理解 PyTorch 的自动微分机制不仅仅是学会调用 .backward(),更重要的是理解它背后是如何动态构建计算图、如何节省资源、如何控制信息流动的。尤其是在复杂任务如 GAN、NLP 微调中,合理使用 .detach()no_grad() 是你进阶为深度学习开发者的标志之一。

希望这篇文章帮你从多个维度、多个场景完全吃透这些关键机制。如果觉得有帮助,欢迎点赞、收藏。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值