affine/linear(仿射/线性)变换函数详解及全连接层反向传播的梯度求导

版权声明:所有的说明性文档基于 Creative Commons 协议, 所有的代码基于 MIT 协议. All documents are licensed under the Creative Commons License, all codes are licensed under the MIT License. https://blog.csdn.net/oBrightLamp/article/details/84333111

摘要

Affine 仿射层, 又称 Linear 线性变换层, 常用于神经网络结构中的全连接层.
本文给出了 Affine 层的两种定义及相关的反向传播梯度.

相关

配套代码, 请参考文章 :

Python和PyTorch对比实现affine/linear(仿射/线性)变换函数及全连接层的反向传播

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. Affine 的一种定义

考虑一个输入向量 x, Affine 层的权重为 k 维向量 w, 偏置为标量 b, 则 :
x=(x1,x2,x3, ,xk)  w=(w1,w2,w3, ,wk)  affine(xi,wi,b)=xiwi+b x = (x_1,x_2,x_3,\cdots,x_k)\\ \;\\ w = (w_1, w_2,w_3,\cdots,w_k)\\ \;\\ affine(x_i,w_i,b) = x_iw_i+b

使用 X 表示 m 行 k 列的矩阵, 偏置为标量 b, 则一次仿射变换为 :
aT=affine(X,w,b)=XwT+b  aT=(x11x12x13x1kx21x22x23x2kx31x32x33x3kxm1xm2xm3xmk)(w1w2w3wk)+b  a=(a1,a2,a3, ,ak) a^T=affine(X,w,b) = Xw^T + b\\\;\\ a^T= \begin{pmatrix} x_{11}&x_{12} &x_{13}&\cdots&x_{1k}\\ x_{21}&x_{22}&x_{23}&\cdots&x_{2k}\\ x_{31}&x_{32}&x_{33}&\cdots&x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{m1}&x_{m2}&x_{m3}&\cdots&x_{mk} \end{pmatrix} \begin{pmatrix} w_1\\ w_2\\ w_3\\ \vdots\\ w_k \end{pmatrix} +b\\ \;\\ a= (a_1,a_2,a_3,\cdots,a_k)

更一般的, 若使用 W 表示 n 行 k 列的矩阵, 偏置为向量 b , 则 n 次仿射变换为 :
Wn×k=(w11w12w13w1kw21w22w23w2kw31w32w33w3kwn1wn2wn3wnk)  b1×n=(b1,b2,b3, ,bn)  Am×n=affine(X,W,b)=Xm×kWn×kT+b1×n W_{n\times k} =\begin{pmatrix} w_{11}&w_{12} &w_{13}&\cdots&w_{1k}\\ w_{21}&w_{22}&w_{23}&\cdots&w_{2k}\\ w_{31}&w_{32}&w_{33}&\cdots&w_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{n1}&w_{n2}&w_{n3}&\cdots&w_{nk} \end{pmatrix}\\ \;\\ b_{1 \times n} = (b_1,b_2,b_3,\cdots,b_n)\\\;\\ A_{m\times n} = affine(X,W,b) = X_{m\times k}W^T_{n\times k} + b_{1 \times n}

使用求和符号表示 A 矩阵中的元素 :
aij=t=1kxitwjt+bj a_{ij} =\sum_{t=1}^{k} x_{it} \cdot w_{jt} + b_j

取其中一项展开作为示例 :
a23=t=1kx2tw3t+b3=x21w31+x22w32+x23w33++x2kw3k+b3 a_{23} =\sum_{t=1}^{k} x_{2t} \cdot w_{3t} + b_3= x_{21}w_{31}+x_{22}w_{32}+x_{23}w_{33}+\cdots+x_{2k}w_{3k}+ b_3

2. 梯度的定义

三维XYZ空间中的梯度定义:
e(3)=exi+eyj+ezk \nabla e_{(3)} = \frac{\partial e}{\partial x}i+\frac{\partial e}{\partial y}j+\frac{\partial e}{\partial z}k

式中, i,j,ki, j, k是三个两两相互垂直的单位向量, 或 i,j,ki, j, k 是正交单位向量组, 或 i,j,ki, j, k 是一组线性无关的单位向量, 这三种说法是等价的.

推广到 tt 维向量空间 VV, 若 tt 个向量 I1,I2,I3, ,ItI_1, I_2, I_3,\cdots, I_t 是一组两两正交的单位向量, 或单位向量组 I1,I2,I3, ,ItI_1, I_2, I_3,\cdots, I_t 线性无关, 那么, 该向量空间 VV 中的梯度可定义为 :
e(V)=ex1I1+ex2I2+ex3I3++extIt \nabla e_{(V)} = \frac{\partial e}{\partial x_1}I_1+\frac{\partial e}{\partial x_2}I_2+\frac{\partial e}{\partial x_3}I_3+\cdots+\frac{\partial e}{\partial x_t}I_t

梯度的定义可以在 <高等数学> 中找到, 正交和线性无关的定义可以在 <线性代数> 中找到.

3. 反向传播中的梯度求导

若 X 矩阵经过 affine 层变换得到 A 矩阵, 往前 forward 传播得到误差值 error (标量 e ), 求 e 关于 X 的梯度:
Am×n=Xm×kWn×kT+b1×n&ThickSpace;e=forward(A) A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\ \;\\ e=forward(A)

3.1 损失值 e 对 A 矩阵的梯度

首先, 我们说求梯度, 究竟是在求什么?
答 : 一个让损失值 e 变小的最快的方向.

比如, e 对 A 的梯度矩阵 :
dedA=(e/a11e/a12e/a13e/a1ne/a21e/a22e/a23e/a2ne/a31e/a32e/a33e/a3ne/am1e/am2e/am3e/amn) \frac{de}{dA} = \begin{pmatrix} \partial e/ \partial a_{11}&amp;\partial e/ \partial a_{12}&amp;\partial e/ \partial a_{13}&amp;\cdots&amp; \partial e/ \partial a_{1n}\\ \partial e/ \partial a_{21}&amp;\partial e/ \partial a_{22}&amp;\partial e/ \partial a_{23}&amp;\cdots&amp; \partial e/ \partial a_{2n}\\ \partial e/ \partial a_{31}&amp;\partial e/ \partial a_{32}&amp;\partial e/ \partial a_{33}&amp;\cdots&amp; \partial e/ \partial a_{3n}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \partial e/ \partial a_{m1}&amp;\partial e/ \partial a_{m2}&amp;\partial e/ \partial a_{m3}&amp;\cdots&amp; \partial e/ \partial a_{mn}\\ \end{pmatrix}

为了书写方便, 记 :
eaij=aij&ThickSpace;e(A)=dedA=(a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn) \frac{\partial e}{\partial a_{ij}} = a_{ij}&#x27;\\ \;\\ \nabla e_{(A)}= \frac{de}{dA} = \begin{pmatrix} a_{11}&#x27;&amp; a_{12}&#x27;&amp; a_{13}&#x27;&amp;\cdots&amp; a_{1n}&#x27;\\ a_{21}&#x27;&amp; a_{22}&#x27;&amp; a_{23}&#x27;&amp;\cdots&amp; a_{2n}&#x27;\\ a_{31}&#x27;&amp; a_{32}&#x27;&amp; a_{33}&#x27;&amp;\cdots&amp; a_{3n}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ a_{m1}&#x27;&amp; a_{m2}&#x27;&amp; a_{m3}&#x27;&amp;\cdots&amp; a_{mn}&#x27; \end{pmatrix}

所有的 aija_{ij}&#x27; 都是已知的, 是上游的 forward 函数帮我们算好的.
只要矩阵 A 中所有的元素按照这个矩阵等比例的更新, 那么就是使 e 值减少最快的方向.
梯度本身的定义并不是一个矩阵, 而是一个向量 :
e(A)=(a11,a12,&ThinSpace;,a21,a22,&ThinSpace;,am1,am2,&ThinSpace;,amn) \nabla e_{(A)}= (a_{11}&#x27;, a_{12}&#x27;,\cdots, a_{21}&#x27;, a_{22}&#x27;,\cdots,a_{m1}&#x27;, a_{m2}&#x27;,\cdots, a_{mn}&#x27;)

这个写法和上面的矩阵写法是等价的.
利用矩阵求导的写法求梯度, 求的是方向导数, 或者单位向量的系数, 和普通的矩阵求导有区别.

3.2 A 矩阵的元素关于 X 的梯度

Am×n=Xm×kWn×kT+b1×n A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\

根据矩阵乘法行乘列的定义, 矩阵 XXWTW^T 中的第 jj 列向量相乘, 将降维得到一个新的列向量, 作为矩阵 A 中的第 jj 列向量, 即 :
Wj=(wj1,wj2,wj3,&ThinSpace;,wjk)&ThickSpace;XWjT=(a1ja2ja3jamj)=A:,j W_j=(w_{j1},w_{j2},w_{j3},\cdots,w_{jk})\\ \;\\ XW_j^T= \begin{pmatrix} a_{1j}\\ a_{2j}\\ a_{3j}\\ \vdots\\ a_{mj} \end{pmatrix}=A_{:,j}

上面的 :,j:,j 符号表示取矩阵中 jj 列的所有行, 结果是一个列向量. 参考的是 numpy 的记法.
矩阵 A 中任意元素的梯度 :
daijdX=(aij/x11aij/x12aij/x13aij/x1kaij/x21aij/x22aij/x23aij/x2kaij/x31aij/x32aij/x33aij/x3kaij/xm1aij/xm2aij/xm3aij/xmk) \frac{d a_{ij}}{dX} = \begin{pmatrix} \partial a_{ij}/ \partial x_{11}&amp;\partial a_{ij}/ \partial x_{12}&amp;\partial a_{ij}/ \partial x_{13}&amp;\cdots&amp; \partial a_{ij}/ \partial x_{1k}\\ \partial a_{ij}/ \partial x_{21}&amp;\partial a_{ij}/ \partial x_{22}&amp;\partial a_{ij}/ \partial x_{23}&amp;\cdots&amp; \partial a_{ij}/ \partial x_{2k}\\ \partial a_{ij}/ \partial x_{31}&amp;\partial a_{ij}/ \partial x_{32}&amp;\partial a_{ij}/ \partial x_{33}&amp;\cdots&amp; \partial a_{ij}/\partial x_{3k}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \partial a_{ij}/ \partial x_{m1}&amp;\partial a_{ij}/ \partial x_{m2}&amp;\partial a_{ij}/ \partial x_{m3}&amp;\cdots&amp; \partial a_{ij}/ \partial x_{mk}\\ \end{pmatrix}

为了书写方便, 记 :
aijxpq=xijpq&ThickSpace;aij(X)=daijdX=(xij11xij12xij13xij1kxij21xij22xij23xij2kxij31xij32xij33xij3kxijm1xijm2xijm3xijmk) \frac{\partial a_{ij}}{\partial x_{pq}} = x_{ij|pq}&#x27;\\ \;\\ \nabla {a_{ij}}_{(X)}=\frac{d a_{ij}}{dX} = \begin{pmatrix} x_{ij|11}&#x27;&amp;x_{ij|12}&#x27;&amp;x_{ij|13}&#x27;&amp;\cdots&amp;x_{ij|1k}&#x27;\\ x_{ij|21}&#x27;&amp;x_{ij|22}&#x27;&amp;x_{ij|23}&#x27;&amp;\cdots&amp;x_{ij|2k}&#x27;\\ x_{ij|31}&#x27;&amp;x_{ij|32}&#x27;&amp;x_{ij|33}&#x27;&amp;\cdots&amp;x_{ij|3k}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ x_{ij|m1}&#x27;&amp;x_{ij|m2}&#x27;&amp;x_{ij|m3}&#x27;&amp;\cdots&amp;x_{ij|mk}&#x27;\\ \end{pmatrix}

3.3 关于 X 的反向传播

按照矩阵元素的定义 :
aij=t=1kxitwjt+bj&ThickSpace;aij=xi1wj1+xi2wj2++xiqwjq++xikwjk+bj&ThickSpace;xijpq=aijxpq={wjqp=i0,pi a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ x_{ij|pq}&#x27;=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{jq}&amp; p = i\\ 0, &amp; p \neq i \end{array} \right.\\

根据 <高等数学> 中介绍的复合函数求导法则, 知 :
expq=i=1i=mj=1j=neaijaijxpq=i=1i=mj=1j=naijxijpq \frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}&#x27; x_{ij|pq}&#x27;\\

删除零项 :
expq=j=1j=napjwjq&ThickSpace;dedX=(j=1j=na1jwj1j=1j=na1jwj2j=1j=na1jwj3j=1j=na1jwjk&ThickSpace;j=1j=na2jwj1j=1j=na2jwj2j=1j=na2jwj3j=1j=na2jwjk&ThickSpace;j=1j=na3jwj1j=1j=na3jwj2j=1j=na3jwj3j=1j=na3jwjkj=1j=namjwj1j=1j=namjwj2j=1j=namjwj3j=1j=namjwjk) \frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}&#x27;w_{jq}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}&#x27;w_{j1}&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{j2}&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{j3}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{2j}&#x27;w_{j1}&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{j2}&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{j3}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{3j}&#x27;w_{j1}&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{j2}&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{j3}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{jk}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \sum_{j =1}^{j =n} a_{mj}&#x27;w_{j1}&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{j2}&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{j3}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{jk}\\ \end{pmatrix}

这个结果恰好满足矩阵乘法的定义, 分解成矩阵 :
dedX=(a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)(w11w12w13w1kw21w22w23w2kw31w32w33w3kwn1wn2wn3wnk) \frac {d e}{d X}=\begin{pmatrix} a_{11}&#x27;&amp; a_{12}&#x27;&amp; a_{13}&#x27;&amp;\cdots&amp; a_{1n}&#x27;\\ a_{21}&#x27;&amp; a_{22}&#x27;&amp; a_{23}&#x27;&amp;\cdots&amp; a_{2n}&#x27;\\ a_{31}&#x27;&amp; a_{32}&#x27;&amp; a_{33}&#x27;&amp;\cdots&amp; a_{3n}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ a_{m1}&#x27;&amp; a_{m2}&#x27;&amp; a_{m3}&#x27;&amp;\cdots&amp; a_{mn}&#x27; \end{pmatrix} \begin{pmatrix} w_{11}&amp;w_{12} &amp;w_{13}&amp;\cdots&amp;w_{1k}\\ w_{21}&amp;w_{22}&amp;w_{23}&amp;\cdots&amp;w_{2k}\\ w_{31}&amp;w_{32}&amp;w_{33}&amp;\cdots&amp;w_{3k}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ w_{n1}&amp;w_{n2}&amp;w_{n3}&amp;\cdots&amp;w_{nk} \end{pmatrix}

所以, 损失值 e 对 X 的梯度矩阵为 :
dedX=e(A)W \frac {d e}{d X} =\nabla e_{(A)}W

矩阵 e(A)\nabla e_{(A)} 已在前面求得.

3.4 关于 W 的反向传播

参考上例求解 :
aij=t=1kxitwjt+bj&ThickSpace;aij=xi1wj1+xi2wj2++xiqwjq++xikwjk+bj&ThickSpace;wijpq=aijwpq={xiqp=j0pj&ThickSpace;ewpq=i=1i=mj=1j=neaijaijwpq=i=1i=mj=1j=naijwijpq&ThickSpace;ewpq=i=1i=maipxiq&ThickSpace;dedW=(i=1i=mai1xi1i=1i=mai1xi2i=1i=mai1xi3i=1i=mai1xik&ThickSpace;i=1i=mai2xi1i=1i=mai2xi2i=1i=mai2xi3i=1i=mai2xik&ThickSpace;i=1i=mai3xi1i=3i=mai3xi2i=1i=mai3xi3i=1i=mai3xiki=1i=mainxi1i=3i=mainxini=1i=mainxi3i=1i=mainxik) a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ w_{ij|pq}&#x27;=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{iq} &amp; p = j \\ 0 &amp; p \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}&#x27; w_{ij|pq}&#x27;\\ \;\\ \frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{ip}&#x27;x_{iq}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}&#x27;x_{i1}&amp;\sum_{i =1}^{i =m} a_{i1}&#x27;x_{i2}&amp;\sum_{i =1}^{i =m} a_{i1}&#x27;x_{i3}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{i1}&#x27;x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i2}&#x27;x_{i1}&amp;\sum_{i =1}^{i =m} a_{i2}&#x27;x_{i2}&amp;\sum_{i =1}^{i =m} a_{i2}&#x27;x_{i3}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{i2}&#x27;x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i3}&#x27;x_{i1}&amp;\sum_{i =3}^{i =m} a_{i3}&#x27;x_{i2}&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{i3}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{ik}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \sum_{i =1}^{i =m} a_{in}&#x27;x_{i1}&amp;\sum_{i =3}^{i =m} a_{in}&#x27;x_{in}&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{i3}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{ik}\\ \end{pmatrix}\\

这个结果恰好满足矩阵乘法的定义, 分解成矩阵 :
dedW=(a11a21a31am1a12a22a32am2a13a23a33am3a1na2na3namn)(x11x12x13x1kx21x22x23x2kx31x32x33x3kxm1xm2xm3xmk) \frac {d e}{d W}= \begin{pmatrix} a_{11}&#x27;&amp; a_{21}&#x27;&amp; a_{31}&#x27;&amp;\cdots&amp; a_{m1}&#x27;\\ a_{12}&#x27;&amp; a_{22}&#x27;&amp; a_{32}&#x27;&amp;\cdots&amp; a_{m2}&#x27;\\ a_{13}&#x27;&amp; a_{23}&#x27;&amp; a_{33}&#x27;&amp;\cdots&amp; a_{m3}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ a_{1n}&#x27;&amp; a_{2n}&#x27;&amp; a_{3n}&#x27;&amp;\cdots&amp; a_{mn}&#x27;\\ \end{pmatrix} \begin{pmatrix} x_{11}&amp;x_{12} &amp;x_{13}&amp;\cdots&amp;x_{1k}\\ x_{21}&amp;x_{22}&amp;x_{23}&amp;\cdots&amp;x_{2k}\\ x_{31}&amp;x_{32}&amp;x_{33}&amp;\cdots&amp;x_{3k}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ x_{m1}&amp;x_{m2}&amp;x_{m3}&amp;\cdots&amp;x_{mk} \end{pmatrix}

所以, 损失值 e 对 W 的梯度矩阵为 :
dedW=e(A)TX \frac {d e}{d W} =\nabla e_{(A)}^TX

矩阵 e(A)\nabla e_{(A)} 已在前面求得.

3.5 关于 e 对 b 的梯度

参考上例求解 :
aij=t=1kxitwjt+bj&ThickSpace;bijp=aijbq={1,q=j0,qj&ThickSpace;ebq=i=1i=mj=1j=neaijaijbq=i=1i=mj=1j=naijbijq&ThickSpace;ebq=i=1i=maiq1&ThickSpace;dedb=(i=1i=mai1,i=1i=mai2,i=1i=mai3,&ThinSpace;,i=1i=maim) a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ b_{ij|p}&#x27;=\frac{\partial a_{ij}}{\partial b_{q}} = \left\{ \begin{array}{rr} 1,&amp; q = j\\ 0, &amp; q \neq j \end{array} \right.\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial b_{q}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}&#x27; b_{ij|q}&#x27;\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m} a_{iq}&#x27;\cdot 1 \\ \;\\ \frac {d e}{d b} = (\sum_{i = 1}^{i=m} a_{i1}&#x27;,\sum_{i = 1}^{i=m} a_{i2}&#x27;,\sum_{i = 1}^{i=m} a_{i3}&#x27;, \cdots ,\sum_{i = 1}^{i=m} a_{im}&#x27;)\\

所以, 损失值 e 对 b 的梯度矩阵为 :
dedb=sum(e(A),&ThickSpace;axis=0) \frac {de}{db}=sum(\nabla e_{(A)},\; axis=0)

矩阵 e(A)\nabla e_{(A)} 已在前面求得. 式中的 axis=0axis=0 表示对矩阵的第一维求和, 参考的是 numpy 的记法.

4. Affine 的另一种定义

上文中, W 矩阵经过转置 WTW^T 后再参与 Affine 运算.

在目前流行的教材中, 将 W 直接进行 Affine 运算的定义也很多.
Am×n=affine(X,W,b)=Xm×kWk×n+b1×n&ThickSpace;aij=t=1kxitwtj+bj A_{m\times n} = affine(X,W,b) = X_{m\times k}W_{k\times n} + b_{1 \times n} \;\\ a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{tj} +b_j

4.1 关于 X 的反向传播

aij=xi1w1j+xi2w2j++xiqwqj++xikwkj+bj&ThickSpace;xijpq=aijxpq={wqjp=i0,pi a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{iq}w_{qj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ x_{ij|pq}&#x27;=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{qj}&amp; p = i\\ 0, &amp; p \neq i \end{array} \right.\\

expq=i=1i=mj=1j=neaijaijxpq=i=1i=mj=1j=naijxijpq \frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}&#x27; x_{ij|pq}&#x27;\\

expq=j=1j=napjwqj&ThickSpace;dedX=(j=1j=na1jw1jj=1j=na1jw2jj=1j=na1jw3jj=1j=na1jwkj&ThickSpace;j=1j=na2jw1jj=1j=na2jw2jj=1j=na2jw3jj=1j=na2jwkj&ThickSpace;j=1j=na3jw1jj=1j=na3jw2jj=1j=na3jw3jj=1j=na3jwkjj=1j=namjw1jj=1j=namjw2jj=1j=namjw3jj=1j=namjwkj) \frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}&#x27;w_{qj}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}&#x27;w_{1j}&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{2j}&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{3j}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{1j}&#x27;w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{2j}&#x27;w_{1j}&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{2j}&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{3j}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{2j}&#x27;w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{3j}&#x27;w_{1j}&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{2j}&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{3j}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{3j}&#x27;w_{kj}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \sum_{j =1}^{j =n} a_{mj}&#x27;w_{1j}&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{2j}&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{3j}&amp;\cdots&amp;\sum_{j =1}^{j =n} a_{mj}&#x27;w_{kj}\\ \end{pmatrix}

dedX=(a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn)(w11w21w31wk1w12w22w32wk2w13w23w33wk3w1nw2nw3nwkn) \frac {d e}{d X}=\begin{pmatrix} a_{11}&#x27;&amp; a_{12}&#x27;&amp; a_{13}&#x27;&amp;\cdots&amp; a_{1n}&#x27;\\ a_{21}&#x27;&amp; a_{22}&#x27;&amp; a_{23}&#x27;&amp;\cdots&amp; a_{2n}&#x27;\\ a_{31}&#x27;&amp; a_{32}&#x27;&amp; a_{33}&#x27;&amp;\cdots&amp; a_{3n}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ a_{m1}&#x27;&amp; a_{m2}&#x27;&amp; a_{m3}&#x27;&amp;\cdots&amp; a_{mn}&#x27; \end{pmatrix} \begin{pmatrix} w_{11}&amp;w_{21} &amp;w_{31}&amp;\cdots&amp;w_{k1}\\ w_{12}&amp;w_{22}&amp;w_{32}&amp;\cdots&amp;w_{k2}\\ w_{13}&amp;w_{23}&amp;w_{33}&amp;\cdots&amp;w_{k3}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ w_{1n}&amp;w_{2n}&amp;w_{3n}&amp;\cdots&amp;w_{kn} \end{pmatrix}

dedX=e(A)WT \frac {d e}{d X} =\nabla e_{(A)}W^T

4.2 关于 W 的反向传播

aij=xi1w1j+xi2w2j++xipwpj++xikwkj+bj&ThickSpace;wijpq=aijwpq={xipq=j0qj&ThickSpace;ewpq=i=1i=mj=1j=neaijaijwpq=i=1i=mj=1j=naijwijpq a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{ip}w_{pj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ w_{ij|pq}&#x27;=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{ip} &amp; q = j \\ 0 &amp; q \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}&#x27; w_{ij|pq}&#x27;\\
ewpq=i=1i=maiqxip&ThickSpace;dedW=(i=1i=mai1xi1i=1i=mai2xi1i=1i=mai3xi1i=1i=mainxi1&ThickSpace;i=1i=mai1xi2i=1i=mai2xi2i=1i=mai3xi2i=1i=mainxi2&ThickSpace;i=1i=mai1xi3i=3i=mai2xi3i=1i=mai3xi3i=1i=mainxi3i=1i=mai1xiki=3i=mai2xiki=1i=mai3xiki=1i=mainxik) \frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{iq}&#x27;x_{ip}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}&#x27;x_{i1}&amp;\sum_{i =1}^{i =m} a_{i2}&#x27;x_{i1}&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{i1}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{i1}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}&#x27;x_{i2}&amp;\sum_{i =1}^{i =m} a_{i2}&#x27;x_{i2}&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{i2}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{i2}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}&#x27;x_{i3}&amp;\sum_{i =3}^{i =m} a_{i2}&#x27;x_{i3}&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{i3}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{i3}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ \sum_{i =1}^{i =m} a_{i1}&#x27;x_{ik}&amp;\sum_{i =3}^{i =m} a_{i2}&#x27;x_{ik}&amp;\sum_{i =1}^{i =m} a_{i3}&#x27;x_{ik}&amp;\cdots&amp;\sum_{i =1}^{i =m} a_{in}&#x27;x_{ik}\\ \end{pmatrix}\\

dedW=(x11x21x31xm1x12x22x31xm2x13x32x33xm3x1kx2kx3kxmk)(a11a12a13a1na21a22a23a2na31a32a33a3nam1am2am3amn) \frac {d e}{d W}= \begin{pmatrix} x_{11}&amp;x_{21} &amp;x_{31}&amp;\cdots&amp;x_{m1}\\ x_{12}&amp;x_{22}&amp;x_{31}&amp;\cdots&amp;x_{m2}\\ x_{13}&amp;x_{32}&amp;x_{33}&amp;\cdots&amp;x_{m3}\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ x_{1k}&amp;x_{2k}&amp;x_{3k}&amp;\cdots&amp;x_{mk} \end{pmatrix} \begin{pmatrix} a_{11}&#x27;&amp; a_{12}&#x27;&amp; a_{13}&#x27;&amp;\cdots&amp; a_{1n}&#x27;\\ a_{21}&#x27;&amp; a_{22}&#x27;&amp; a_{23}&#x27;&amp;\cdots&amp; a_{2n}&#x27;\\ a_{31}&#x27;&amp; a_{32}&#x27;&amp; a_{33}&#x27;&amp;\cdots&amp; a_{3n}&#x27;\\ \vdots&amp;\vdots&amp;\vdots&amp;\ddots&amp;\vdots\\ a_{m1}&#x27;&amp; a_{m2}&#x27;&amp; a_{m3}&#x27;&amp;\cdots&amp; a_{mn}&#x27; \end{pmatrix}

dedW=XTe(A) \frac {d e}{d W} = X^T\nabla e_{(A)}

4.3 关于 e 对 b 的梯度

同上:
dedb=sum(e(A),&ThickSpace;axis=0) \frac {de}{db}=sum(\nabla e_{(A)},\; axis=0)
全文完.