C++ 实现全连接神经网络算法识别 Mnist 手写数字
完整项目代码可从 https://github.com/hfq0219/mnist
下载。
/**
*@Author: fengqi
*@Email: 2607546441@qq.com
*/
本程序使用全连接神经网络进行手写数字识别的训练和预测。当然修改一下输入和输出节点数,调整网络层数,也可用于其他多分类或回归问题。
代码结构参考了 yolo(You Only Look Once) 项目源码框架 darknet.
目录文件介绍:
--mnist/ 存放的是 mnist 数据集原始二进制文件;
--obj/ 存放的是编译生成的 .obj 文件;
--backup/ 存放的是每轮训练过程中生成的权重文件
--testData/ 是运行 ./data 程序读取 mnist 测试数据集生成的测试图片(如 0.jpg, 1.jpg...)和对应的标签文件 testLabel.txt;
--trainData/ 是运行 ./data 程序读取 mnist 训练数据集生成的训练图片(如 0.jpg, 1.jpg...)和对应的标签文件 trainLabel.txt;
--layer.cpp/.h 是全连接层的类定义及实现,主要是分配层的计算数据存储空间和前向计算反向传播以及参数更新函数定义;
--network.cpp/.h 是网络的类定义及实现,主要是定义了网络中,全连接层的添加,网络的前向传播,反向传播等函数;
--mnist.cpp 是读取 mnist/ 下的二进制文件,生成相应的图片,便于可视化和图片读写;
--main.cpp 是主函数入口文件,里面实现了网络训练及验证,以及预测功能,并且实现了在训练网络完成后保存网络各层的权重到文件里,
方便下次训练或预测时随时载入权重,不用重新训练网络;
使用方法介绍:
--本程序使用 Makefile 进行项目管理构建,只需在终端输入 make 命令,即可生成 data 和 run 两个可执行文件;
--注意,本程序生成和读写图片使用了 opencv,所以请确保电脑上安装并配置好了 opencv 开发环境。
1、运行 ./data 可读取 mnist/ 数据文件,生成训练和测试用的图片;
2、运行 ./run train 可训练网络,网络训练完成后,会生成 mnist.weight 网络权重文件;
3、运行 ./run test mnist.weight 可进行图片识别预测,只需输入图片文件名,按 ctrl-c 停止即可;
主要参数介绍:
--程序里的主要需要修改的参数有,训练迭代次数 epoches 和学习率 learning rate, 在构建 network 对象时传入;
--全连接层的个数及各层神经元数量和激活函数类型,可通过 network->addLayer(int node,ACTIVATION activate) 调整;
--学习率变化调整,默认调整方式是每次迭代减小 0.01,当小于 0.01 时,固定使用 0.01 作为学习率;
主入口函数:
/**
* @Author: fengqi
* @Email: 2607546441@qq.com
*/
#include <iostream>
#include <time.h>
#include <string>
#include <opencv2/opencv.hpp>
#include "network.h"
using namespace std;
using namespace cv;
void saveWeight(string file,Network *network){
//保存各层权重到文件
ofstream outfile(file);
for(int i=0;i<network->mNumLayers;i++){
Layer *layer=network->mLayers[i];
for(int m=0;m<layer->mNumNodes;m++){
for(int n=0;n<layer->mNumInputNodes+1;n++){
outfile<<layer->mWeights[m][n]<<" ";
}
}
}
outfile.close();
cout<<"save weight file to <"<<file<<"> done."<<endl;
}
void loadWeight(string file,Network *network){
//加载权重文件
ifstream infile(file);
if(!infile.is_open()){
cout<<"open weight file failed!"<<endl;
exit(-1);
}
for(int i=0;i<network->mNumLayers;i++){
Layer *layer=network->mLayers[i];
for(int m=0;m<layer->mNumNodes;m++){
for(int n=0;n<layer->mNumInputNodes+1;n++){
infile>>layer->mWeights[m][n];
}
}
}
infile.close();
cout<<"load weight from <"<<file<<"> done."<<endl;
}
float train(Network *network,string path,int imageSize, int numImages) //训练网络,使用训练数据集
{
srand(time(0));
float *temp = new float[imageSize];
string la=path;
ifstream labelFile(path.append("trainLabel.txt")); //标签文件
int label;
for (int i = 0; i < numImages; i++)
{
if(i%(numImages/10)==0){
//每 6000 张图片统计错误率,并显示训练进度
network->mErrorSum=0;
cout << setfill('=') << setw(2) << ">"<<(i/(numImages/10))*10<<"%"<<flush;
}
if(i==numImages-1)
cout<<"====>100%"<< endl;
int k=rand()%numImages; //随机选取图片训练
string l=la;
Mat x=imread(l.append(to_string(k)).append(".jpg"),0); //使用 opencv 读取图片
if(!x.data){
cout<<"read image error."<<endl;return -1;}
for(int m=0;m<x.rows;m++){
for(int n=0;n<x.cols;n++){
float a=(x.at<uchar>(m,n))/255.0; //归一化
temp[m*x.cols+n]=a;
}
}
labelFile.seekg(2*k); //标签和图片对应
labelFile>>label;
network->compute(temp,label); //每次训练一张图片
}
cout << "the error is:" << network->mErrorSum/(numImages/10);
labelFile.close();
delete [] temp;
return network->mErrorSum;
}
int validate(Network *network,string path,int imageSize, int numImages) //验证网络准确率,使用测试数据集
{
int ok_cnt = 0;
float* temp = new float[imageSize];
string la=path;
ifstream labelFile(path.append("testLabel.txt")); //标签文件
int label,idx=0;
for (int i = 0; i < numImages; i++)
{
if(i%(numImages/10)==0) //显示进度
cout << setfill('=') << setw(2) << ">"<<(i/(numImages/10))*10<<"%"<<flush;
if(i==numImages-1)
cout<<"====>100%"<< endl;
string l=la;
Mat x=imread(l.append