# affine/linear(仿射/线性)变换函数详解及全连接层反向传播的梯度求导

## 摘要

Affine 仿射层, 又称 Linear 线性变换层, 常用于神经网络结构中的全连接层.

## 相关

Python和PyTorch对比实现affine/linear(仿射/线性)变换函数及全连接层的反向传播

## 1. Affine 的一种定义

$x = (x_1,x_2,x_3,\cdots,x_k)\\ \;\\ w = (w_1, w_2,w_3,\cdots,w_k)\\ \;\\ affine(x_i,w_i,b) = x_iw_i+b$

$a^T=affine(X,w,b) = Xw^T + b\\\;\\ a^T= \begin{pmatrix} x_{11}&x_{12} &x_{13}&\cdots&x_{1k}\\ x_{21}&x_{22}&x_{23}&\cdots&x_{2k}\\ x_{31}&x_{32}&x_{33}&\cdots&x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{m1}&x_{m2}&x_{m3}&\cdots&x_{mk} \end{pmatrix} \begin{pmatrix} w_1\\ w_2\\ w_3\\ \vdots\\ w_k \end{pmatrix} +b\\ \;\\ a= (a_1,a_2,a_3,\cdots,a_k)$

$W_{n\times k} =\begin{pmatrix} w_{11}&w_{12} &w_{13}&\cdots&w_{1k}\\ w_{21}&w_{22}&w_{23}&\cdots&w_{2k}\\ w_{31}&w_{32}&w_{33}&\cdots&w_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{n1}&w_{n2}&w_{n3}&\cdots&w_{nk} \end{pmatrix}\\ \;\\ b_{1 \times n} = (b_1,b_2,b_3,\cdots,b_n)\\\;\\ A_{m\times n} = affine(X,W,b) = X_{m\times k}W^T_{n\times k} + b_{1 \times n}$

$a_{ij} =\sum_{t=1}^{k} x_{it} \cdot w_{jt} + b_j$

$a_{23} =\sum_{t=1}^{k} x_{2t} \cdot w_{3t} + b_3= x_{21}w_{31}+x_{22}w_{32}+x_{23}w_{33}+\cdots+x_{2k}w_{3k}+ b_3$

## 2. 梯度的定义

$\nabla e_{(3)} = \frac{\partial e}{\partial x}i+\frac{\partial e}{\partial y}j+\frac{\partial e}{\partial z}k$

$\nabla e_{(V)} = \frac{\partial e}{\partial x_1}I_1+\frac{\partial e}{\partial x_2}I_2+\frac{\partial e}{\partial x_3}I_3+\cdots+\frac{\partial e}{\partial x_t}I_t$

## 3. 反向传播中的梯度求导

$A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\ \;\\ e=forward(A)$

### 3.1 损失值 e 对 A 矩阵的梯度

$\frac{de}{dA} = \begin{pmatrix} \partial e/ \partial a_{11}&\partial e/ \partial a_{12}&\partial e/ \partial a_{13}&\cdots& \partial e/ \partial a_{1n}\\ \partial e/ \partial a_{21}&\partial e/ \partial a_{22}&\partial e/ \partial a_{23}&\cdots& \partial e/ \partial a_{2n}\\ \partial e/ \partial a_{31}&\partial e/ \partial a_{32}&\partial e/ \partial a_{33}&\cdots& \partial e/ \partial a_{3n}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \partial e/ \partial a_{m1}&\partial e/ \partial a_{m2}&\partial e/ \partial a_{m3}&\cdots& \partial e/ \partial a_{mn}\\ \end{pmatrix}$

$\frac{\partial e}{\partial a_{ij}} = a_{ij}'\\ \;\\ \nabla e_{(A)}= \frac{de}{dA} = \begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix}$

$\nabla e_{(A)}= (a_{11}', a_{12}',\cdots, a_{21}', a_{22}',\cdots,a_{m1}', a_{m2}',\cdots, a_{mn}')$

### 3.2 A 矩阵的元素关于 X 的梯度

$A_{m \times n} = X_{m\times k}{W_{n\times k}}^T + b_{1 \times n}\\$

$W_j=(w_{j1},w_{j2},w_{j3},\cdots,w_{jk})\\ \;\\ XW_j^T= \begin{pmatrix} a_{1j}\\ a_{2j}\\ a_{3j}\\ \vdots\\ a_{mj} \end{pmatrix}=A_{:,j}$

$\frac{d a_{ij}}{dX} = \begin{pmatrix} \partial a_{ij}/ \partial x_{11}&\partial a_{ij}/ \partial x_{12}&\partial a_{ij}/ \partial x_{13}&\cdots& \partial a_{ij}/ \partial x_{1k}\\ \partial a_{ij}/ \partial x_{21}&\partial a_{ij}/ \partial x_{22}&\partial a_{ij}/ \partial x_{23}&\cdots& \partial a_{ij}/ \partial x_{2k}\\ \partial a_{ij}/ \partial x_{31}&\partial a_{ij}/ \partial x_{32}&\partial a_{ij}/ \partial x_{33}&\cdots& \partial a_{ij}/\partial x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \partial a_{ij}/ \partial x_{m1}&\partial a_{ij}/ \partial x_{m2}&\partial a_{ij}/ \partial x_{m3}&\cdots& \partial a_{ij}/ \partial x_{mk}\\ \end{pmatrix}$

$\frac{\partial a_{ij}}{\partial x_{pq}} = x_{ij|pq}'\\ \;\\ \nabla {a_{ij}}_{(X)}=\frac{d a_{ij}}{dX} = \begin{pmatrix} x_{ij|11}'&x_{ij|12}'&x_{ij|13}'&\cdots&x_{ij|1k}'\\ x_{ij|21}'&x_{ij|22}'&x_{ij|23}'&\cdots&x_{ij|2k}'\\ x_{ij|31}'&x_{ij|32}'&x_{ij|33}'&\cdots&x_{ij|3k}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{ij|m1}'&x_{ij|m2}'&x_{ij|m3}'&\cdots&x_{ij|mk}'\\ \end{pmatrix}$

### 3.3 关于 X 的反向传播

$a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ x_{ij|pq}'=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{jq}& p = i\\ 0, & p \neq i \end{array} \right.\\$

$\frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' x_{ij|pq}'\\$

$\frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}'w_{jq}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}'w_{j1}&\sum_{j =1}^{j =n} a_{1j}'w_{j2}&\sum_{j =1}^{j =n} a_{1j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{1j}'w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{2j}'w_{j1}&\sum_{j =1}^{j =n} a_{2j}'w_{j2}&\sum_{j =1}^{j =n} a_{2j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{2j}'w_{jk}\\\;\\ \sum_{j =1}^{j =n} a_{3j}'w_{j1}&\sum_{j =1}^{j =n} a_{3j}'w_{j2}&\sum_{j =1}^{j =n} a_{3j}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{3j}'w_{jk}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{j =1}^{j =n} a_{mj}'w_{j1}&\sum_{j =1}^{j =n} a_{mj}'w_{j2}&\sum_{j =1}^{j =n} a_{mj}'w_{j3}&\cdots&\sum_{j =1}^{j =n} a_{mj}'w_{jk}\\ \end{pmatrix}$

$\frac {d e}{d X}=\begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix} \begin{pmatrix} w_{11}&w_{12} &w_{13}&\cdots&w_{1k}\\ w_{21}&w_{22}&w_{23}&\cdots&w_{2k}\\ w_{31}&w_{32}&w_{33}&\cdots&w_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{n1}&w_{n2}&w_{n3}&\cdots&w_{nk} \end{pmatrix}$

$\frac {d e}{d X} =\nabla e_{(A)}W$

### 3.4 关于 W 的反向传播

$a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ a_{ij}= x_{i1}w_{j1} +x_{i2}w_{j2} +\cdots+x_{iq}w_{jq} +\cdots+x_{ik}w_{jk} +b_j\\ \;\\ w_{ij|pq}'=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{iq} & p = j \\ 0 & p \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' w_{ij|pq}'\\ \;\\ \frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{ip}'x_{iq}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}'x_{i1}&\sum_{i =1}^{i =m} a_{i1}'x_{i2}&\sum_{i =1}^{i =m} a_{i1}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i1}'x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i2}'x_{i1}&\sum_{i =1}^{i =m} a_{i2}'x_{i2}&\sum_{i =1}^{i =m} a_{i2}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i2}'x_{ik}\\ \;\\ \sum_{i =1}^{i =m} a_{i3}'x_{i1}&\sum_{i =3}^{i =m} a_{i3}'x_{i2}&\sum_{i =1}^{i =m} a_{i3}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{i3}'x_{ik}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{i =1}^{i =m} a_{in}'x_{i1}&\sum_{i =3}^{i =m} a_{in}'x_{in}&\sum_{i =1}^{i =m} a_{in}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{ik}\\ \end{pmatrix}\\$

$\frac {d e}{d W}= \begin{pmatrix} a_{11}'& a_{21}'& a_{31}'&\cdots& a_{m1}'\\ a_{12}'& a_{22}'& a_{32}'&\cdots& a_{m2}'\\ a_{13}'& a_{23}'& a_{33}'&\cdots& a_{m3}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{1n}'& a_{2n}'& a_{3n}'&\cdots& a_{mn}'\\ \end{pmatrix} \begin{pmatrix} x_{11}&x_{12} &x_{13}&\cdots&x_{1k}\\ x_{21}&x_{22}&x_{23}&\cdots&x_{2k}\\ x_{31}&x_{32}&x_{33}&\cdots&x_{3k}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{m1}&x_{m2}&x_{m3}&\cdots&x_{mk} \end{pmatrix}$

$\frac {d e}{d W} =\nabla e_{(A)}^TX$

### 3.5 关于 e 对 b 的梯度

$a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{jt} +b_j\\ \;\\ b_{ij|p}'=\frac{\partial a_{ij}}{\partial b_{q}} = \left\{ \begin{array}{rr} 1,& q = j\\ 0, & q \neq j \end{array} \right.\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial b_{q}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' b_{ij|q}'\\ \;\\ \frac {\partial e}{\partial b_{q}} = \sum_{i = 1}^{i=m} a_{iq}'\cdot 1 \\ \;\\ \frac {d e}{d b} = (\sum_{i = 1}^{i=m} a_{i1}',\sum_{i = 1}^{i=m} a_{i2}',\sum_{i = 1}^{i=m} a_{i3}', \cdots ,\sum_{i = 1}^{i=m} a_{im}')\\$

$\frac {de}{db}=sum(\nabla e_{(A)},\; axis=0)$

## 4. Affine 的另一种定义

$A_{m\times n} = affine(X,W,b) = X_{m\times k}W_{k\times n} + b_{1 \times n} \;\\ a_{ij}= \sum_{t=1}^{k} x_{it}\cdot w_{tj} +b_j$

### 4.1 关于 X 的反向传播

$a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{iq}w_{qj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ x_{ij|pq}'=\frac{\partial a_{ij}}{\partial x_{pq}} = \left\{ \begin{array}{rr} w_{qj}& p = i\\ 0, & p \neq i \end{array} \right.\\$

$\frac {\partial e}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial x_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' x_{ij|pq}'\\$

$\frac {\partial e}{\partial x_{pq}}=\sum_{j =1}^{j =n} a_{pj}'w_{qj}\\ \;\\ \frac {d e}{d X}=\begin{pmatrix} \sum_{j =1}^{j =n} a_{1j}'w_{1j}&\sum_{j =1}^{j =n} a_{1j}'w_{2j}&\sum_{j =1}^{j =n} a_{1j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{1j}'w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{2j}'w_{1j}&\sum_{j =1}^{j =n} a_{2j}'w_{2j}&\sum_{j =1}^{j =n} a_{2j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{2j}'w_{kj}\\\;\\ \sum_{j =1}^{j =n} a_{3j}'w_{1j}&\sum_{j =1}^{j =n} a_{3j}'w_{2j}&\sum_{j =1}^{j =n} a_{3j}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{3j}'w_{kj}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{j =1}^{j =n} a_{mj}'w_{1j}&\sum_{j =1}^{j =n} a_{mj}'w_{2j}&\sum_{j =1}^{j =n} a_{mj}'w_{3j}&\cdots&\sum_{j =1}^{j =n} a_{mj}'w_{kj}\\ \end{pmatrix}$

$\frac {d e}{d X}=\begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix} \begin{pmatrix} w_{11}&w_{21} &w_{31}&\cdots&w_{k1}\\ w_{12}&w_{22}&w_{32}&\cdots&w_{k2}\\ w_{13}&w_{23}&w_{33}&\cdots&w_{k3}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ w_{1n}&w_{2n}&w_{3n}&\cdots&w_{kn} \end{pmatrix}$

$\frac {d e}{d X} =\nabla e_{(A)}W^T$

### 4.2 关于 W 的反向传播

$a_{ij}= x_{i1}w_{1j} +x_{i2}w_{2j} +\cdots+x_{ip}w_{pj} +\cdots+x_{ik}w_{kj} +b_j\\ \;\\ w_{ij|pq}'=\frac{\partial a_{ij}}{\partial w_{pq}} = \left\{ \begin{array}{rr} x_{ip} & q = j \\ 0 & q \neq j \end{array} \right.\\\;\\ \frac {\partial e}{\partial w_{pq}} = \sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} \frac {\partial e}{\partial a_{ij}}\frac {\partial a_{ij}}{\partial w_{pq}} =\sum_{i = 1}^{i=m}\sum_{j =1}^{j =n} a_{ij}' w_{ij|pq}'\\$
$\frac {\partial e}{\partial w_{pq}}=\sum_{i =1}^{i =m} a_{iq}'x_{ip}\\ \;\\ \frac {d e}{d W}= \begin{pmatrix} \sum_{i =1}^{i =m} a_{i1}'x_{i1}&\sum_{i =1}^{i =m} a_{i2}'x_{i1}&\sum_{i =1}^{i =m} a_{i3}'x_{i1}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i1}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}'x_{i2}&\sum_{i =1}^{i =m} a_{i2}'x_{i2}&\sum_{i =1}^{i =m} a_{i3}'x_{i2}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i2}\\ \;\\ \sum_{i =1}^{i =m} a_{i1}'x_{i3}&\sum_{i =3}^{i =m} a_{i2}'x_{i3}&\sum_{i =1}^{i =m} a_{i3}'x_{i3}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{i3}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ \sum_{i =1}^{i =m} a_{i1}'x_{ik}&\sum_{i =3}^{i =m} a_{i2}'x_{ik}&\sum_{i =1}^{i =m} a_{i3}'x_{ik}&\cdots&\sum_{i =1}^{i =m} a_{in}'x_{ik}\\ \end{pmatrix}\\$

$\frac {d e}{d W}= \begin{pmatrix} x_{11}&x_{21} &x_{31}&\cdots&x_{m1}\\ x_{12}&x_{22}&x_{31}&\cdots&x_{m2}\\ x_{13}&x_{32}&x_{33}&\cdots&x_{m3}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ x_{1k}&x_{2k}&x_{3k}&\cdots&x_{mk} \end{pmatrix} \begin{pmatrix} a_{11}'& a_{12}'& a_{13}'&\cdots& a_{1n}'\\ a_{21}'& a_{22}'& a_{23}'&\cdots& a_{2n}'\\ a_{31}'& a_{32}'& a_{33}'&\cdots& a_{3n}'\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{m1}'& a_{m2}'& a_{m3}'&\cdots& a_{mn}' \end{pmatrix}$

$\frac {d e}{d W} = X^T\nabla e_{(A)}$

### 4.3 关于 e 对 b 的梯度

$\frac {de}{db}=sum(\nabla e_{(A)},\; axis=0)$