代码仓库: https://github.com/brandonlyg/cute-dl
目标
- 增加交叉熵损失函数,使框架能够支持分类任务的模型。
- 构建一个MLP模型, 在mnist数据集上执行分类任务准确率达到91%。
实现交叉熵损失函数
数学原理
分解交叉熵损失函数
交叉熵损失函数把模型的输出值当成一个离散随机变量的分布列。 设模型的输出为: Y ^ = f ( X ) \hat{Y} = f(X) Y^=f(X), 其中 f ( X ) f(X) f(X)表示模型。 Y ^ \hat{Y} Y^是一个m X n矩阵, 如下所示:
[ y ^ 11 y ^ 12 . . . y ^ 1 n y ^ 21 y ^ 22 . . . y ^ 2 n . . . . . . . . . . . . y ^ m 1 y ^ m 2 . . . y ^ m n ] \begin{bmatrix} \hat{y}_{11} & \hat{y}_{12} & ... & \hat{y}_{1n} \\ \hat{y}_{21} & \hat{y}_{22} & ... & \hat{y}_{2n} \\ ... & ... & ... & ... \\ \hat{y}_{m1} & \hat{y}_{m2} & ... & \hat{y}_{mn} \end{bmatrix} ⎣⎢⎢⎡y^11y^21...y^m1y^12y^22...y^m2............y^1ny^2n...y^mn⎦⎥⎥⎤
把这个矩阵的第i行记为 y ^ i \hat{y}_i y^i, 它是一个 R 1 X n \\R^{1Xn} R1Xn向量, 它的第j个元素记为 y ^ i j \hat{y}_{ij} y^ij。
交叉熵损失函数要求 y ^ i \hat{y}_i y^i具有如下性质:
0 < = y ^ i j < = 1 ( 1 ) ∑ j = 1 n y ^ i j = 1 , n = 2 , 3 , . . . ( 2 ) \begin{matrix} 0<=\hat{y}_{ij}<=1 & & (1)\\ \sum_{j=1}^{n} \hat{y}_{ij} = 1, & n=2,3,... & (2) \end{matrix} 0<=y^ij<=1∑j=1ny^ij=1,n=2,3,...(1)(2)
特别地,当n=1时, 只需要满足第一条性质即可。我们先考虑n > 1的情况, 这种情况下n=2等价于n=1,在工程上n=1可以看成是对n=2的优化。
模型有时候并不会保证输出值有这些性质, 这时损失函数要把 y ^ i \hat{y}_i y^i转换成一个分布列: p ^ i \hat{p}_i p^i, 转换函数的定义如下:
S i = ∑ j = 1 n e y ^ i j p ^ i j = e y ^ i j S i \begin{matrix} S_i = \sum_{j=1}^{n} e^{\hat{y}_{ij}}\\ \hat{p}_{ij} = \frac{e^{\hat{y}_{ij}}}{S_i} \end{matrix} Si=∑j=1ney^ijp^ij=Siey^ij
这里的 p ^ i \hat{p}_i p^i是可以满足要求的。函数 e y ^ i j e^{\hat{y}_{ij}} ey^ij是单调增函数,对于任意两个不同的 y ^ i a < y ^ i b \hat{y}_{ia} < \hat{y}_{ib} y^ia<y^ib, 都有: e