CS231n笔记三:神经网络之反向传播

本文介绍了计算图的概念,它用于表示复杂函数的计算流程,特别是在损失函数的计算中。反向传播是利用链式法则递归计算梯度的过程,用于更新模型参数。文章通过实例详细阐述了反向传播的步骤,并讨论了加法、乘法等门在反向传播中的角色。此外,还提及了多维向量和矩阵的梯度计算,并提出了前向传播和反向传播的API设计思路。最后,强调了梯度计算在优化参数矩阵中的重要性。
摘要由CSDN通过智能技术生成

一、计算图 Computational graphs

在这个计算图中,节点表示一种计算操作。
下图就是损失函数L的计算图。即输入x和W,首先将x和W进行矩阵乘,得到分数s;随后计算hinge loss和正则化R(W),之后将二者相加得到最终的L。
在这里插入图片描述
应用这种计算图的框架,我们可以求出任意复杂函数的解析梯度。


二、反向传播 Back Propagation

  1. 定义
    反向传播是链式法则的递归调用。具体的计算过程如下:
    在这里插入图片描述
    1、从最后开始往前算,即首先计算df/df,显然为1;
    2、之后计算df/dz,因为df/dz = q,而q=x+y=3,故df/dz = 3;
    3、同理,可以计算得df/dq = -4;
    4、之后我们希望得到df/dy,但是f和y没有直接的关系,这就需要用到链式法则。即df/dy = (df/dq) * (dq/dy),而dq/dy = 1,故可得df/dy = -4。同理可得df/dx = -4;
    从这里可以看出,对于同一节点的输入输出的微分,我们可以直接求导得到它们的关系;但对于不同节点的值,无法直接得到关系,因而需要用到反向传播算法和链式法则来找到它们之间的关联。
    每次计算我们只要关心local gradient(通过下游输入的值和local关系公式计算得到)和上游输入的gradient值即可,随后将二者相乘得到该节点为下游节点输出的最终gradient值
    一个更加抽象的实例是这样的,其中红色的线代表反向传播的过程:
    在这里插入图片描述
    注意,其实通过正向的链式法则可以直接得到dL/dx = dL/dz * dz/dx的表达式,但这里我们需要的不是表达式,我们只关心这些梯度最终的值是多少,不需要弄那么复杂的表达式。
    下面是一个更复杂的例子:
    在这里插入图片描述

  2. 一些有趣的local gradient
    1、加法门:gradient distributor,local gradient对两个输入来说都是1,上游输入的gradient直接就是该节点输出给下游的gradient;
    2、max门:gradient router,对大的那个输入的local gradient是1,对小的输入的local gradient是0,相当于将上游输入的gradient进行了一个路由选择;
    3、乘法门:gradient switcher,对输入1的local gradient是输入2的值,对输入2的local gradient是输入1的值,相当于进行了交换。
    4、gradients add at branches:在反向传播过程中,当涉及多个上游节点的gradient汇聚到同一个下游节点时,需要将所有上游节点的gradients相加最为这个local节点的总上游梯度。这可以解释为:因为在前向传播计算函数值时,这个local节点的取值会影响与之相连的之后所有节点,因而在反向传播的过程中,之后所有节点的gradients也都会反过来影响这个local节点的gradient取值。

所以以上的整个反向传播就是在讲怎么计算梯度,以用于找到斜率最大的方向来优化参数矩阵W的取值。

  1. 多维反向传播
    前面函数的输入输出都是单个数值,那如果输入输出变成多维向量或者矩阵呢?
    下面是一个例子:
    这里是W梯度的计算(视频里的这个W梯度矩阵写反了):在这里插入图片描述
    这里是x梯度的计算:在这里插入图片描述
    以上例子的计算过程如下:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

每个变量的梯度的维度应当和这个变量的维度保持一致

  1. 模块化设计:前向传播和反向传播API
    可以把back propagation的过程看成一个forward() API和一个backward() API。
    1、在forward()中:我们计算所有操作的结果并保存所有中间变量,以便在之后计算梯度时使用;
    2、在backward()中:需要应用链式法则,从最后一个节点开始反向计算,最终得到函数对于所有输入变量的梯度。

以下为一个代码示例,分为forward()和backward()两个函数接口:
在这里插入图片描述
以乘法门为例,在forward中需要计算乘法结果z,同时还需要保存self.x和self.y,以便在后续计算梯度时使用,最终返回计算结果z。在backward中,输入为dL/dz(用dz表示),期望输出dL/dx和dL/dy(分别用dx和dy表示),应用链式法则可以通过已知的dz和之前保存的self.x和self.y来计算dx和dy。
在这里插入图片描述

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值