作者:Accelemate
发布时间:2025年6月26日
本文摘要:本文将从零开始,系统性地讲解 PyTorch 中的计算图、反向传播、with torch.no_grad()
、.detach()
等核心机制,结合实践场景如可视化中间层特征图、GAN 模型中对生成器的冻结操作等内容,帮助你在实际开发中灵活、正确地使用自动微分特性。
一、自动微分基础概念
1.1 什么是自动微分(Autograd)?
PyTorch 的自动微分机制(Autograd)是其最核心的特性之一。
简而言之:
你只需要写出前向传播的过程(即如何从输入计算出输出),PyTorch 就会自动帮你记录每一步操作,构建出一个“计算图”,并且在调用
.backward()
时,沿着这个图,自动地反向传播误差,计算所有需要梯度的参数的梯度。
1.2 张量的 requires_grad
属性
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
- 这告诉 PyTorch:我要追踪这个张量上的所有操作,未来可能要用来求导。
- 默认是 False,只有当
requires_grad=True
时,计算图才会追踪这个张量。
1.3 从 .backward()
到参数更新
在模型训练过程中,每次前向传播得到 loss(损失值)后,调用:
loss.backward() # 反向传播,计算所有参数的梯度
optimizer.step() # 更新模型参数
optimizer.zero_grad() # 梯度清零,为下次迭代做准备
.backward()
是我们所说的“反向传播”核心函数,它的入口是 Loss。
你可能会问:我们不是要优化权重吗?为啥对 Loss 做反向传播?
✅ 答案是:
- 因为 PyTorch 构建的是从输入到 Loss 的整条计算路径,调用
.backward()
就等于从 Loss 开始,自动向后走回网络的各层参数,使用链式法则一步步计算每个参数对 Loss 的梯度。
这就是自动微分的强大之处。
二、计算图的生成与管理
2.1 什么是计算图?
- 它不是张量,不是图片,而是一个动态图结构。
- 每当你对带
requires_grad=True
的张量进行运算时,PyTorch 就会动态地记录这些操作,并构成一张有向无环图(DAG),这个图就叫做计算图。
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2 + 3 * x # 此时构建计算图
- 此时,PyTorch 记录了
x → x**2 → +3x → y
的整个过程。 - 这图中每个节点是操作(如加法、乘法),边是数据流。
2.2 什么时候构建计算图?
✅ 不是创建张量时,而是在执行运算时。
x = torch.tensor([2.0], requires_grad=True) # 无计算图,仅是张量
z = x ** 2 + 1 # 此刻构建计算图
2.3 .grad_fn
与叶子节点
- 所有由运算生成的张量,都会拥有
.grad_fn
属性,标志其由哪种操作生成。 - 而你自己用
tensor(...)
或.detach()
创建的张量没有.grad_fn
,这些是叶子节点(计算图的起点)。
三、深入理解 .backward()
3.1 基本使用
y = (x ** 2).sum()
y.backward() # 自动计算 x 的梯度
print(x.grad) # tensor([2.0, 4.0, 6.0])
3.2 .backward()
的限制
- 要求输出张量是标量(0维)。
- 如果是向量,需要显式传入梯度:
y.backward(gradient=torch.tensor([1.0, 1.0, 1.0]))
3.3 多次调用 .backward()
的陷阱
loss.backward(retain_graph=True)
- 默认
.backward()
会释放计算图;多次反向需要设定retain_graph=True
。
四、with torch.no_grad()
:推理与节省资源
4.1 推理阶段的常规做法
model.eval()
with torch.no_grad():
output = model(input)
4.2 为什么使用 no_grad()
?
即使你不写 .backward()
,PyTorch 仍会构建计算图,记录中间操作,浪费内存。
✅ 使用 no_grad()
可以:
- 禁用计算图构建;
- 加快运行速度;
- 降低显存占用;
- 防止误用
.grad
。
所以即便不训练,也应养成用 no_grad()
的好习惯。
五、.detach()
:从计算图中摘除变量
5.1 基本用法
x = torch.tensor([2.0], requires_grad=True)
y = x * 3
y_detached = y.detach()
y_detached
与y
数据相同;- 但它不再属于计算图,
y_detached.grad_fn = None
。
5.2 是深拷贝吗?
❌ 不是。是“浅拷贝 + 剪断计算图”。
.detach()
返回的张量与原张量共享内存,但不再追踪其历史。
5.3 典型使用场景
(1)中间层输出特征图可视化
features = model.conv1(input_image) # 中间层输出
features = features.detach() # 切断反向传播路径
为什么要可视化?看模型是否提取了我们期望的特征,比如边缘、纹理等:
import matplotlib.pyplot as plt
for i in range(6):
plt.imshow(features[0, i].cpu(), cmap='gray')
plt.title(f'Feature Map {i}')
plt.show()
(2)冻结生成器(如 GAN)
fake_imgs = generator(z).detach()
d_real = discriminator(real_imgs)
d_fake = discriminator(fake_imgs)
- 防止在训练判别器时 G 的参数也被更新。
(3)防止梯度传播影响损失
x = x.detach() + torch.randn_like(x) * noise_level
- 假如你想给张量加点噪声,但不想这部分噪声参与反向传播。
六、GAN 场景中的 .detach()
冻结操作
6.1 什么是 GAN?
生成对抗网络(GAN) = 生成器 + 判别器:
- G:从噪声中生成图像,希望骗过 D。
- D:判别图像是真是假。
6.2 为什么要 detach() 生成器?
# 1. 训练 D
fake_imgs = generator(z).detach()
loss_D = loss_fn(discriminator(real_imgs), 1) + \
loss_fn(discriminator(fake_imgs), 0)
loss_D.backward()
# 2. 训练 G
fake_imgs = generator(z)
loss_G = loss_fn(discriminator(fake_imgs), 1)
loss_G.backward()
- 训练 D 时,G 的输出不应参与反向传播(否则 G 也更新了)。
- 用
.detach()
手动断开反向路径。
✅ 拆成两步训练,才能实现 G 和 D 的“对抗式交替训练”。
七、几个常见思维误区解答
7.1 为什么 .backward()
的目标是 Loss?
因为 Loss 是模型输出与目标之间差距的标量度量。
从 Loss 出发反推,才能知道“哪里错了”,以修正前面的参数。
7.2 计算图是不是在创建张量那一刻就生成了?
不是。计算图在你使用 requires_grad=True
的张量进行运算时才开始构建。
7.3 推理时既然不写 .backward()
,为何还用 no_grad()
?
虽然你不写 .backward()
,但不加 no_grad()
,PyTorch 仍然默默构建计算图,浪费内存。
no_grad()
是告诉系统:“我只要 forward,不要 graph 和 grad。”
7.4 .grad_fn
是不是就是 .grad
?
不是。
.grad_fn
是这个张量是由哪个操作生成的(即图中是哪一个函数节点);.grad
是在反向传播时被填充的梯度值。
八、总结对照表(建议收藏)
概念/方法 | 含义 | 用法 | 应用场景 |
---|---|---|---|
requires_grad=True | 开启梯度追踪 | 定义张量 | 模型参数 |
.backward() | 启动反向传播 | loss.backward() | 从 loss 回溯计算各层梯度 |
torch.no_grad() | 关闭计算图追踪 | with torch.no_grad(): | 推理、冻结部分模块 |
.detach() | 手动剪断计算图 | y = x.detach() | 可视化中间层、GAN 冻结 G |
.grad_fn | 当前张量由哪个操作生成 | y.grad_fn | 理解计算图结构 |
.grad | 张量的梯度值 | x.grad | 反向传播后查看参数梯度 |
九、结语
理解 PyTorch 的自动微分机制不仅仅是学会调用 .backward()
,更重要的是理解它背后是如何动态构建计算图、如何节省资源、如何控制信息流动的。尤其是在复杂任务如 GAN、NLP 微调中,合理使用 .detach()
和 no_grad()
是你进阶为深度学习开发者的标志之一。
希望这篇文章帮你从多个维度、多个场景完全吃透这些关键机制。如果觉得有帮助,欢迎点赞、收藏。