Autograd
其实反向传播是一个比较底层的东西,如果有自动求导,那就会方便用户很多。简单的线性回归手动实现比较容易,如果复杂的网络,确实会比较费时费力,还容易出错。Pytorch就提供这样一套自动求导引擎,可以根据输入和前向传播过程自动构建计算图,执行反向传播。
1. Variable
Variable 的数据结构是这样的:
data是保存tensor的,grad是保存data对应的梯度,是Variable格式的,grad_fn指向一个Function,记录variable的操作历史。
早期Variable的创建是需要tensor,类似这样
a = V(t.ones(3, 4), requires_grad = True)
目前Pytorch的版本已经可以直接这样:
a = t.ones(3, 4).requires_grad_(True)
不区分tensor和Variable,Tensors/Variables 合并,弃用 volatile 标志,原来若True,在这之后的图都不会求导。
2. 计算图
计算图就是autograd的底层
用户自己创建的叫叶子节点,叶子节点的grad_fn为None,若需要求导,具有AccumulateGrad标识
import torch as t
x = t.ones(1)
b = t.rand(1).requires_grad_(True)
w = t.rand(1).requires_grad_(True)
y = w * x
z = y + b
print(x.requires_grad, b.requires_grad, w.requires_grad, y.requires_grad)
print(x.is_leaf, b.is_leaf, w.is_leaf, y.is_leaf, z.is_leaf)
print(z.grad_fn)
print(z.grad_fn.next_functions) # tuple是 y和b的grad_fn
print(z.grad_fn.next_functions[0][0] == y.grad_fn)
variable.backward(grad_variable=None, retain_graph=None, creat_graph=None)
grad_variable:size和variable一样
retain_graph:反向传播中需要缓存的一些中间结果(非叶子节点计算完就会被清空)
creat_graph:对反向传播过程再次构建计算图
对 grad_variable的解释:
所以反向传播backward的grad_variable可以看成是链式求导的中间结果,如果是标量可以省略,默认是1
2. 用auto实现线性回归
import torch as t
import matplotlib.pyplot as plt
# 设置随机数种子
t.manual_seed(1000)
def get_fake_data(batch_size = 8):
# 产生随机数据, y = 2*x + 3, 加上一些噪音
x = t.rand(batch_size, 1) * 20
y = 2 * x + (1 + t.randn(batch_size, 1)) * 3
return x, y
# 初始化w,b
w = t.rand(1, 1).requires_grad_(True)
b = t.zeros(1, 1).requires_grad_(True)
lr = 0.001 # 学习率
for step in range(20000):
x, y = get_fake_data()
# forward: 计算loss
y_pred = x.mm(w) + b.expand_as(y)
loss = 0.5 * (y_pred - y) ** 2 # 均方误差
loss = loss.sum()
# backward: 自动计算梯度
loss.backward()
# 更新
w.data.sub_(lr * w.grad.data)
b.data.sub_(lr * b.grad.data)
# 梯度清零
w.grad.data.zero_()
b.grad.data.zero_()
if step % 5000 == 0:
# 显示
x = t.arange(0, 20).view(-1, 1).float()
y = x.mm(w.data) + b.data.expand_as(x)
plt.plot(x.numpy(), y.numpy()) # predicted
x2, y2 = get_fake_data(20)
plt.scatter(x2.numpy(), y2.numpy()) # true data
plt.xlim(0, 20)
plt.ylim(0, 41)
plt.show()
plt.pause(0.5)
print(w.squeeze(), b.squeeze())
这里的不同就是不需要手机计算梯度,可以autograd,记住每次反向传播之前需要把梯度清零。