写在前面:这篇博客主要是对 pytorch 的自动求导机制进行学习,主要内容来自知乎:PyTorch 的 Autograd
1. 计算图
假设我们有一个复杂的神经网络模型,我们把它想象成一个错综复杂的管道结构,不同的管道之间通过节点连接起来,我们有一个注水口,一个出水口。我们在入口注入数据的之后,数据就沿着设定好的管道路线缓缓流动到出水口,这时候我们就完成了一次正向传播。
计算图通常包含两种元素,一个是 tensor,另一个是 Function。
Function 指的是在计算图中某个节点(node)所进行的运算,比如加减乘除卷积等等之类的,Function 内部有 forward()
和 backward()
两个方法,分别应用于正向、反向传播。
import torch
a = torch.tensor(2.0, requires_grad=True)
b = a.exp()
print(b)
# tensor(7.3891, grad_fn=<ExpBackward>)
ExpBackward
:为反向传播做一些准备,为反向计算图添加 Function 节点。在上边这个例子中,变量 b 在反向传播中所需要进行的操作是 。
2. 具体实例
假如我们需要计算这么一个模型:
y
=
(
x
w
1
+
w
2
)
(
x
w
1
∗
w
3
)
y=(xw_1 + w_2)(xw_1 * w_3)
y=(xw1+w2)(xw1∗w3)
loss值是
loss = mean(l4)
那么,写成代码为:
l1 = input x w1
l2 = l1 + w2
l3 = l1 x w3
l4 = l2 x l3
loss = mean(l4)
input = torch.ones([2, 2], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)
w3 = torch.tensor(4.0, requires_grad=True)
l1 = input * w1
l2 = l1 + w2
l3 = l1 * w3
l4 = l2 * l3
loss = l4.mean()
print(w1.data, w1.grad, w1.grad_fn)
# tensor(2.) None None
print(l1.data, l1.grad, l1.grad_fn)
# tensor([[2., 2.],
# [2., 2.]]) None <MulBackward0 object at 0x000001EBE79E6AC8>
print(loss.data, loss.grad, loss.grad_fn)
# tensor(40.) None <MeanBackward0 object at 0x000001EBE79D8208>
正向传播的结果基本符合我们的预期。我们可以看到,变量 l1 的 grad_fn 储存着乘法操作符 ,用于在反向传播中指导导数的计算。而 w1 是用户自己定义的,不是通过计算得来的,所以其 grad_fn 为空;同时因为还没有进行反向传播,grad 的值也为空。
中间过程:
input = [1.0, 1.0, 1.0, 1.0]
w1 = [2.0, 2.0, 2.0, 2.0]
w2 = [3.0, 3.0, 3.0, 3.0]
w3 = [4.0, 4.0, 4.0, 4.0]
l1 = input x w1 = [2.0, 2.0, 2.0, 2.0]
l2 = l1 + w2 = [5.0, 5.0, 5.0, 5.0]
l3 = l1 x w3 = [8.0, 8.0, 8.0, 8.0]
l4 = l2 x l3 = [40.0, 40.0, 40.0, 40.0]
loss = mean(l4) = 40.0
接下来我们继续运行代码,并检查一下结果和自己算的是否一致:
loss.backward()
print(w1.grad, w2.grad, w3.grad)
# tensor(28.) tensor(8.) tensor(10.)
print(l1.grad, l2.grad, l3.grad, l4.grad, loss.grad)
# None None None None None
3. 叶子张量
对于任意一个张量来说,我们可以用 tensor.is_leaf 来判断它是否是叶子张量(leaf tensor)。在反向传播过程中,只有 is_leaf=True 的时候,需要求导的张量的导数结果才会被最后保留下来。
4. inplace 操作
我们如果在某种情况下需要重新对叶子变量赋值该怎么办呢?有办法!
# 方法一
a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
print(a, a.is_leaf, id(a))
# tensor([10., 5., 2., 3.], requires_grad=True) True 2501274822696
a.data.fill_(10.)
# 或者 a.detach().fill_(10.)
print(a, a.is_leaf, id(a))
# tensor([10., 10., 10., 10.], requires_grad=True) True 2501274822696
loss = (a*a).mean()
loss.backward()
print(a.grad)
# tensor([5., 5., 5., 5.])
# 方法二
a = torch.tensor([10., 5., 2., 3.], requires_grad=True)
print(a, a.is_leaf)
# tensor([10., 5., 2., 3.], requires_grad=True) True
with torch.no_grad():
a[:] = 10.
print(a, a.is_leaf)
# tensor([10., 10., 10., 10.], requires_grad=True) True
loss = (a*a).mean()
loss.backward()
print(a.grad)
# tensor([5., 5., 5., 5.])