聊聊关于矩阵反向传播的梯度计算

文章详细介绍了前向传播过程,通过一个简单的网络示例展示了如何计算输出。接着,解释了反向传播的基本概念,特别是乘法节点的梯度计算规则。在矩阵反向传播部分,重点讨论了如何计算权重矩阵W和输入特征X的梯度,强调了保持维度一致性和矩阵乘法的运用。最后,总结了计算梯度时的一般方法,即根据传递的信号维度和矩阵乘法规则进行计算。
摘要由CSDN通过智能技术生成

目录

1. 前向传播

2. 反向传播

3. 矩阵反向传播

4. 总结


1. 前向传播

建立如图所示的简单网络

W 是权重矩阵,初始赋值为 2*2 的矩阵

X 是输入特征,初始赋值为 2*1 的矩阵

这样通过矩阵乘法 , Y = WX ,应该得到一个 2*1 的输出矩阵

最后定义loss 为二范数的平方,即 out = 0.22^2 + 0.26^2 = 0.116

 

代码演示为:

torch.norm 是计算矩阵范数的函数

 

2. 反向传播

反向传播的计算根据链式法则,这里不作数学上的推导。在计算图当中,只需要记住以下常用的即可:(注:需要注意的是,传递的值是反向传递过来的,还是正向传播输入的)

  • 加法节点:上游传回来的值直接传递到下游
  • 乘法节点:上游传回来的值,乘上输入信号的翻转值
  • Max 门:上游传来的值,只传递给输入信号的最大者,其余为0
  • ReLU : 如果输入信号大于 0,则上游直接传递;否则,为0

本章,只需要知道乘法节点计算图传递的规则即可

3. 矩阵反向传播

先将结果进行展示:

 

 


 

首先,Y 的梯度很容易计算,Y = [0.22 0.26](转置)

因为这里out 是二范数的平方,因此out = x1^2 + x2^2 ,对Y进行偏导的话,就是2倍的关系

 


 

对W和X进行计算的话,因为这里是乘法节点(W*X),因此这里需要将输入信号反转

例如求取W反向传播,应该是上游传递过来的和X的矩阵乘法 

这里只要记住反向传播的维度要和输入保持一致就行了

也就是说,目标是得到一个2*2大小的W反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入信号翻转的X是一个2*1大小的。那么根据矩阵乘法,只能是上游传递过来的 * X的转置

 

 


同样的道理,对X计算反向传播

目标是得到一个2*1大小的X反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入翻转的W是一个2*2大小的。那么根据矩阵乘法,只能是W的转置 * 上游传递过来的值

 

4. 总结

本章采用的是 W * X = Y 的方式计算。因为资料或者书籍上面有时候矩阵乘法的顺序会不一样,有的还会加上转置等等。其实这些都是为了满足矩阵乘法规则

为了不会混乱,可以这样记忆。可以不用考虑乘法的顺序或者有无转置

A * B = C 的矩阵乘法

计算谁的时候,就用反向传递的值替换掉谁,然后将另一个元素转置。顺序不变

例如:

W_{2*2} * X_{2*1} = Y_{2*1}

计算W梯度的时候,用反向传递的值替换掉W,变成 y * X(这里y是反向传递的值,本章y = 2 Y)

然后另一个元素转置,变成y * X(转置)

计算X梯度的时候,用反向传递的值替换掉X,变成 W * y(这里y是反向传递的值,本章y = 2 Y)

然后另一个元素转置,变成W(转置)* y

或者根据上游传递的信号的维度,和 输入信号翻转的维度进行矩阵计算,也可以得到正确的计算

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ai 医学图像分割

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值