反向传播: 揭秘神经网络的学习机制

为了对万能近似函数的每个参数求偏导,我们需要使用反向传播作为方法。

所谓反向传播,与之相对的就是正向传播。神经网络执行是从前到后的,这是正向传播,而为神经网络的各个节点求导,则需要从最后一个输出节点向前推导,因为顺序是从后往前的,所以成为反向传播。

我们先看一个例子。

神经网络图

09327463bc6a2cbbbe54d9958119c7ea.png

我们有如上的神经网络,一共有 1~7 个节点,其中 1 号是输入节点,7 号是输出节点,而 2~6 号就是中间的隐藏层中的所谓隐藏节点。

在训练时,除了 1 号输入节点的值来自于 Training data 外,其余 2~7 号节点的值都来自于计算,计算公式都是上图所示的这个。

举例来说,对于 2、4 号节点来说,具体的计算公式如下:

8f78d7a1fc13475a7e5291687b2aef86.png

首先 x1 代表节点 1 的值,x2 就代表节点 2 的值。然后我们可以发现,所有节点计算用到的参数,都存储在当前节点上,而上一个节点仅提供其节点值。

接下来我们看看如何计算各节点参数针对计算结果的偏导数。

计算输出节点参数的偏导数

对于复杂的神经网络,计算导数肯定是比较复杂的,所以为了便于理解,我们先拿最好理解的输出节点 7 举例。

be018e93714863d1cc9467808cf7d304.png

为了方便计算,我们假设节点 4、5、6、7 的值分别为 4、5、6、7。假设训练资料的期望输出是 10,但我们拿到的实际输出却是 7。

接下来我们要计算的是参数 b c w4to7 w5to7 w6to7 针对 loss function 的偏导。

乍一看,没有明显的头绪,我们先从 loss 函数定义开始思考。

假设 loss function 采用均方差的逻辑,那么其函数定义如下:

1d6ec7ef648646405c5e26231617cf49.png

  • x 表示输出节点实际输出的值。

  • t 表示 training,即训练资料的预期输出值。

因为 t 作为训练资料的预期输出,无论如何训练值都不变,因此作为函数的一个已知值。而神经网络的输出 x 才是变量。

那么 loss 函数针对 x 的求导公式如下:

aea0856815f378b3ecde71bd1c1a2d23.png

在本次例子中,x=7t=10,所以 dloss/dx 的值为 -6也就是说,我们可以认为,根据实际输出值与训练资料的预期输出值,dloss/dx 可以很容易计算出来,成为一个已知的值。

然后我们看 x 是什么?x 就是节点 7 的输出值,它来自于函数:1670c2093e916a2af9e10d149367a99e.png

我们可以根据 f(x) 对该函数内任何变量求偏导。那么根据 链式法则,我们就可以逐个击破了。首先是 dloss/db:

e41c274f6c8985b032ad8573669d8c71.png

由于 dloss/dx=-6,我们继续观察 dx/db 怎么计算。

首先我们把这个复合的 f(x) 函数改写成更原子的函数系列:

cca44efef6514bc74199312d387361d8.png

080762760191dc1234555bfb2b74a28a.png

34c184e6583787f64405522669ffa55d.png

那么根据链式法则:

9643f26cbbeb234c0dd5cfb43a2ea2ce.png

其中 dx/dg=c,而根据 simgoid 函数特性,dg/dh=g(x) * (1-g(x)),而 dh/db=1,那么三者相乘后,结果是:

dx/db = c * g(x) * (1-g(x))

注意,为了让神经网络启动,此时所有参数都会有一个随机出来的值,总之它们是有值的。假设 w4to7=1 w5to7=1 w6to7=1 c=1 b=1,而 x4=4 x5=5 x6=6,再加上 sigmoid 函数在之前我们就说过了计算公式,这样整个公式所有的值都明确了,虽然复杂了点儿,但一定能计算出结果。至此我们得到了 dr/db 的值。

最后 dloss/db = dloss/dx * dx/db,后面两个表达式的值都算了出来,至此得到了 dloss/db 的值,我们求出了输出节点 7 的参数 b 针对 loss 函数的偏导。

接下来,还有几个参数,但有了参数 b 的偏导过程,剩下的都大同小异了:

8df8c2799618e78b881519aa2dbf0e88.png

根据展开公式:

61991f794842b490ea177cad6771ce9e.png

易得 dx/dc=g(b),而 g(b) 同理所有参数都可计算,因此 dloss/dc 也求出来了。

接下来 w4to7 w5to7 w6to7 计算方式都类似,我们以 w4to7 举例:

654a7fdb2342ea048d4f12a45f5ab4f1.png

根据展开公式:

57315e79731f2a8a430b945952f32071.png

7401756e3d5a6dd85ea8393de276efc5.png

1c3a9255ee7d748c484a52171062ba95.png


其中 dx/dg=c,而根据 simgoid 函数特性,dg/dh=g(x7) * (1-g(x7)),而 dh/dw4to7=x4,那么三者相乘后,结果是:

dx/dw4to7 = c * g(x7) * (1-g(x7)) * x4

其实与求参数 b 的偏导相比,这里只不过结果多乘了 x4 的值,因此从工程角度来看,前置计算 c * g(x7) * (1-g(x7)) 可以缓存下来,方便后续计算复用。

其他节点的参数偏导

在输出节点的偏导计算中,我们抓住了 loss function 定义,根据其公式的偏导,以及函数表达式结合链式法则,总算得到了各参数的偏导,但再往神经网络深处反向传播时,应该怎么办呢?

其实还是用链式法则。以节点 4 举例:

9e61023161281b160433e4266f7b449f.png

要求节点 4 的参数 b 针对 loss function 的导数 dloss/db4,我们可以写成:2d905fef0f82580483b220db82370c12.png

其中 dx4/db4 计算逻辑与上一节一样:

dx4/db4 = c4 * g(x4) * (1-g(x4))

对该节点其他参数的求导同理。所以最重要的是 dloss/dr4 怎么计算,如果我们求出了它,那么最终节点的 loss function 对节点 4 任意参数的偏导我们就都可以推导出来了。

还是根据链式法则:

b9de4dc740410755c42178eadf917084.png

其中 dloss/dx7 在前面已经算出来了等于 -6,我们只要关注 dx7/dx4 怎么算。

由于 x7 指的是节点 7 实际输出的值,x4 指的是节点 4 实际输出的值,而 x7 的计算公式是:

c42feeb3257598ae8796de4676d0b647.png

可以看到,dx7/dx4 就是对上述函数对 x4 求偏导! 因为节点 7 是由节点 4 计算得到的,它们必须紧挨着才能用公式求导,否则就要使用链式法则继续在链条上推导。

根据拆分公式:

8c5aeaf7fff5b8101a5222028250799c.png

0aca5a946f60bc38a43dec267b79dd9f.png

5fa633d572fe600aaf283018a4a1dffd.png


可以看到,只有最后一步根据 x4 求导得到的是 w4to7,因此:

dx7/dx4= c * g(x7) * (1-g(x7)) * w4to7

节点 4 的问题解决了,我们再来看节点 2。节点 2 的代表性在于,它不直接挨着输出节点,我们要理清楚它是怎么由输出节点反向推导出来的:

c3be8f253bb2cf761a32652e9e831508.png

可以看到,节点 2 由节点 4、5、6 共同反向推导而成,因此我们使用链式法则:

2e0a4988ee26124b07874229624cde14.png

以节点 4 举例,dx4/dx2 的处理方法我们刚刚说过,计算公式最后乘以的数字改成 w2to4 即可,而 dloss/dx4 我们刚刚计算过。因此对于任意节点,应用链式法则都可以求出最终 loss function 对该节点任意参数的导数。

值得注意的是,当某节点反向传播由多个节点回溯而来时,链式法则采取加法关系。

另外,从工程角度来看,对于任意一个节点 i,在计算反向传播的过程中,有几个值会反复用到:

  1. 节点与节点间链式法则时,需要多次用到 dloss/dxi

  2. dx(i+j)/dxi 虽然仅用到一次,但对于节点 ici * g(xi) * (1-g(xi)) 在计算自身参数偏导,或者在节点与节点间链式法则时会频繁用到。

对于每个节点来说,这两个值都需要缓存下来,避免重复计算。

总结

我们总结一下,求任意节点参数对 loss 的偏导的过程。

总的来看,一共只有三种计算手法,这三种手法使用的场景和套路掌握清楚后,你就会牢牢掌握反向传播,再也不容易忘了。

  • 场景 1:输出节点的值 x 对 loss function 的偏导。

  • 场景 2:非输出节点的值 x 对 loss function 的偏导。

  • 场景 3:已知某个节点值 x 对 loss function 的偏导后,求该节点任意参数对 loss function 的偏导。

场景 1 和 场景 2 是为了拿到任意节点值 x,对 loss function 的偏导。有了这个作为原材料后,结合场景 3,使用链式法则,就可以求出对任意节点的参数对 loss function 的偏导。

为了加深印象,我们一个一个说明,这次我们把理解拔高一个层次。

场景 1:输出节点的值 x 对 loss function 的偏导。

因为 loss function 定义的就是 loss(x),所以该偏导值完全取决于 loss function 的设计,比如采用均方差的话:

d287062188d7be6f9fdba55bde71db9e.png

51604507c7abd46439b83bb1643e5dac.png

没有其他任何因素影响了,它已经是 loss 计算的最后一步。

场景 2:非输出节点的值 x 对 loss function 的偏导。

利用场景 1 得到的结果,用链式法则一步步跳到目标节点。比如节点关系为 1 -> 2 -> 3,节点 3 是最后一个节点,此时我们已知的值是:

d766fa49a17757e3c09080cfda12bec3.png

利用它作为跳板,比如求节点 1 的值 x1 对 loss function 的偏导:

f114b9a091d032cf4339968ee6c06d88.png

所以重点是我们要知道如何求 dxi/dxj 的值,且已知节点 i 的下一个节点是 j

此时我们要理解的是,神经网络中,节点 j 的值 xj 是由前面节点计算出来的,就包括了节点 i 贡献的 xi,所以根据计算公式就可以求出 dxi/dxj 的结果。

场景 3:已知某个节点值 x 对 loss function 的偏导后,求该节点任意参数对 loss function 的偏导。

此时目光锁定在这个节点内,要求的是 loss function 对该节点任意参数 z 的偏导:

2522d4281dec96b3ec2ec635c0a7a2aa.png

依然使用链式法则,经历过场景 1、2,我们已经计算出任意节点值对 loss function 的偏导,因此用它作为跳板:

73214078ffb7de97f368e3c22eb11abc.png

而参数 z 与该节点输出值 x 的关系,就藏在该节点的计算公式上,所以我们又回到了节点计算公式,计算该参数的偏导即可。

下一篇文章,我们将实践这些理论,写一个完整的神经网络实现,并让它运行起来!

  • 13
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值