通过计算图求梯度下降中各偏导的推导
Author: nex3z
2017-08-30
在 Neural Networks and Deep Learning 课程的 Logistic Regression Gradient Descent 一节以逻辑回归为例,介绍了使用计算图(Computation Graph)求梯度下降中各偏导的方法,但没有给出具体的推导过程。
例子中模型为:
\begin{equation}
z = w^Tx + b \tag{1}
\end{equation}
预测为:
\begin{equation}
\hat y = a = \sigma(z) \tag{2}
\end{equation}
其中 $\sigma (z)$ 为 Sigmoid 函数:
\begin{equation}
\sigma(z) = \frac{1}{1 + e^{-z}} \tag{3}
\end{equation}
损失函数为:
\begin{equation}
L(a, y) = -(ylog(a) + (1 – y)log(1 – a)) \tag{4}
\end{equation}
假设只有两个特征 $x_{1}$、$x_{2}$,则:
\begin{equation}
w^T =
\begin{bmatrix}
w_{1} \ w_{2} \tag{5}
\end{bmatrix}
\end{equation}
运算图如图1所示:
图 1
反向计算各偏导的过程如下:
首先求得 $\frac{\partial L}{\partial a}$ 如下:
\begin{equation}
\frac{\partial L}{\partial a} = – \frac{y}{a} + \frac{1 – y}{1 – a} \tag{6}
\end{equation}
然后可以由链式法则求得 $\frac{\partial L}{\partial z}$ 如下:
\begin{equation}
\frac{\partial L}{\partial z} = \frac{\partial L}{\partial a} \cdot \frac{da}{dz} \tag{7}
\end{equation}
其中,$a = \sigma(z)$ 是 Sigmoid 函数,有:
\begin{equation}
\frac{d\sigma(z)}{dz} = \sigma(z)(1 – \sigma(z)) \tag{8}
\end{equation}
将式 (6)、(8) 带入式 (7),得:
\begin{equation}
\frac{\partial L}{\partial z} = (- \frac{y}{a} + \frac{1 – y}{1 – a}) \cdot a(1 – a) \
= -y(1 – a) + a(1 – y) \
= -y + a \tag{9}
\end{equation}
最后求得 $\frac{\partial L}{\partial w_{1}}$、$\frac{\partial L}{\partial w_{2}}$ 和 $\frac{\partial L}{\partial b}$ 如下:
\begin{equation}
\frac{\partial L}{\partial w_{1}} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_{1}} = \frac{\partial L}{\partial z} \cdot x_{1} \tag{10}
\end{equation}
\begin{equation}
\frac{\partial L}{\partial w_{2}} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w_{2}} = \frac{\partial L}{\partial z} \cdot x_{2} \tag{11}
\end{equation}
\begin{equation}
\frac{\partial L}{\partial b} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial b} = \frac{\partial L}{\partial z} \tag{12}
\end{equation}
这里 $\frac{\partial L}{\partial z}$ 不再展开。实际应用中,在由式 (9) 求得 $\frac{\partial L}{\partial z}$ 的值之后,就可以直接带入式 (10)、(11)、(12) 进行计算。