说起backward大家肯定不陌生,用过PyTorch的肯定都知道,这个函数的作用是反向传播计算梯度的。比如下边这个例子,要反向传播计算梯度之后,才能调用优化器的step函数更新网络模型参数。
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
[1] torch.Tensor.backward
在 torch/tensor.py 文件中可以看到,class Tensor(torch._C._TensorBase)中有函数def backward。所以我们可以用tensor.backward()来进行反向传播。
def backward(self, gradient=None, retain_graph=None, create_graph=False):
r"""Computes the gradient of current tensor w.r.t. graph leaves.The graph is differentiated using the chain rule. If the tensor isnon-scalar (i.e. its data has more than one element) and requiresgradient, the function additionally requires specifying ``gradient``.It should be a tensor of matching type and location, that containsthe gradient of the differentiated function w.r.t. ``self``.This function accumulates gradients in the leaves - you might need tozero them before calling it.Arguments:gradient (Tensor or None): Gradient w.r.t. thetensor. If it is a tensor, it will be automatically convertedto a Tensor that does not require grad unless ``create_graph`` is True.None values can be specified for scalar Tensors or ones thatdon't require grad. If a None value would be acceptable thenthis argument is optional.retain_graph (bool, optional): If ``False``, the graph used to computethe grads will be freed. Note that in nearly all cases settingthis option to True is not needed and often can be worked aroundin a much more efficient way. Defaults to the value of``create_graph``.create_graph (bool, optional): If ``True``, graph of the derivative willbe constructed, allowing to compute higher order derivativeproducts. Defaults to ``False``."""
torch.autograd.backward(self, gradient, retain_graph, create_graph)
其中,create_graph参数的作用是,如果为True,那么就创建一个专门的graph of the derivative,这可以方便计算高阶微分。参数retain_graph可以忽略,因为绝大多数情况根本不需要,它的作用是要不要保留Graph。该函数实现代码也很简单,就是调用torch.autograd.backward。所以接下来看一下torch.autograd.backward中的实现。
[2] torch.autograd.backward
函数torch.autograd.backward的定义在文件 torch/autograd/__init__.py 中。借助于链式法则the chain rule和Jacobian-vector product可以很方便的计算梯度。下边就是具体的代码。
# ...
from .variable import Variable
# ...
def _make_grads(outputs, grads):
new_grads = []
for out, grad in zip(outputs, grads):
if isinstance(grad, torch.Tensor):
if not out.shape == grad.shape:
# raise RuntimeError ...
new_grads.append(grad)
elif grad is None:
if out.requires_grad:
if out.numel() != 1:
# raise RuntimeError ...
else:
new_grads.append(None)
else:
# raise TypeError ...
return tuple(new_grads)
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.The graph is differentiated using the chain rule. If any of ``tensors``are non-scalar (i.e. their data has more than one element) and requiregradient, then the Jacobian-vector product would be computed, in thiscase the function additionally requires specifying ``grad_tensors``.It should be a sequence of matching length, that contains the "vector"in the Jacobian-vector product, usually the gradient of the differentiatedfunction w.r.t. corresponding tensors (``None`` is an acceptable value forall tensors that don't need gradient tensors).This function accumulates gradients in the leaves - you might need to zerothem before calling it."""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
if grad_tensors is None:
grad_tensors = grad_variables
else:
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
if grad_tensors is None:
grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
grad_tensors = [grad_tensors]
else:
grad_tensors = list(grad_tensors)
grad_tensors = _make_grads(tensors, grad_tensors)
if retain_graph is None:
retain_graph = create_graph
Variable._execution_engine.run_backward(
tensors, grad_tensors, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag
# ...
if not torch._C._autograd_init():
raise RuntimeError("autograd initialization failed")
参数grad_variables是老版本的,已经被deprecated,现在使用的是grad_tensors。即便你使用了也没关系,代码会把参数grad_variables的