目录
1. 前向传播
建立如图所示的简单网络
W 是权重矩阵,初始赋值为 2*2 的矩阵
X 是输入特征,初始赋值为 2*1 的矩阵
这样通过矩阵乘法 , Y = WX ,应该得到一个 2*1 的输出矩阵
最后定义loss 为二范数的平方,即 out = 0.22^2 + 0.26^2 = 0.116
代码演示为:
torch.norm 是计算矩阵范数的函数
2. 反向传播
反向传播的计算根据链式法则,这里不作数学上的推导。在计算图当中,只需要记住以下常用的即可:(注:需要注意的是,传递的值是反向传递过来的,还是正向传播输入的)
- 加法节点:上游传回来的值直接传递到下游
- 乘法节点:上游传回来的值,乘上输入信号的翻转值
- Max 门:上游传来的值,只传递给输入信号的最大者,其余为0
- ReLU : 如果输入信号大于 0,则上游直接传递;否则,为0
本章,只需要知道乘法节点计算图传递的规则即可
3. 矩阵反向传播
先将结果进行展示:
首先,Y 的梯度很容易计算,Y = [0.22 0.26](转置)
因为这里out 是二范数的平方,因此out = x1^2 + x2^2 ,对Y进行偏导的话,就是2倍的关系
对W和X进行计算的话,因为这里是乘法节点(W*X),因此这里需要将输入信号反转
例如求取W反向传播,应该是上游传递过来的和X的矩阵乘法
这里只要记住反向传播的维度要和输入保持一致就行了
也就是说,目标是得到一个2*2大小的W反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入信号翻转的X是一个2*1大小的。那么根据矩阵乘法,只能是上游传递过来的 * X的转置
同样的道理,对X计算反向传播
目标是得到一个2*1大小的X反向传播,已经知道上游传过来的是一个2*1大小的矩阵,而将输入翻转的W是一个2*2大小的。那么根据矩阵乘法,只能是W的转置 * 上游传递过来的值
4. 总结
本章采用的是 W * X = Y 的方式计算。因为资料或者书籍上面有时候矩阵乘法的顺序会不一样,有的还会加上转置等等。其实这些都是为了满足矩阵乘法规则
为了不会混乱,可以这样记忆。可以不用考虑乘法的顺序或者有无转置
A * B = C 的矩阵乘法
计算谁的时候,就用反向传递的值替换掉谁,然后将另一个元素转置。顺序不变
例如:
计算W梯度的时候,用反向传递的值替换掉W,变成 y * X(这里y是反向传递的值,本章y = 2 Y)
然后另一个元素转置,变成y * X(转置)
计算X梯度的时候,用反向传递的值替换掉X,变成 W * y(这里y是反向传递的值,本章y = 2 Y)
然后另一个元素转置,变成W(转置)* y
或者根据上游传递的信号的维度,和 输入信号翻转的维度进行矩阵计算,也可以得到正确的计算