scratch lenet(12): LeNet-5输出层和损失函数的计算

本文详细介绍了LeNet-5网络中输出层的结构,特别是GaussianConnection的原理和权重可视化,以及损失函数的计算,包括基本项和惩罚项。此外,还提供了C语言实现的输出层前向计算和损失函数计算的代码示例。
摘要由CSDN通过智能技术生成

1. 目的

弄清 LeNet-5 的论文原文 LeCun-98.pdf1 中的输出层的结构,以及损失函数的计算公式。

的确网络上已经有很多 LeNet-5 的现代实现了, 但过于现代导致和原版相差太多。本篇 blog 尽可能贴合论文原文, 不希望人云亦云照搬已有结果, 根据论文进行复现追求的是“一致性”。假设没有互联网上已经存在的各种 LeNet-5 的实现, 也应该能按论文原文做出实现。

2. 输出层结构

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-93L5Wh2q-1687600825416)(LeNet-5_gaussian_connection_rbf.png)]

2.1 Gaussian Connection

博客2 给出了 Gaussian Connection 的白话解释:高斯连接的输出就是输出结果,不需要套用激活函数; 高斯连接需要和欧几里得径向基函数(Euclidean Radial Basic Functions)配合使用。

Finally, the output layer is composed of Euclidean Radial Basis Function units (RBF), one for each class, with 84 inputs each. The output each RBF unit y i y_i yi is computed as follows:
y i = ∑ j ( x j − w i j ) 2 y_i = \sum_j(x_j - w_{ij})^2 yi=j(xjwij)2

2.2 Gaussian Connection 的 weight 可视化

按论文原文, Gaussian connection 的 weight 只有两种取值: 1 和 -1。不妨做一个映射,把1映射到255,把-1映射到0,就得到一幅图像。没错, Gaussian connection 的 weight 只有10个,每个代表一个 ascii 打印形式的数字(0-9),大小是 7x12 = 84 像素。

首先是制作每个数字的 ascii 表示, 然后保存为 .pgm 格式的图像.

可视化:
在这里插入图片描述

存储 pgm 图像:

void write_pgm_image(uchar* image, int width, int height, const char* filename)
{
    FILE* fout = fopen(filename, "wb");
    fprintf(fout, "P5\n%d %d\n255\n", width, height);
    fwrite(image, width * height, 1, fout);
    fclose(fout);
}

完整代码:

uchar digits[10][84];

void make_7x12_digits()
{
    uchar d[10][84] =
    {
        {
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 1, 1, 1, 0, 0,
            0, 1, 1, 0, 1, 1, 0,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            0, 1, 1, 0, 1, 1, 0,
            0, 0, 1, 1, 1, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
        
        {
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 1, 1, 1, 0, 0,
            0, 1, 1, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
        
        {
            0, 0, 0, 0, 0, 0, 0,
            0, 1, 1, 1, 1, 0, 0,
            0, 1, 1, 0, 1, 1, 1,
            1, 0, 0, 0, 1, 1, 1,
            0, 0, 0, 1, 1, 1, 0,
            0, 0, 1, 1, 1, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        },

        {
            0, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 1, 1, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 0, 0, 0,
            0, 0, 1, 1, 1, 1, 0,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            0, 1, 1, 1, 1, 1, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
        
        {
            0, 0, 0, 0, 0, 0, 0,
            0, 1, 1, 0, 1, 1, 0,
            0, 1, 1, 0, 1, 1, 0,
            0, 1, 1, 0, 1, 1, 0,
            0, 1, 1, 0, 1, 1, 0,
            1, 1, 0, 0, 1, 1, 0,
            1, 1, 0, 0, 1, 1, 1,
            1, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 1, 1, 0,
            0, 0, 0, 0, 1, 1, 0,
            0, 0, 0, 0, 1, 1, 0,
            0, 0, 0, 0, 1, 1, 0,
        },

        {
            0, 0, 0, 0, 0, 0, 0,
            1, 1, 1, 1, 1, 1, 1,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 1, 1, 1, 0, 0,
            0, 0, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            1, 1, 1, 0, 0, 1, 1,
            0, 1, 1, 1, 1, 1, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
     
        {
            0, 0, 0, 1, 1, 1, 0,
            0, 1, 1, 1, 0, 0, 0,
            0, 1, 0, 0, 0, 0, 0,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 0, 0, 0, 0, 0,
            1, 1, 1, 1, 1, 1, 0,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 0, 1,
            1, 1, 0, 0, 0, 0, 1,
            0, 1, 0, 0, 0, 1, 0,
            0, 1, 1, 1, 1, 1, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
         
        {
            1, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 1, 1, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 0, 1, 1, 0, 0,
            0, 0, 1, 1, 0, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            0, 1, 1, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        },
         
        {
            0, 1, 1, 1, 1, 1, 0,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            0, 1, 1, 1, 1, 1, 0,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 0, 1,
            1, 1, 0, 0, 0, 0, 1,
            1, 1, 0, 0, 0, 0, 1,
            0, 1, 1, 1, 1, 1, 0,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        },

        {
            0, 1, 1, 1, 1, 1, 0,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 0, 1, 1,
            1, 1, 0, 0, 1, 1, 1,
            0, 1, 1, 1, 1, 1, 1,
            0, 0, 0, 0, 0, 1, 1,
            0, 0, 0, 0, 0, 1, 0,
            0, 0, 0, 0, 1, 1, 0,
            0, 1, 1, 1, 1, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0,
        }
    };
    memcpy(digits, d, sizeof(digits));

    for (int k = 0; k < 10; k++)
    {
        char savename[50] = { 0 };
        sprintf(savename, "RBF%d.pgm", k);

        int width = 7;
        int height = 12;
        for (int i = 0; i < height; i++)
        {
            for (int j = 0; j < width; j++)
            {
                int idx = i * width + j;
                d[k][idx] = (d[k][idx] == 1) ? 255 : 0;
            }
        }

        write_pgm_image(d[k], width, height, savename);
    }
}

3. Loss Function

关于 Loss function, 先前卡在了不理解 Loss 函数中用的符号。后来在谷歌中搜了一下这篇论文的相关解读博客3, 4,发现其实 Introduction 部分就给出了符号的说明。

3.1 当前类别判断错误时,loss function 中的项(基本项)

Z p Z^p Zp: 第 p p p 个输入

Y p Y^p Yp: Z p Z^p Zp 的被识别的类别(推理结果)

W W W: 系统中可调整的参数的集合。

D p D^p Dp: Z p Z^p Zp 的正确类别(ground truth)

推理得到结果:
Y p = F ( Z p , W ) Y^p = F(Z^p, W) Yp=F(Zp,W)

Loss Function:
E p = D ( D p , F ( W , Z p ) E^p = D(D^p, F(W, Z^p) Ep=D(Dp,F(W,Zp)

E ( W ) = 1 P ∑ p = 1 P y D p ( Z p , W ) E(W) = \frac{1}{P}\sum_{p=1}^{P}y_{D^p}(Z^p, W) E(W)=P1p=1PyDp(Zp,W)

其中 y D p y_{D^p} yDpj 表示输入样本 Z p Z^p Zp 在对应的真实类别 D p D^p Dp 上的得分. 比如,对于 3.pgm 来说, D p = 3 D^p=3 Dp=3, output 层算出来的 RBF 有10个结果: y 0 , . . . , y 9 y_0, ..., y_9 y0,...,y9, 而 y 3 y_3 y3 就是 y D p y_{D^p} yDp

实际训练时, 样本数量 P > 1 P > 1 P>1, 我们收集每个样本的损失 y D p y_{D^p} yDp, 累加起来, 就得到 E ( W ) E(W) E(W) 了。

3.2 判断为其他类别时, loss function 中的项(惩罚项)

尽管上述 cost function 适合大多数情况, 但它缺乏3个重要的性质。

第一, 如果我们允许 RBF 层的参数改变(本来是设定为1和-1),那么 E ( W ) E(W) E(W) 会得到一个平凡的、同时完全不可接受的解。(啥意思??)在这个解中,所有 RBF 参数向量都相等, 并且F6的状态是常量,并且和参数向量相等。这种情况下, 网络开心地忽略输入, 并且所有 RBF 输入等于0。 如果 RBF 的权值不允许改变, 这种“坍缩现象”就不会出现。

第二个问题是,类别之间没有竞争。

为了解决这两个问题, 作者提出如下修正后的损失函数:
E ( W ) = 1 P ∑ p = 1 P ( y D p ( Z p , W ) + log ⁡ ( e − j + ∑ i e − y i ( Z p , W ) ) ) E(W) = \frac{1}{P}\sum_{p=1}^{P}(y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_i e^{-y_i(Z^p, W)})) E(W)=P1p=1P(yDp(Zp,W)+log(ej+ieyi(Zp,W)))

其中 j j j 是一个大于0的常数(但作者没有给出具体取值)。

其中 i i i 的取值,原文没说,可以理解为 D p D^p Dp 之外的类别,因为作为惩罚项它的目的是当其他类别的 RBF 输出值较小(接近0)时,惩罚项变大:

E ( W ) = 1 P ∑ p = 1 P ( y D p ( Z p , W ) + log ⁡ ( e − j + ∑ i ! = D p e − y i ( Z p , W ) ) ) E(W) = \frac{1}{P}\sum_{p=1}^{P}(y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_{i!=D^p} e^{-y_i(Z^p, W)})) E(W)=P1p=1P(yDp(Zp,W)+log(ej+i!=Dpeyi(Zp,W)))

例如对于 3.pgm 输入图, D p = 3 D^p=3 Dp=3, i ∈ { 0 , 1 , 2 , 4 , 5 , 6 , 7 , 8 , 9 } i \in \{ 0, 1, 2, 4, 5, 6, 7, 8, 9 \} i{0,1,2,4,5,6,7,8,9}.

4. 代码实现

4.1 输出层的前向计算


void forward_output()
{
    double* x = F6_output;
    double* y = RBF_output;
    double w[10][84];

    for (int i = 0; i < 10; i++)
    {
        for (int j = 0; j < 84; j++)
        {
            w[i][j] = (digits[i][j] == 255) ? 1 : -1;
        }
    }

    // Euclidean Radial Basis Functions (RBF)
    printf("Output layer: \n");
    for (int i = 0; i < 10; i++)
    {
        double sum = 0;
        for (int j = 0; j < 84; j++)
        {
            double diff = (x[j] - w[i][j]);
            sum += diff * diff;
        }
        y[i] = sum;
        printf("  y[%d] = %lf\n", i, y[i]);
    }
}

4.2 Loss Function 的实现

简单起见, 先只考虑单个样本对应的损失函数。 也就是, P = 1 P=1 P=1 时,

E ( W ) = y D p ( Z p , W ) + log ⁡ ( e − j + ∑ i ! = D p e − y i ( Z p , W ) ) E(W) = y_{D^p}(Z^p, W) + \log(e^{-j} + \sum_{i!=D^p} e^{-y_i(Z^p, W)}) E(W)=yDp(Zp,W)+log(ej+i!=Dpeyi(Zp,W))

C 语言实现如下:

// compute loss for single input pattern D^p
void compute_loss()
{
    int correct_label = 3;
    double squared_error = RBF_output[correct_label];

    // compute penalty
    int j = 1;
    double value = m_exp(j);
    for (int i = 0; i < 10; i++)
    {
        if (i != correct_label)
        {
            value += m_exp(-RBF_output[i]);
        }
    }
    double penalty = m_log(value);
    double loss = squared_error + penalty;

    printf("loss is %lf\n", loss);
}

其中 m_exp() 是 exp 函数的C实现,见 scratch lenet(8): C语言实现 exp(x) 的计算5:

double m_exp(double x)
{
    double d;
    *((int*)&d + 0) = 0;
    *((int*)&d + 1) = (int)(1512775 * x + 1072632447);
    return d;
}

以及, m_log 是自然对数 ln 的C实现,见 scratch lenet(10): C语言计算log6

// logarithm for natural number `e`
// Mostly taken from https://gist.github.com/LingDong-/7e4c4cae5cbbc44400a05fba65f06f23
float m_log(float x)
{
    // x = m * 2^p, m \in [1, 2]
    // ln(x) = ln(m * 2^p) = ln(m) + ln(2) * p
    
    // determine p
    unsigned int bx = *(unsigned int *) (&x);
    unsigned int ex = bx >> 23;
    signed int p = (signed int)ex - (signed int)127;

    // determine m
    // exp:  00000000
    // frac: 0b11111111111111111111111 = 838607
    unsigned int bm = (127 << 23) | (bx & 8388607);
    float m = *(float *) (&bm);
    // printf("m = %f\n", m); 

    // determine ln(m) by Remez algorithm for m in [1, 2]
    float ln_m_approx_4th_order = -1.7417939 + (2.8212026 + (-1.4699568 + (0.44717955 - 0.056570851 * m) * m) * m) * m;
    //float ln_m_approx_3rd_order = -1.49278 + (2.11263 +(-0.729104 + 0.10969 * m) * m) * m;
    float ln_m_approx = ln_m_approx_4th_order;

    // combine the result
    const float ln2 = 0.6931471806;
    float res = ln_m_approx + ln2 * p;
    return res;
}

References


  1. lecun-98.pdf ↩︎

  2. ADS DataScience - 高斯连接(Gaussian Connection):LetNet-5论文里使用的高斯连接到底是个啥? ↩︎

  3. 论文笔记:Gradient-Based Learning Applied to Document Recognition ↩︎

  4. 論文筆記 - LeCun 1998 - Gradient-Based Learning Applied to Document Recognition ↩︎

  5. scratch lenet(8): C语言实现 exp(x) 的计算 ↩︎

  6. scratch lenet(10): C语言计算log ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值