Lumos框架实现XOR问题建模

本文介绍了如何利用Lumos深度学习框架构建神经网络模型来解决经典的非线性问题——XOR问题。模型包含多层全连接层,激活函数为ReLU,损失函数采用均方误差。文章详细展示了数据预处理、模型构建、训练和测试的过程,最后验证了模型能成功拟合XOR函数。
摘要由CSDN通过智能技术生成

Lumos框架实现XOR问题建模

通过前馈神经网络,构建XOR函数模型。使用[Lumos](LumosNet (github.com))深度学习框架,实现神经网络构建,训练和测试

异或(XOR)

异或函数XOR,是两个二进制数a,b的运算,当且仅当其中一个值为1时,XOR结果为1,其余结果为0

异或

标签数据 [a, b]
1[1, 0] [0, 1]
0[1, 1] [0, 0]

异或问题是典型的非线性问题,于逻辑与,逻辑或相比较

逻辑与

标签数据 [a, b]
1[1, 1]
0[1, 0] [0, 1] [0, 0]

逻辑或

标签数据 [a, b]
1[1, 0] [0, 1] [1, 1]
0[0, 0]

异或,逻辑与,逻辑或的散点图如下

在这里插入图片描述

可以看出,逻辑与和逻辑或的数据分布可以用一个线性函数进行分割,而异或无法用单一线性函数进行划分,所以XOR函数是一个典型的非线性函数

模型构建

数据集

数据集下载XOR

将异或的四个数据作为训练和测试数据,我们希望构建的模型能够完全拟合

数据标签
[0, 0]0
[0, 1]1
[1, 0]1
[1, 1]0

网络结构

Connect         Layer    :    [output=   4, bias=1, active=relu]
Connect         Layer    :    [output=   2, bias=1, active=relu]
Connect         Layer    :    [output=   1, bias=1, active=relu]
Mse             Layer    :    [output=   1]

一共三层全连接层,神经元个数分别为4,2,1

全连接层带有偏置项,采用relu激活函数

损失函数采用Mse(均方差)
M S E = 1 n S S E = 1 n ∑ i = 1 n ( y i ^ − y i ) 2 y i ^ 预测结果, y i 真实标签 MSE=\frac{1}{n}SSE=\frac{1}{n} \sum_{i=1}^{n}(\hat{y_{i}}-y_i)^{2} \\ \hat{y_{i}}预测结果,y_i 真实标签 MSE=n1SSE=n1i=1n(yi^yi)2yi^预测结果,yi真实标签

代码构建

标签处理函数

[lumos](LumosNet (github.com))框架支持自定义标签预处理,使用独热编码处理标签

void xor_label2truth(char **label, float *truth)
{
    int x = atoi(label[0]);
    one_hot_encoding(1, x, truth);
}

网络构建

首先创建graph,并将所有layer添加至graph中

lumos接受数据必须是图片形式,所以添加im2col层将图像数据转化为一维向量

Graph *graph = create_graph("Lumos", 5);
Layer *l1 = make_im2col_layer(1);
Layer *l2 = make_connect_layer(4, 1, "relu");
Layer *l3 = make_connect_layer(2, 1, "relu");
Layer *l4 = make_connect_layer(1, 1, "relu");
Layer *l5 = make_mse_layer(1);
append_layer2grpah(graph, l1);
append_layer2grpah(graph, l2);
append_layer2grpah(graph, l3);
append_layer2grpah(graph, l4);
append_layer2grpah(graph, l5);

权重初始化器

我们采用KaimingHe初始化

Initializer init = he_initializer();

创建会话

创建会话,并绑定网络模型

Session *sess = create_session("cpu", init);
bind_graph(sess, graph);

创建训练场景

指定训练数据,训练batch设置为4,训练500轮,学习率 0.01

create_train_scene(sess, 1, 2, 1, 1, 1, xor_label2truth, "./xor/data.txt", "./xor/label.txt");
init_train_scene(sess, 500, 4, 2, NULL);
session_train(sess, 0.01, "./xorw.w");

创建测试会话和场景

Session *t_sess = create_session("cpu", init);
bind_graph(t_sess, graph);
create_test_scene(t_sess, 1, 2, 1, 1, 1, xor_label2truth, "./xor/test.txt", 		                       "./xor/label.txt");
init_test_scene(t_sess, "./xorw.w");
session_test(t_sess, xor_process_test_information);

测试结果展示

lumos框架支持自定义结果展示,打印测试结果和真实标签数据,以及Loss值

void xor_process_test_information(char **label, float *truth, float *predict, float loss, char *data_path)
{
    fprintf(stderr, "Test Data Path: %s\n", data_path);
    fprintf(stderr, "Label:   %s\n", label[0]);
    fprintf(stderr, "Truth:   %f\n", truth[0]);
    fprintf(stderr, "Predict: %f\n", predict[0]);
    fprintf(stderr, "Loss:    %f\n\n", loss);
}

完整代码

#include <stdio.h>
#include <stdlib.h>

#include "lumos.h"


void xor_label2truth(char **label, float *truth)
{
    int x = atoi(label[0]);
    one_hot_encoding(1, x, truth);
}

void xor_process_test_information(char **label, float *truth, float *predict, float loss, char *data_path)
{
    fprintf(stderr, "Test Data Path: %s\n", data_path);
    fprintf(stderr, "Label:   %s\n", label[0]);
    fprintf(stderr, "Truth:   %f\n", truth[0]);
    fprintf(stderr, "Predict: %f\n", predict[0]);
    fprintf(stderr, "Loss:    %f\n\n", loss);
}

void xor () {
    Graph *graph = create_graph("Lumos", 5);
    Layer *l1 = make_im2col_layer(1);
    Layer *l2 = make_connect_layer(4, 1, "relu");
    Layer *l3 = make_connect_layer(2, 1, "relu");
    Layer *l4 = make_connect_layer(1, 1, "relu");
    Layer *l5 = make_mse_layer(1);
    append_layer2grpah(graph, l1);
    append_layer2grpah(graph, l2);
    append_layer2grpah(graph, l3);
    append_layer2grpah(graph, l4);
    append_layer2grpah(graph, l5);

    Initializer init = he_initializer();
    Session *sess = create_session("cpu", init);
    bind_graph(sess, graph);
    create_train_scene(sess, 1, 2, 1, 1, 1, xor_label2truth, "./xor/data.txt", "./xor/label.txt");
    init_train_scene(sess, 500, 4, 2, NULL);
    session_train(sess, 0.01, "./xorw.w");

    Session *t_sess = create_session("cpu", init);
    bind_graph(t_sess, graph);
    create_test_scene(t_sess, 1, 2, 1, 1, 1, xor_label2truth, "./xor/test.txt", "./xor/label.txt");
    init_test_scene(t_sess, "./xorw.w");
    session_test(t_sess, xor_process_test_information);
}

int main(){
    xor();
    return 0;
}

训练及结果

使用如下命令编译代码

gcc -fopenmp xor.c -I/usr/local/lumos/include/ -o main -L/usr/local/lumos/lib -llumos

编译完成后运行

可以看到,打印出的网络结构

[Lumos]         max   5  Layers
Im2col          Layer    :    [flag=1]
Connect         Layer    :    [output=   4, bias=1, active=relu]
Connect         Layer    :    [output=   2, bias=1, active=relu]
Connect         Layer    :    [output=   1, bias=1, active=relu]
Mse             Layer    :    [output=   1]

[Lumos]                     Inputs         Outputs
Im2col          Layer      2*  1*  1 ==>   1*  2*  1
Connect         Layer      1*  2*  1 ==>   1*  4*  1
Connect         Layer      1*  4*  1 ==>   1*  2*  1
Connect         Layer      1*  2*  1 ==>   1*  1*  1
Mse             Layer      1*  1*  1 ==>   1*  1*  1

最终得到如下结果

Session Start To Detect Test Cases
Test Data Path: ./xor/data/00.png
Label:   0
Truth:   0.000000
Predict: 0.129547
Loss:    0.016782

Test Data Path: ./xor/data/01.png
Label:   1
Truth:   1.000000
Predict: 0.904523
Loss:    0.009116

Test Data Path: ./xor/data/11.png
Label:   0
Truth:   0.000000
Predict: 0.104283
Loss:    0.010875

Test Data Path: ./xor/data/10.png
Label:   1
Truth:   1.000000
Predict: 0.942409
Loss:    0.003317
数据测试结果真实标签Loss
[0, 0]0.1295470.00.016782
[0, 1]0.9045231.00.009116
[1, 1]0.1042830.00.010875
[1, 0]0.9424091.00.003317

完全符合预期,与XOR函数完全拟合

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫猫虫(——)

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值