AUTOGRAD: AUTOMATIC DIFFERENTIATION
autograd程序包:对张量的所有运算自动微分。
import torch
张量
-
torch.Tensor
-
训练:
追踪张量:
.requires_grad = True
.backward()
.grad
停止追踪:
.detach()
- 评估(evaluating)
禁止追踪:
with torch.no_grad():
- Function类
Tensor和Function通过构建无环图(acyclic graph),对计算过程编码。Tensor的.grad_fn属性指向创建Tensor的Function(Tensor由用户创建时,.grad_fn = None)。
- 计算导数
.backward(gradient)
x = torch.ones(2, 2, requires_grad=True)
print(x)
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
y = x + 2
print(y)
print(y.grad_fn)
tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x000001EF8A3D7160>
z = y * y * 3
out = z.mean()
print(z, out)
tensor([[27., 27.],
[27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
.requires_grad_()修改(in-place)已有Tensor的.requires_grad属性(True、False(default))
a = torch.randn(2, 2)
a = (a * 3) / (a - 1)
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a * a).sum()
print(b.grad_fn)
False
True
<SumBackward0 object at 0x000001EF8A3D7F98>
梯度(Gradients)
# scalar
out.backward() # equivalent to out.backward(torch.tensor(1.))
# out.backward(torch.tensor(1.))
∂ o u t ∂ x \frac{\partial out}{\partial x} ∂x∂out
o = 1 4 ∑ i z i z i = 3 ( x i + 2 ) 2 z i ∣ x i = 1 = 27 ↓ ∂ o ∂ x = 1.5 ( x i + 2 ) ∂ o ∂ x ∣ x i = 1 = 4.5 \begin{aligned} & o = \frac{1}{4}\sum_i z_i \\ & z_i = 3 (x_i + 2)^2 \\ & z_i |_{x_i = 1} = 27 \\ & \downarrow \\ & \frac{\partial o}{\partial x} = 1.5 (x_i + 2) \\ & \frac{\partial o}{\partial x} |_{x_i = 1} = 4.5 \end{aligned} o=41i∑zizi=3(xi+2)2zi∣xi=1=27↓∂x∂o=1.5(xi+2)∂x∂o∣xi=1=4.5
# gradients d(out)/dx
print(x.grad)
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
torch.autograd是计算向量-雅露克比积(vector-Jacobian product)的引擎
x = torch.randn(3, requires_grad=True)
y = x * 2
while y.data.norm() < 1000:
y = y * 2
print(y)
tensor([-1494.1448, -286.1320, -200.1645], grad_fn=<MulBackward0>)
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)
print(x.grad)
tensor([1.0240e+02, 1.0240e+03, 1.0240e-01])
停止autograd
print(x.requires_grad)
print((x ** 2).requires_grad)
with torch.no_grad():
print((x ** 2).requires_grad)
True
True
False