Multivariate Chain Rule
We have seen an example previously about how to use chain rule to calculate total derivative of function f with respect to variable t through intermediate variables x, y and z as follows.
f = sin ( x ) e y z 2 x = t − 1 y = t 2 z = 1 t d f d t = ∂ f ∂ x ⋅ d x d t + ∂ f ∂ y ⋅ d y d t + ∂ f ∂ z ⋅ d z d t = cos ( x ) e y z 2 ⋅ ( 1 ) + z 2 sin ( x ) e y z 2 ⋅ ( 2 t ) + 2 y sin ( x ) e y z 2 ⋅ ( − 1 t 2 ) = e y z 2 [ cos ( x ) + 2 t z 2 sin ( x ) − 2 y z t 2 sin ( x ) ] \begin{aligned} f&=\sin(x)e^{yz^2}\\ x&=t-1\\ y&=t^2\\ z&=\frac{1}{t}\\ \frac{df}{dt}&=\frac{\partial f}{\partial x}\cdot\frac{dx}{dt}+\frac{\partial f}{\partial y}\cdot\frac{dy}{dt}+\frac{\partial f}{\partial z}\cdot\frac{dz}{dt}\\ &=\cos(x)e^{yz^2}\cdot(1)+z^2\sin(x)e^{yz^2}\cdot(2t)+2y\sin(x)e^{yz^2}\cdot(-\frac{1}{t^2})\\ &=e^{yz^2}[\cos(x)+2tz^2\sin(x)-\frac{2yz}{t^2}\sin(x)] \end{aligned} fxyzdtdf=sin(x)eyz2=t−1=t2=t1=∂x∂f⋅dtdx+∂y∂f⋅dtdy+∂z∂f⋅dtdz=cos(x)eyz2⋅(1)+z2sin(x)eyz2⋅(2t)+2ysin(x)eyz2⋅(−t21)=eyz2[cos(x)+2tz2sin(x)−t22yzsin(x)]
Now let’s formalize the definition of chain rule in multivariate calculus. We start with a function f with n input variables x 1 x_1 x1, x 2 x_2 x2,…, x n x_n xn.
f ( x 1 , x 2 , ⋯ , x n ) = f ( x ) f(x_1, x_2,\cdots,x_n)=f(\mathbf{x}) f(x1,x2,⋯,xn)=f(x)
To simplify our expression, we will denote function f by f ( x ) f(\mathbf{x}) f(x) with bold letter x \mathbf{x} x representing a vector of n input variables.
If each of the input variable in x \mathbf{x} x is also a function of another variable t, our function f can be re-expressed as f ( x ( t ) ) f(\mathbf{x}(t)) f(x(t)).
f ( x ( t ) ) = f ( x 1 ( t ) , x 2 ( t ) , ⋯ , x n ( t ) ) f(\mathbf{x}(t))=f(x_1(t),x_2(t),\cdots,x_n(t)) f(x(t))=f(x1(t),x2(t),⋯,xn(t))
To compute the derivative of function f with respect to t, we can write a dot product of two n-dimensional vectors.
d f d t = ∂ f ∂ x 1 ⋅ d x 1 d t + ∂ f ∂ x 2 ⋅ d x 2 d t + ⋯ + ∂ f ∂ x n ⋅ d x n d t = ( ∂ f ∂ x 1 ∂ f ∂ x 2 ⋯ ∂ f ∂ x n ) ⋅ ( d x 1 d t d x 2 d t ⋮ d x n d t ) \begin{aligned} \frac{df}{dt}&=\frac{\partial f}{\partial x_1}\cdot\frac{dx_1}{dt}+\frac{\partial f}{\partial x_2}\cdot\frac{dx_2}{dt}+\cdots+\frac{\partial f}{\partial x_n}\cdot\frac{dx_n}{dt}\\ &=\begin{pmatrix}\frac{\partial f}{\partial x_1}&\frac{\partial f}{\partial x_2}&\cdots&\frac{\partial f}{\partial x_n}\end{pmatrix}\cdot\begin{pmatrix}\frac{dx_1}{dt}\\\frac{dx_2}{dt}\\\vdots\\\frac{dx_n}{dt}\end{pmatrix} \end{aligned} dtdf=∂x1∂f⋅dtdx1+∂x2∂f⋅dtdx2+⋯+∂xn∂f⋅dtdxn=(∂x1∂f∂x2∂f⋯∂xn∂f)⋅⎝⎜⎜⎜⎛dtdx1dtdx2⋮dtdxn⎠⎟⎟⎟⎞
What we get from the first vector is just the partial derivatives of function f with respect to all its input variables. Does that look familiar? Yes, that is our Jacobian vector. So our differentiation equation can be simplified as,
d f d t = ( J f ) ⋅ d x d t \frac{df}{dt}=(J_f)\cdot\frac{d\mathbf{x}}{dt} dtdf=(Jf)⋅dtdx
We use d x d t \frac{d\mathbf{x}}{dt} dtdx to compactly represent the vector of differentiation expression d x 1 d t , d x 2 d t , ⋯ , d x n d t \frac{dx_1}{dt}, \frac{dx_2}{dt}, \cdots, \frac{dx_n}{dt} dtdx1,dtdx2,⋯,dtdxn.
This is multivariate chain rule for one intermediate layer. What if we have more than one intermediate layer? Let’s add one more layer in between x \mathbf{x} x and t and denote the function as f ( x ( u ( t ) ) ) f(\mathbf{x}(\mathbf{u}(t))) f(x(u(t))). Now f is a function of a vector of n variables x \mathbf{x} x. Each variable of x \mathbf{x} x is a function of a vector of m variables u \mathbf{u} u. Finally, each variable of u \mathbf{u} u is a function of variable t. To differentiate function f with respect to t, we need
d f d t = ( ∂ f ∂ x 1 ∂ f ∂ x 2 ⋯ ∂ f ∂ x n ) ⋅ ( ∂ x 1 ∂ u 1 ∂ x 1 ∂ u 2 ⋯ ∂ x 1 ∂ u m ∂ x 2 ∂ u 1 ∂ x 2 ∂ u 2 ⋯ ∂ x 2 ∂ u m ⋮ ⋮ ⋱ ⋮ ∂ x n ∂ u 1 ∂ x n ∂ u 2 ⋯ ∂ x n ∂ u m ) ⋅ ( d u 1 d t d u 2 d t ⋮ d u m d t ) = ( J f ) ⋅ ( J x ) ⋅ d u d t \begin{aligned} \frac{df}{dt} &=\begin{pmatrix}\frac{\partial f}{\partial x_1}&\frac{\partial f}{\partial x_2}&\cdots&\frac{\partial f}{\partial x_n}\end{pmatrix}\cdot\begin{pmatrix}\frac{\partial x_1}{\partial u_1}&\frac{\partial x_1}{\partial u_2}&\cdots&\frac{\partial x_1}{\partial u_m}\\\frac{\partial x_2}{\partial u_1}&\frac{\partial x_2}{\partial u_2}&\cdots&\frac{\partial x_2}{\partial u_m}\\\vdots&\vdots&\ddots&\vdots\\\frac{\partial x_n}{\partial u_1}&\frac{\partial x_n}{\partial u_2}&\cdots&\frac{\partial x_n}{\partial u_m}\end{pmatrix}\cdot\begin{pmatrix}\frac{du_1}{dt}\\\frac{du_2}{dt}\\\vdots\\\frac{du_m}{dt}\end{pmatrix}\\ &=(J_f)\cdot(J_\mathbf{x})\cdot\frac{d\mathbf{u}}{dt} \end{aligned} dtdf=(∂x1∂f∂x2∂f⋯∂xn∂f)⋅⎝⎜⎜⎜⎛∂u1∂x1∂u1∂x2⋮∂u1∂xn∂u2∂x1∂u2∂x2⋮∂u2∂xn⋯⋯⋱⋯∂um∂x1∂um∂x2⋮∂um∂xn⎠⎟⎟⎟⎞⋅⎝⎜⎜⎜⎛dtdu1dtdu2⋮dtdum⎠⎟⎟⎟⎞=(Jf)⋅(Jx)⋅dtdu
Again, we have simplified our expression by two Jacobian terms. Note the middle Jacobian ( J x J_\mathbf{x} Jx) is a matrix instead of a vector.
With that, we can generalize the multivariate chain rule of n intermediate layers as
d f d t = ( J f ) ⋅ ( J x ( 1 ) ) ⋅ ( J x ( 2 ) ) ⋯ ( J x ( n − 1 ) ) ⋅ d x ( n ) d t \frac{df}{dt}=(J_f)\cdot(J_{\mathbf{x}^{(1)}})\cdot(J_{\mathbf{x}^{(2)}})\cdots(J_{\mathbf{x}^{(n-1)}})\cdot\frac{d\mathbf{x}^{(n)}}{dt} dtdf=(Jf)⋅(Jx(1))⋅(Jx(2))⋯(Jx(n−1))⋅dtdx(n)
Here the superscript ( i ) (i) (i) denotes the i t h i^{th} ith intermediate layer.
This is a powerful rule. As long as we can clearly define the variable expressions from one layer to the next, we can always differentiate function f f f with respect to the input variable at any layer. We will see this rule being used a lot in practice.
Neural Networks
Multivariate chain rule is one essential component of neural networks. Each layer of neurons in the network are controlled by a set of parameters. These parameters are optimized by individually differentiating the cost function. Moreover, the differentiation result is propagated from one layer to the next, just like how our chain rule works.
Let’s use a simple neural network setup to explain the notations we are going to use.
We have circles representing nodes and lines representing connections between each node in a network. The above network can be written mathematically as
a 0 ( 1 ) = σ ( w 00 a 0 ( 0 ) + w 10 a 1 ( 0 ) + w 20 a 2 ( 0 ) + b 0 ) a 1 ( 1 ) = σ ( w 01 a 0 ( 0 ) + w 11 a 1 ( 0 ) + w 21 a 2 ( 0 ) + b 1 ) \begin{aligned} a_0^{(1)}&=\sigma(w_{00}a_0^{(0)}+w_{10}a_1^{(0)}+w_{20}a_2^{(0)}+b_0)\\ a_1^{(1)}&=\sigma(w_{01}a_0^{(0)}+w_{11}a_1^{(0)}+w_{21}a_2^{(0)}+b_1) \end{aligned} a0(1)a1(1)=σ(w00a0(0)+w10a1(0)+w20a2(0)+b0)=σ(w01a0(0)+w11a1(0)+w21a2(0)+b1)
Here we use a to represent value of a node (it can be an input or a output node). The superscript of a indicates which layer the node is at and the subscript of a indicates the nth node of that layer. w and b are the weight and bias respectively applied to the input node. Finally, σ is an activation function that transforms a number of input nodes to the output node.
We can further simply our network expression by vector notation.
a ( 1 ) = σ ( w ⋅ a ( 0 ) + b ) \mathbf{a}^{(1)}=\sigma(\mathbf{w}\cdot\mathbf{a}^{(0)}+\mathbf{b}) a(1)=σ(w⋅a(0)+b)
Now a ( 0 ) \mathbf{a}^{(0)} a(0) is the vector of input nodes at layer 0. a ( 1 ) \mathbf{a}^{(1)} a(1) is the vector of output nodes at layer 1. w \mathbf{w} w is the weight matrix relating each input node to each output node. b \mathbf{b} b is the bias vector for each output node.
If we extend the same idea to any layer of a neural network, we can get following relation between input layer L − 1 L-1 L−1 and output layer L L L.
a ( L ) = σ ( w ( L ) ⋅ a ( L − 1 ) + b ) \mathbf{a}^{(L)}=\sigma(\mathbf{w}^{(L)}\cdot\mathbf{a}^{(L-1)}+\mathbf{b}) a(L)=σ(w(L)⋅a(L−1)+b)
You can imagine we stack up these intermediate layers of nodes between the input and output layers for a complex neural network. Each intermediate layer is simultaneously the output from previous layer and the input to the next layer. In solving such a neural network, we are interested to find the right parameters w \mathbf{w} w and b \mathbf{b} b at each layer that successfully map the model input to the final output.
In order to do that, we need to know how much our current network output is deviated from the actual output. This deviation is defined by cost function.
C = ∑ i ( a i ( L ) − y i ) 2 C=\sum_{i}(a_i^{(L)}-y_i)^2 C=i∑(ai(L)−yi)2
We take the square of difference between the model output and the actual output and sum over all possible outputs. The goal for our neural network becomes solving for parameters w \mathbf{w} w and b \mathbf{b} b at each layer such that C is minimized. This requires us to perform differentiation of C with respect to every variable w \mathbf{w} w and b \mathbf{b} b. Since w \mathbf{w} w and b \mathbf{b} b are connected to C via layers of nodes, we can apply multivariate chain rule here.
Let’s start simple with just one layer of network with single input and output node.
a ( 1 ) = σ ( w a 0 + b ) a^{(1)}=\sigma(wa^{0}+b) a(1)=σ(wa0+b)
We will add one assist variable z to represent the output value before activation. So the network model is connected to cost function as:
z ( 1 ) = w a 0 + b a ( 1 ) = σ ( z ( 1 ) ) C = ( a ( 1 ) − y ) 2 \begin{aligned} z^{(1)}&=wa^{0}+b\\ a^{(1)}&=\sigma(z^{(1)})\\ C&=(a^{(1)}-y)^2 \end{aligned} z(1)a(1)C=wa0+b=σ(z(1))=(a(1)−y)2
We can then find the derivative of C C C with respect to w w w and b b b as
d C d w = ∂ C ∂ a ( 1 ) ⋅ ∂ a ( 1 ) ∂ z ( 1 ) ⋅ d z ( 1 ) d w = 2 ( a ( 1 ) − y ) ⋅ σ − 1 ( z ( 1 ) ) ⋅ ( a 0 ) d C d b = ∂ C ∂ a ( 1 ) ⋅ ∂ a ( 1 ) ∂ z ( 1 ) ⋅ d z ( 1 ) d b = 2 ( a ( 1 ) − y ) ⋅ σ − 1 ( z ( 1 ) ) ⋅ ( 1 ) \begin{aligned} \frac{dC}{dw}&=\frac{\partial C}{\partial a^{(1)}}\cdot\frac{\partial a^{(1)}}{\partial z^{(1)}}\cdot\frac{dz^{(1)}}{dw}=2(a^{(1)}-y)\cdot\sigma^{-1}(z^{(1)})\cdot(a^{0})\\ \frac{dC}{db}&=\frac{\partial C}{\partial a^{(1)}}\cdot\frac{\partial a^{(1)}}{\partial z^{(1)}}\cdot\frac{dz^{(1)}}{db}=2(a^{(1)}-y)\cdot\sigma^{-1}(z^{(1)})\cdot(1) \end{aligned} dwdCdbdC=∂a(1)∂C⋅∂z(1)∂a(1)⋅dwdz(1)=2(a(1)−y)⋅σ−1(z(1))⋅(a0)=∂a(1)∂C⋅∂z(1)∂a(1)⋅dbdz(1)=2(a(1)−y)⋅σ−1(z(1))⋅(1)
This differentiation equation will be evaluated to a scalar, which is the step we take to move w w w and b b b from their current values. We then recalculate the differentiation result and update parameters w and b iteratively until their derivatives are zero. That is where the cost is minimized.
When we add more nodes and more layers to the network, it will get more complicated to differentiate C C C with respect to each of w \mathbf{w} w and b \mathbf{b} b. Nonetheless, we shall still apply the multivariate chain rule to iteratively differentiate and evaluate each layer of the network till all the parameters are optimized. This is a very laboratory process, so I will not demonstrate the details here. Fortunately, most of the machine learning solver has been equipped with the differentiation process. You do not need to spell out the expression yourself. What really matters is the principle behind all the calculation. It could help you better design your network structure and/or choose the parameters.
We have completed the study of multivariate chain rule. It is a very powerful tool when we are dealing with complex optimization problems such as neural networks. We can traverse from one set of variables to the next and calculate their influence on the final cost function with chained partial derivatives. Hope you have enjoyed reading so far. We will dive into more interesting topics in multivariate calculus in the coming articles.
(Inspired by Mathematics for Machine Learning lecture series from Imperial College London)