文章目录
1. 目的
弄清 LeNet-5 的论文原文 LeCun-98.pdf1 中的输出层的结构,以及损失函数的计算公式。
的确网络上已经有很多 LeNet-5 的现代实现了, 但过于现代导致和原版相差太多。本篇 blog 尽可能贴合论文原文, 不希望人云亦云照搬已有结果, 根据论文进行复现追求的是“一致性”。假设没有互联网上已经存在的各种 LeNet-5 的实现, 也应该能按论文原文做出实现。
2. 输出层结构
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-93L5Wh2q-1687600825416)(LeNet-5_gaussian_connection_rbf.png)]
2.1 Gaussian Connection
博客2 给出了 Gaussian Connection 的白话解释:高斯连接的输出就是输出结果,不需要套用激活函数; 高斯连接需要和欧几里得径向基函数(Euclidean Radial Basic Functions)配合使用。
Finally, the output layer is composed of Euclidean Radial Basis Function units (RBF), one for each class, with 84 inputs each. The output each RBF unit y i y_i yi is computed as follows:
y i = ∑ j ( x j − w i j ) 2 y_i = \sum_j(x_j - w_{ij})^2 yi=j∑(xj−wij)2
2.2 Gaussian Connection 的 weight 可视化
按论文原文, Gaussian connection 的 weight 只有两种取值: 1 和 -1。不妨做一个映射,把1映射到255,把-1映射到0,就得到一幅图像。没错, Gaussian connection 的 weight 只有10个,每个代表一个 ascii 打印形式的数字(0-9),大小是 7x12 = 84 像素。
首先是制作每个数字的 ascii 表示, 然后保存为 .pgm 格式的图像.
可视化:
存储 pgm 图像:
void write_pgm_image(uchar* image, int width, int height, const char* filename)
{
FILE* fout = fopen(filename, "wb");
fprintf(fout, "P5\n%d %d\n255\n", width, height);
fwrite(image, width * height, 1, fout);
fclose(fout);
}
完整代码:
uchar digits[10][84];
void make_7x12_digits()
{
uchar d[10][84] =
{
{
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 0, 1, 1, 0,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 0, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 0, 0,
0, 1, 1, 0, 1, 1, 1,
1, 0, 0, 0, 1, 1, 1,
0, 0, 0, 1, 1, 1, 0,
0, 0, 1, 1, 1, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 0, 1, 1, 0,
0, 1, 1, 0, 1, 1, 0,
0, 1, 1, 0, 1, 1, 0,
0, 1, 1, 0, 1, 1, 0,
1, 1, 0, 0, 1, 1, 0,
1, 1, 0, 0, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 0, 1, 1, 0,
},
{
0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1,
1, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 0, 0, 1, 1,
0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 0, 0, 1, 1, 1, 0,
0, 1, 1, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 0,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 1,
1, 1, 0, 0, 0, 0, 1,
0, 1, 0, 0, 0, 1, 0,
0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 1, 1, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 1, 1, 0, 0,
0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 1, 1, 1, 1, 1, 0,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 1, 1, 1, 0,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 0, 1,
1, 1, 0, 0, 0, 0, 1,
1, 1, 0, 0, 0, 0, 1,
0, 1, 1, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
},
{
0, 1, 1, 1, 1, 1, 0,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 0, 1, 1,
1, 1, 0, 0, 1, 1, 1,
0, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 1, 1,
0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 1, 1, 0,
0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
}
};
memcpy(digits, d, sizeof(digits));
for (int k = 0; k < 10; k++)
{
char savename[50] = { 0 };
sprintf(savename, "RBF%d.pgm", k);
int width = 7;
int height = 12;
for (int i = 0; i < height; i++)
{
for (int j = 0; j < width; j++)
{
int idx = i * width + j;
d[k][idx] = (d[k][idx] == 1) ? 255 : 0;
}
}
write_pgm_image(d[k], width, height, savename);
}
}
3. Loss Function
关于 Loss function, 先前卡在了不理解 Loss 函数中用的符号。后来在谷歌中搜了一下这篇论文的相关解读博客3, 4,发现其实 Introduction 部分就给出了符号的说明。
3.1 当前类别判断错误时,loss function 中的项(基本项)
Z p Z^p Zp: 第 p p p 个输入
Y p Y^p Yp: Z p Z^p Zp 的被识别的类别(推理结果)
W W W: 系统中可调整的参数的集合。
D p D^p Dp: Z p Z^p Zp 的正确类别(ground truth)
推理得到结果:
Y
p
=
F
(
Z
p
,
W
)
Y^p = F(Z^p, W)
Yp=F(Zp,W)
Loss Function:
E
p
=
D
(
D
p
,
F
(
W
,
Z
p
)
E^p = D(D^p, F(W, Z^p)
Ep=D(Dp,F(W,Zp)
E ( W ) = 1 P ∑ p = 1 P y D p ( Z p , W ) E(W) = \frac{1}{P}\sum_{p=1}^{P}y_{D^p}(Z^p, W) E(W)=P1p=1∑PyDp(Zp,W)
其中 y D p y_{D^p} yDpj 表示输入样本 Z p Z^p Zp 在对应的真实类别 D p D^p Dp 上的得分. 比如,对于 3.pgm 来说, D p = 3 D^p=3 Dp=3, output 层算出来的 RBF 有10个结果: y 0 , . . . , y 9 y_0, ..., y_9 y0,...,y9, 而 y 3 y_3 y3 就是 y D p y_{D^p} yDp。
实际训练时, 样本数量 P > 1 P > 1 P>1, 我们收集每个样本的损失 y D p y_{D^p} yDp, 累加起来, 就得到 E ( W ) E(W) E(W) 了。
3.2 判断为其他类别时, loss function 中的项(惩罚项)
尽管上述 cost function 适合大多数情况, 但它缺乏3个重要的性质。
第一, 如果我们允许 RBF 层的参数改变(本来是设定为1和-1),那么 E ( W ) E(W) E(W) 会得到一个平凡的、同时完全不可接受的解。(啥意思??)在这个解中,所有 RBF 参数向量都相等, 并且F6的状态是常量,并且和参数向量相等。这种情况下, 网络开心地忽略输入, 并且所有 RBF 输入等于0。 如果 RBF 的权值不允许改变, 这种“坍缩现象”就不会出现。
第二个问题是,类别之间没有竞争。
为了解决这两个问题, 作者提出如下修正后的损失函数:
E
(
W
)
=
1
P
∑
p
=
1
P
(
y
D
p
(
Z
p
,
W
)
+
log
(
e
−
j
+
∑
i
e
−
y
i
(
Z
p
,
W
)
)
)
E(W) = \frac{1}{P}\sum_{p=1}^{P}(y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_i e^{-y_i(Z^p, W)}))
E(W)=P1p=1∑P(yDp(Zp,W)+log(e−j+i∑e−yi(Zp,W)))
其中 j j j 是一个大于0的常数(但作者没有给出具体取值)。
其中 i i i 的取值,原文没说,可以理解为 D p D^p Dp 之外的类别,因为作为惩罚项它的目的是当其他类别的 RBF 输出值较小(接近0)时,惩罚项变大:
E ( W ) = 1 P ∑ p = 1 P ( y D p ( Z p , W ) + log ( e − j + ∑ i ! = D p e − y i ( Z p , W ) ) ) E(W) = \frac{1}{P}\sum_{p=1}^{P}(y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_{i!=D^p} e^{-y_i(Z^p, W)})) E(W)=P1p=1∑P(yDp(Zp,W)+log(e−j+i!=Dp∑e−yi(Zp,W)))
例如对于 3.pgm 输入图, D p = 3 D^p=3 Dp=3, i ∈ { 0 , 1 , 2 , 4 , 5 , 6 , 7 , 8 , 9 } i \in \{ 0, 1, 2, 4, 5, 6, 7, 8, 9 \} i∈{0,1,2,4,5,6,7,8,9}.
4. 代码实现
4.1 输出层的前向计算
void forward_output()
{
double* x = F6_output;
double* y = RBF_output;
double w[10][84];
for (int i = 0; i < 10; i++)
{
for (int j = 0; j < 84; j++)
{
w[i][j] = (digits[i][j] == 255) ? 1 : -1;
}
}
// Euclidean Radial Basis Functions (RBF)
printf("Output layer: \n");
for (int i = 0; i < 10; i++)
{
double sum = 0;
for (int j = 0; j < 84; j++)
{
double diff = (x[j] - w[i][j]);
sum += diff * diff;
}
y[i] = sum;
printf(" y[%d] = %lf\n", i, y[i]);
}
}
4.2 Loss Function 的实现
简单起见, 先只考虑单个样本对应的损失函数。 也就是, P = 1 P=1 P=1 时,
E ( W ) = y D p ( Z p , W ) + log ( e − j + ∑ i ! = D p e − y i ( Z p , W ) ) E(W) = y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_{i!=D^p} e^{-y_i(Z^p, W)}) E(W)=yDp(Zp,W)+log(e−j+i!=Dp∑e−yi(Zp,W))
C 语言实现如下:
// compute loss for single input pattern D^p
void compute_loss()
{
int correct_label = 3;
double squared_error = RBF_output[correct_label];
// compute penalty
int j = 1;
double value = m_exp(j);
for (int i = 0; i < 10; i++)
{
if (i != correct_label)
{
value += m_exp(-RBF_output[i]);
}
}
double penalty = m_log(value);
double loss = squared_error + penalty;
printf("loss is %lf\n", loss);
}
其中 m_exp()
是 exp 函数的C实现,见 scratch lenet(8): C语言实现 exp(x) 的计算5:
double m_exp(double x)
{
double d;
*((int*)&d + 0) = 0;
*((int*)&d + 1) = (int)(1512775 * x + 1072632447);
return d;
}
以及, m_log
是自然对数 ln 的C实现,见 scratch lenet(10): C语言计算log6
// logarithm for natural number `e`
// Mostly taken from https://gist.github.com/LingDong-/7e4c4cae5cbbc44400a05fba65f06f23
float m_log(float x)
{
// x = m * 2^p, m \in [1, 2]
// ln(x) = ln(m * 2^p) = ln(m) + ln(2) * p
// determine p
unsigned int bx = *(unsigned int *) (&x);
unsigned int ex = bx >> 23;
signed int p = (signed int)ex - (signed int)127;
// determine m
// exp: 00000000
// frac: 0b11111111111111111111111 = 838607
unsigned int bm = (127 << 23) | (bx & 8388607);
float m = *(float *) (&bm);
// printf("m = %f\n", m);
// determine ln(m) by Remez algorithm for m in [1, 2]
float ln_m_approx_4th_order = -1.7417939 + (2.8212026 + (-1.4699568 + (0.44717955 - 0.056570851 * m) * m) * m) * m;
//float ln_m_approx_3rd_order = -1.49278 + (2.11263 +(-0.729104 + 0.10969 * m) * m) * m;
float ln_m_approx = ln_m_approx_4th_order;
// combine the result
const float ln2 = 0.6931471806;
float res = ln_m_approx + ln2 * p;
return res;
}