python求导_PyTorch的自动求导功能(Autograd)原理解析

本文探讨了PyTorch的自动求导(Autograd)原理,通过实例解释了如何通过调用`.backward()`更新Tensor的状态。文章揭示了在Python和C/C++层之间的工作流程,以及在`.backward()`执行时如何通过Function对象串联Tensor,从而实现求导。作者通过探索源代码、使用内置方法以及Google搜索,逐步解析了这个过程,并以一张关系图总结了不同Variable和Function之间的关系。
摘要由CSDN通过智能技术生成

fd586d4d900befc5ea37f1a675398692.png

It's automatic

我们知道,深度学习最核心的其中一个步骤,就是求导:根据函数(linear + activation function)求weights相对于loss的导数(还是loss相对于weights的导数?)。然后根据得出的导数,相应的修改weights,让loss最小化。

各大深度学习框架Tensorflow,Keras,PyTorch都自带有自动求导功能,不需要我们手动算。

在初步学习PyTorch的时候,看到PyTorch的自动求导过程时,感觉非常的别扭和不直观。我下面举个例子,大家自己感受一下。

>>> import torch
>>>
>>> a = torch.tensor(2.0, requires_grad=True)
>>> b = torch.tensor(3.0, requires_grad=True)
>>> c = a + b
>>> d = torch.tensor(4.0, requires_grad=True)
>>> e = c * d
>>>
>>> e.backward() # 执行求导
>>> a.grad # a.grad 即导数 d(e)/d(a) 的值
tensor(4.)

这里让人感觉别扭的是,调用 e.backward()执行求导,为什么会更新 a 对象的状态grad?对于习惯了OOP的人来说,这是非常不直观的。因为,在OOP里面,你要改变一个对象的状态,一般的做法是,引用这个对象本身,给它的property显示的赋值(比如 user.age = 18),或者是调用这个对象的方法(user.setAge(18)),让它状态得以改变。

而这里的做法是,调用了一个跟它(a)本身看起来没什么关系的对象(e)的方法,结果改变了它的状态。

每次写代码写到这个地方的时候,我都觉得心里一惊。因此,就一直想一探究竟,看看这内部的关联究竟是怎么样的。

根据上面的代码,我们知道的是,e的结果,是由cd运算得到的,而c,又是根据ab相加得到的。现在,执行e的方法,最终改变了a的状态。因此,我们可以猜测e内部可能有某个东西,引用着c,然后呢,c内部又有些东西,引用着a。因此,在运行ebackward()方法时,通过这些引用,先是改变c,在根据c内部的引用,最终改变了a。如果我们的猜测没错的话,那么这些引用关系到底是什么呢?在代码里是怎么提现的呢?

想要知道其中原理,最先想到的办法,自然是去看源代码。

遗憾的是,backward()的实现主要是在C/Cpp层间做的,在Python层面做的事情很少,基本上就是对参数做了一下处理,然后调用native层面的实现。如下:

def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None):
r"""Computes the sum of gradients of given tensors w.r.t. graph leaves.
...more comment
"""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
if grad_tensors is None:
grad_tensors = grad_variables
else:
raise RuntimeError("'grad_tensors' and 'grad_variables' (deprecated) "
"arguments both passed to backward(). Please only "
"use 'grad_tensors'.")

tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

if grad_tensors is None:
grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
grad_tensors = [grad_tensors]
else:
grad_tensors = list(grad_tensors)

grad_tensors = _make_grads(tensors, grad_tensors)
if retain_graph is None:
retain_graph = create_graph

Variable._execution_engine.run_backward(
tensors, grad_tensors, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag

说到Cpp。。。

cc157850e3f73fbed13840fab1a86b2d.png

看来只能通过一顿自行的探索操作,来了解这个执行过程了。

我们先看看e里面有什么。

由于e是一个Tensor变量,我们自然想到去看Tensor这个类的代码,看看里面有哪些成员变量。不幸的是,Python语言声明成员变量的方式跟Java这些静态语言不一样,他们是用到的时候直接用self.xxx随时声明的。不像Java这样,在某一个地方统一声明并做初始化。

当然,我们可以用正则表达式 self\.\w+\s+= 搜索所有类似于 self.xxx = 的地方,于是你会找到一些datarequires_grad_backward_hooksret

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值