It's automatic
我们知道,深度学习最核心的其中一个步骤,就是求导:根据函数(linear + activation function)求weights相对于loss的导数(还是loss相对于weights的导数?)。然后根据得出的导数,相应的修改weights,让loss最小化。
各大深度学习框架Tensorflow,Keras,PyTorch都自带有自动求导功能,不需要我们手动算。
在初步学习PyTorch的时候,看到PyTorch的自动求导过程时,感觉非常的别扭和不直观。我下面举个例子,大家自己感受一下。
>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward() # 执行求导
>>> a.grad # a.grad 即导数 d(e)/d(a) 的值
tensor(4.)
这里让人感觉别扭的是,调用 e.backward()
执行求导,为什么会更新 a
对象的状态grad
?对于习惯了OOP的人来说,这是非常不直观的。因为,在OOP里面,你要改变一个对象的状态,一般的做法是,引用这个对象本身,给它的property显示的赋值(比如 user.age = 18
),或者是调用这个对象的方法(user.setAge(18)
),让它状态得以改变。
而这里的做法是,调用了一个跟它(a
)本身看起来没什么关系的对象(e
)的方法,结果改变了它的状态。
每次写代码写到这个地方的时候,我都觉得心里一惊。因此,就一直想一探究竟,看看这内部的关联究竟是怎么样的。
根据上面的代码,我们知道的是,e
的结果,是由c
和d
运算得到的,而c
,又是根据a
和b
相加得到的。现在,执行e
的方法,最终改变了a
的状态。因此,我们可以猜测e
内部可能有某个东西,引用着c
,然后呢,c
内部又有些东西,引用着a
。因此,在运行e
的backward()
方法时,通过这些引用,先是改变c
,在根据c
内部的引用,最终改变了a
。如果我们的猜测没错的话,那么这些引用关系到底是什么呢?在代码里是怎么提现的呢?
想要知道其中原理,最先想到的办法,自然是去看源代码。
遗憾的是,backward()
的实现主要是在C/Cpp层间做的,在Python层面做的事情很少,基本上就是对参数做了一下处理,然后调用native层面的实现。如下:
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
...more comment
"""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
if grad_tensors is None:
grad_tensors = grad_variables
else:
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
if grad_tensors is None:
grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
grad_tensors = [grad_tensors]
else:
grad_tensors = list(grad_tensors)
grad_tensors = _make_grads(tensors, grad_tensors)
if retain_graph is None:
retain_graph = create_graph
Variable._execution_engine.run_backward(
tensors, grad_tensors, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag
说到Cpp。。。
看来只能通过一顿自行的探索操作,来了解这个执行过程了。
我们先看看e
里面有什么。
由于e
是一个Tensor
变量,我们自然想到去看Tensor
这个类的代码,看看里面有哪些成员变量。不幸的是,Python语言声明成员变量的方式跟Java这些静态语言不一样,他们是用到的时候直接用self.xxx
随时声明的。不像Java这样,在某一个地方统一声明并做初始化。
当然,我们可以用正则表达式 self\.\w+\s+=
搜索所有类似于 self.xxx =
的地方,于是你会找到一些data
, requires_grad
, _backward_hooks
, ret