torch.autograd提供实现任意标量值功能的自动区分的类和功能。 它需要对现有代码进行最小的更改,只需将所有张量包装在Variable对象中。
Variable API与常规Tensor API几乎相同(除了几个需要对梯度计算的输入进行重写的in-place方法外)。在大多数情况下,Tensors可以安全地替换为Variable,代码将保持工作正常。所以,在这里只介绍有别于Tensors的一些操作。
Variable的in-place(原地操作):在autograd中支持就地操作是一件很困难的事情,大多数情况下都不鼓励使用它们。 Autograd的积极缓冲区释放和重用使in-place变得非常高效,极少数场合就地操作能够实质上地降低内存使用量。 除非您在内存压力很大的情况下运行,否则您可能永远不需要使用它们。
in-place准确性检查:所有的Variable应用并跟踪in-place操作,如果在实现过程中发现一个变量在一个函数中保存为后向,并在in-place前向操作时被修改,那么一但反向传播开始,会产生一个错误。这可以确保如果您使用就地操作并且没有看到任何错误,则可以确定计算出的梯度是正确的。
torch.autograd.Variable用来包裹张量并记录应用的操作。
Variable可以看作是对Tensor对象周围的一个薄包装,也包含了和张量相关的梯度,以及对创建它的函数的引用。 此引用允许对创建数据的整个操作链进行回溯。需要BP的网络都是通过Variable来计算的。如果Variable是由用户创建的,则其grad_fn将为None,我们将这些对象称为叶子Variable。
由于自动求导仅支持标量值函数微分,因此grad大小始终与数据大小匹配。 此外,grad通常仅分配给叶子Variable,否则将始终为零。
class Variable(_C._VariableBase):
"""
Attributes:
data: 任意类型的封装好的张量。
grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
grad_fn: Gradient function graph trace.
Parameters:
data (any tensor class): 要包装的张量.
requires_grad (bool): bool型的标记值. **Keyword only.**
volatile (bool): bool型的标记值. **Keyword only.**
"""
def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
"""计算关于当前图叶子变量的梯度,图使用链式法则导致分化
如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
函数在叶子上累积梯度,调用前需要对该叶子进行清零。
Arguments:
grad_variables (Tensor, Variable or None):
变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
默认值为create_graph的值。
create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
默认为False,除非``gradient``是一个volatile变量。
"""
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
def register_hook(self, hook):
"""Registers a backward hook.
每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
Example:
>>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.Tensor([1, 1, 1]))
>>> v.grad.data
2
2
2
[torch.FloatTensor of size 3]
>>> h.remove() # removes the hook
"""
if self.volatile:
raise RuntimeError("cannot register a hook on a volatile variable")
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a variable that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def reinforce(self, reward):
"""Registers a reward obtained as a result of a stochastic process.
区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
Parameters:
reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
"""
if not isinstance(self.grad_fn, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions")
self.grad_fn._reinforce(reward)
def detach(self):
"""返回一个从当前图分离出来的心变量。
结果不需要梯度,如果输入是volatile,则输出也是volatile。
.. 注意::
返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
def detach_(self):
"""从创建它的图中分离出变量并作为该图的一个叶子"""
self._grad_fn = None
self.requires_grad = False
def retain_grad(self):
"""Enables .grad attribute for non-leaf Variables."""
if self.grad_fn is None: # no-op for leaves
return
if not self.requires_grad:
raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
if hasattr(self, 'retains_grad'):
return
weak_self = weakref.ref(self)
def retain_grad_hook(grad):
var = weak_self()
if var is None:
return
if var._grad is None:
var._grad = grad.clone()
else:
var._grad = var._grad + grad
self.register_hook(retain_grad_hook)
self.retains_grad = True