【C#】数字图像识别

这是一个用于训练BP神经网络完成数字识别任务的小程序,较为简单,可以在此基础上进行修改,完成常见的分类任务。

这个东西该怎么使用?

首先,有一点需要明确的是,这是一个训练神经网络的工具,当然,也可以实现识别,但是由于当时只是想在PC上训练一个BP神经网络,然后移植到单片机上,求出BP网络的权重文件才是该软件的目的所在,所以,就没在GUI上添加具体的识别操作的按钮。如果需要,可以参考test方法自己实现一下。

整个程序可以在这里下载到:C#实现BP神经网络数字图像识别源码,也可以在github上下载:BPNetwork

解压后会得到下面的目录,下面分别介绍一下:

  1. 测试样本:测试集,20*20的灰度图像,包含0~9;
  2. 训练样本:训练集,20*20的灰度图像,包含0~9;
  3. 训练成功后的矩阵:最终我个人训练的BP网络的权重矩阵;
  4. BPNetwork.exe:可执行程序;
  5. BPNetwork.rar:整个项目源码

具体原理:

  1. 将训练集和测试集中每一个20*20的图像拉伸为一个长度为400的数组;
  2. 利用train方法训练得到权重文件;
  3.  利用test方法测试所得到的BP神经网络的识别准确率。
using System;
using System.IO;
using System.Text;

namespace BPNetwork
{
    /// <summary>  
    /// BpNet 的摘要说明。  
    /// </summary>  
    public class BpNet
    {
        /// <summary>
        /// 输入节点数
        /// </summary>
        public int inNum;  
        /// <summary>
        /// 隐层节点数
        /// </summary>
        int hideNum;  
        /// <summary>
        /// 输出层节点数  
        /// </summary>
        public int outNum;
        /// <summary>
        /// 样本总数
        /// </summary>
        public int sampleNum;  

        Random R;
        /// <summary>
        /// 输入节点的输入(输出)数据  
        /// </summary>
        double[] x;
        /// <summary>
        /// 隐层节点的输出  
        /// </summary>
        double[] x1;
        /// <summary>
        /// 输出节点的输出  
        /// </summary>
        double[] x2;
        /// <summary>
        /// 隐层的输入  
        /// </summary>
        double[] o1;
        /// <summary>
        /// 输出层的输入  
        /// </summary>
        double[] o2;
        /// <summary>
        /// 权值矩阵w  
        /// </summary>
        public double[,] w;
        /// <summary>
        /// 权值矩阵V  
        /// </summary>
        public double[,] v;
        /// <summary>
        /// 权值矩阵w  
        /// </summary>
        public double[,] dw;
        /// <summary>
        /// 权值矩阵V  
        /// </summary>
        public double[,] dv;
        
        /// <summary>
        /// 隐层阈值矩阵  
        /// </summary>
        public double[] b1;
        /// <summary>
        /// 输出层阈值矩阵  
        /// </summary>
        public double[] b2;
        /// <summary>
        /// 隐层阈值矩阵  
        /// </summary>
        public double[] db1;
        /// <summary>
        /// 输出层阈值矩阵
        /// </summary>
        public double[] db2;

        /// <summary>
        /// 隐层的误差
        /// </summary>
        double[] pp;
        /// <summary>
        /// 输出层的误差
        /// </summary>
        double[] qq;
        /// <summary>
        /// 输出层的教师数据
        /// </summary>
        double[] yd;
        /// <summary>
        /// 均方误差
        /// </summary>
        public double e;
        /// <summary>
        /// 归一化比例系数
        /// </summary>
        double in_rate;

        /// <summary>
        /// 计算隐藏层节点数
        /// </summary>
        /// <param name="m">输入层节点数</param>
        /// <param name="n">输出层节点数</param>
        /// <returns></returns>
        public int computeHideNum(int m, int n)
        {
            double s = Math.Sqrt(0.43 * m * n + 0.12 * n * n + 2.54 * m + 0.77 * n + 0.35) + 0.51;
            int ss = Convert.ToInt32(s);
            return ((s - (double)ss) > 0.5) ? ss + 1 : ss;
        }

        /// <summary>
        /// 初始化神经网络
        /// </summary>
        /// <param name="innum">输入节点数</param>
        /// <param name="outnum">输出节点数</param>
        public BpNet(int innum, int outnum)
        {
            // 构造函数逻辑  
            R = new Random();

            this.inNum = innum; //数组第二维大小 为 输入节点数  
            this.outNum = outnum; //输出节点数  
            this.hideNum = computeHideNum(inNum, outNum); //隐藏节点数

            x = new double[inNum];
            x1 = new double[hideNum];
            x2 = new double[outNum];

            o1 = new double[hideNum];
            o2 = new double[outNum];

            w = new double[inNum, hideNum];
            v = new double[hideNum, outNum];
            dw = new double[inNum, hideNum];
            dv = new double[hideNum, outNum];

            b1 = new double[hideNum];
            b2 = new double[outNum];
            db1 = new double[hideNum];
            db2 = new double[outNum];

            pp = new double[hideNum];
            qq = new double[outNum];
            yd = new double[outNum];

            //初始化w  
            for (int i = 0; i < inNum; i++)
            {
                for (int j = 0; j < hideNum; j++)
                {
                    w[i, j] = (R.NextDouble() * 2 - 1.0) / 2;
                }
            }

            //初始化v  
            for (int i = 0; i < hideNum; i++)
            {
                for (int j = 0; j < outNum; j++)
                {
                    v[i, j] = (R.NextDouble() * 2 - 1.0) / 2;
                }
            }
            
            e = 0.0;
            in_rate = 1.0;
        }

        /// <summary>
        /// 训练函数
        /// </summary>
        /// <param name="p">训练样本集合</param>
        /// <param name="t">训练样本结果集合</param>
        /// <param name="rate">学习率</param>
        public void train(double[,] p, double[,] t, double rate)
        {
            //获取样本数量
            this.sampleNum = p.GetLength(0);
            e = 0.0;
            //求p,t中的最大值  
            double pMax = 0.0;
            for (int isamp = 0; isamp < sampleNum; isamp++)
            {
                for (int i = 0; i < inNum; i++)
                {
                    if (Math.Abs(p[isamp, i]) > pMax)
                    {
                        pMax = Math.Abs(p[isamp, i]);
                    }
                }

                for (int j = 0; j < outNum; j++)
                {
                    if (Math.Abs(t[isamp, j]) > pMax)
                    {
                        pMax = Math.Abs(t[isamp, j]);
                    }
                }
                in_rate = pMax;
            }

            for (int isamp = 0; isamp < sampleNum; isamp++)
            {
                //数据归一化  
                for (int i = 0; i < inNum; i++)
                {
                    x[i] = p[isamp, i] / in_rate;
                }
                for (int i = 0; i < outNum; i++)
                {
                    yd[i] = t[isamp, i] / in_rate;
                }

                //计算隐层的输入和输出  
                for (int j = 0; j < hideNum; j++)
                {
                    o1[j] = 0.0;
                    for (int i = 0; i < inNum; i++)
                    {
                        o1[j] += w[i, j] * x[i];
                    }
                    x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j]));
                }

                //计算输出层的输入和输出  
                for (int k = 0; k < outNum; k++)
                {
                    o2[k] = 0.0;
                    for (int j = 0; j < hideNum; j++)
                    {
                        o2[k] += v[j, k] * x1[j];
                    }
                    x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k]));
                }

                //计算输出层误差和均方差  
                for (int k = 0; k < outNum; k++)
                {
                    qq[k] = (yd[k] - x2[k]) * x2[k] * (1.0 - x2[k]);
                    e += (yd[k] - x2[k]) * (yd[k] - x2[k]);
                    //更新V  
                    for (int j = 0; j < hideNum; j++)
                    {
                        v[j, k] += rate * qq[k] * x1[j];
                    }
                }

                //计算隐层误差  
                for (int j = 0; j < hideNum; j++)
                {
                    pp[j] = 0.0;
                    for (int k = 0; k < outNum; k++)
                    {
                        pp[j] += qq[k] * v[j, k];
                    }
                    pp[j] = pp[j] * x1[j] * (1 - x1[j]);

                    //更新W  
                    for (int i = 0; i < inNum; i++)
                    {
                        w[i, j] += rate * pp[j] * x[i];
                    }
                }

                //更新b2  
                for (int k = 0; k < outNum; k++)
                {
                    b2[k] += rate * qq[k];
                }

                //更新b1  
                for (int j = 0; j < hideNum; j++)
                {
                    b1[j] += rate * pp[j];
                }
            }
            e = Math.Sqrt(e);
            //adjustWV(w,dw);  
            //adjustWV(v,dv);  
        }
        
        /// <summary>
        /// 测试函数(单个数据测试)
        /// </summary>
        /// <param name="p">待测试样本</param>
        /// <returns>识别结果</returns>
        public int test(double[] p)
        {
            double[,] w = new double[inNum, hideNum];
            double[,] v = new double[hideNum, outNum];
            double[] b1 = new double[hideNum];
            double[] b2 = new double[outNum];
            //1.读取权值矩阵系数
            readMatrixW(w, "w.txt");
            readMatrixW(v, "v.txt");
            readMatrixB(b1, "b1.txt");
            readMatrixB(b2, "b2.txt");

            //2.数据归一化  
            double pMax = 0.0;
            for (int i = 0; i < inNum; i++)
            {
                if (Math.Abs(p[i]) > pMax)
                {
                    pMax = Math.Abs(p[i]);
                }
            }
            in_rate = pMax;//归一化系数
            for (int i = 0; i < inNum; i++)
            {
                x[i] = p[i] / in_rate;
            }

            //3.计算隐层的输入和输出  
            for (int j = 0; j < hideNum; j++)
            {
                o1[j] = 0.0;
                for (int i = 0; i < inNum; i++)
                {
                    o1[j] += w[i, j] * x[i];
                }

                x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j]));
            }

            //4.计算输出层的输入和输出  
            for (int k = 0; k < outNum; k++)
            {
                o2[k] = 0.0;
                for (int j = 0; j < hideNum; j++)
                {
                    o2[k] += v[j, k] * x1[j];
                }
                x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k]));
            }

            //5.判断是否正确
            double max = x2[0];
            int maxi = 0;
            for(int i = 0; i < outNum; i++)
            {
                if(x2[i] > max)
                {
                    max = x2[i];
                    maxi = i;
                }
            }
            return maxi;
        }

        public void adjustWV(double[,] w, double[,] dw)
        {
            for (int i = 0; i < w.GetLength(0); i++)
            {
                for (int j = 0; j < w.GetLength(1); j++)
                {
                    w[i, j] += dw[i, j];
                }
            }
        }

        public void adjustWV(double[] w, double[] dw)
        {
            for (int i = 0; i < w.Length; i++)
            {
                w[i] += dw[i];
            }
        }

        /// <summary>
        /// 保存矩阵w,v  
        /// </summary>
        /// <param name="w">要保存的矩阵</param>
        /// <param name="filename">文件名</param>
        public void saveMatrix(double[,] w, string filename)
        {
            StreamWriter sw = File.CreateText(filename);
            for (int i = 0; i < w.GetLength(1); i++)
            {
                for (int j = 0; j < w.GetLength(0); j++)
                {
                    sw.Write(w[j, i].ToString("0.000000000000000") + " ");
                }
                sw.WriteLine();
            }
            sw.Close();

        }

        /// <summary>
        /// 保存矩阵b1,b2  
        /// </summary>
        /// <param name="b">要保存的阀值矩阵</param>
        /// <param name="filename">文件名</param>
        public void saveMatrix(double[] b, string filename)
        {
            StreamWriter sw = File.CreateText(filename);
            for (int i = 0; i < b.Length; i++)
            {
                sw.Write(b[i] + " ");
            }
            sw.Close();
        }

        /// <summary>
        /// 读取矩阵W,V  
        /// </summary>
        /// <param name="w">要读取到的那个矩阵</param>
        /// <param name="filename">文件所在位置</param>
        public void readMatrixW(double[,] w, string filename)
        {
            StreamReader sr;
            try
            {
                sr = new StreamReader(filename, Encoding.GetEncoding("gb2312"));

                String line;
                int i = 0;

                while ((line = sr.ReadLine()) != null)
                {
                    string[] s1 = line.Trim().Split(' ');
                    for (int j = 0; j < s1.Length; j++)
                    {
                        w[j, i] = Convert.ToDouble(s1[j]);
                    }
                    i++;
                }
                sr.Close();

            }
            catch (Exception e)
            {
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }
        }

        /// <summary>
        /// 读取矩阵b1,b2  
        /// </summary>
        /// <param name="b">要读取的阀值矩阵</param>
        /// <param name="filename">文件所在位置</param>
        public void readMatrixB(double[] b, string filename)
        {
            StreamReader sr;
            try
            {
                sr = new StreamReader(filename, Encoding.GetEncoding("gb2312"));

                String line;
                if ((line = sr.ReadLine()) != null)
                {
                    string[] strs = line.Trim().Split(' ');
                    for (int i = 0; i < strs.Length; i++)
                    {
                        b[i] = Convert.ToDouble(strs[i]);
                    }
                }
                sr.Close();

            }
            catch (Exception e)
            {
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }
        }
    }
}
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Drawing.Imaging;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;

namespace BPNetwork
{
    public partial class MainFrm : Form
    {
        public MainFrm()
        {
            InitializeComponent();
            this.btnTest.Enabled = false;
            this.btnTrain.Enabled = false;
            this.txtLearnRate.Text = 0.3.ToString();
            //窗口固定大小
            this.MaximizeBox = false;//最大化按钮隐藏
            this.MinimizeBox = false;//最小化按钮隐藏
            this.FormBorderStyle = FormBorderStyle.FixedSingle;//不支持鼠标拖动

            this.lblMessage.Text = "请先载入测试或训练样本";
        }

        /// <summary>
        /// 训练按钮是否已经点击过一次
        /// </summary>
        private static int flag = 0;
        /// <summary>
        /// 训练文件是否已打开
        /// </summary>
        private static int flag2 = 0;
        /// <summary>
        /// 测试文件是否已打开
        /// </summary>
        private static int flag3 = 0;
        private BackgroundWorker bw;

        private void btnTrain_Click(object sender, EventArgs e)
        {
            if (flag2 != 0)//测试文件和训练文件都已经选中
            {
                if (flag == 0)
                {
                    bw = new BackgroundWorker();
                    bw.DoWork += Bw_DoWork;
                    bw.RunWorkerCompleted += Bw_RunWorkerCompleted;
                    bw.WorkerSupportsCancellation = true;//1.支持取消操作

                    bw.RunWorkerAsync();
                    this.btnTrain.Text = "停止";
                    flag = 1;
                }
                else
                {
                    this.btnTrain.Text = "训练";
                    bw.CancelAsync();
                    flag = 0;
                }
            }
            else
            {
                MessageBox.Show("请点击文件,选择要训练文件所在的目录!", "文件未载入");
                flag2 = 0;
            }
        }

        private void Bw_RunWorkerCompleted(object sender, RunWorkerCompletedEventArgs e)
        {
            this.Show();//隐藏窗体
            MessageBox.Show("训练成功", "提示");
            this.btnTrain.Text = "训练";
            flag = 0;
        }

        /// <summary>
        /// 训练样本的目录
        /// </summary>
        private static string train_path;
        /// <summary>
        /// 测试样本的目录
        /// </summary>
        private static string test_path;

        private void Bw_DoWork(object sender, DoWorkEventArgs e)
        {
            //定义BP神经网络类
            BpNet bp = new BpNet(400, 10);
            double[] tmp = new double[20];

            try
            {
                //学习率
                double lr = Double.Parse(this.txtLearnRate.Text.Trim());
                int count = 0;//计数器
                int study = 0;//学习(训练)次数

                //数据字典
                Dictionary<string, int> filedictionary = new Dictionary<string, int>();
                for (int i = 0; i < 10; i++)
                {
                    string dir = train_path + @"\" + i + @"\";
                    string[] files = Directory.GetFiles(dir);
                    foreach (string item in files)
                    {
                        filedictionary.Add(item, i);
                    }
                }

                //声明数据存储区域
                double[,] input = new double[filedictionary.Count, 400];
                double[,] output = new double[filedictionary.Count, 10];

                //数据装载
                foreach (KeyValuePair<string, int> item in filedictionary)
                {
                    Bitmap bmp = new Bitmap(item.Key);

                    for (int k = 0; k < bmp.Height; k++)
                    {
                        for (int l = 0; l < bmp.Width; l++)
                        {
                            input[count, k * bmp.Width + l] = bmp.GetPixel(l, k).R;
                        }
                    }

                    //交换行,因为位图存储时,先存储最后一行,从图片的底部开始,逐渐向上扫描
                    for (int k = 0; k < bmp.Height / 2; k++)
                    {
                        for (int l = 0; l < bmp.Width; l++)
                        {
                            tmp[l] = input[count, k * bmp.Width + l];
                            input[count, k * bmp.Width + l] = input[count, (bmp.Height - 1 - k) * bmp.Width + l];
                            input[count, (bmp.Height - 1 - k) * bmp.Width + l] = tmp[l];
                        }
                    }

                    output[count, item.Value] = 1;//第j个图片被分为第i类
                    count++;
                }

                do
                {
                    if (!bw.CancellationPending)//2.检测用户是否取消
                    {
                        //训练
                        bp.train(input, output, lr);
                        study++;
                        this.lblMessage.Text = "第" + study + "次训练的误差: " + bp.e;
                    }
                    else
                    {
                        break;//停止训练
                    }
                } while (bp.e > 0.01 && study < 50000);
            }
            catch (Exception ex)
            {
                MessageBox.Show("出错了" + ex.Message);
            }
            finally//出错或者中途取消也会保存权值矩阵的信息
            {
                bp.saveMatrix(bp.w, "w.txt");
                bp.saveMatrix(bp.v, "v.txt");
                bp.saveMatrix(bp.b1, "b1.txt");
                bp.saveMatrix(bp.b2, "b2.txt");
                this.lblMessage.Text = "训练终止!";
            }
        }

        private void btnTest_Click(object sender, EventArgs e)
        {
            if (flag3 != 0 && File.Exists("w.txt") && File.Exists("v.txt") && File.Exists("b1.txt") && File.Exists("b2.txt"))
            {
                //清空已有训练结果
                this.lbTestResult.Items.Clear();
                BackgroundWorker bw1 = new BackgroundWorker();

                bw1.DoWork += Bw1_DoWork;

                bw1.RunWorkerAsync();
                flag3 = 1;
            }
            else
            {
                MessageBox.Show("请点击文件,选择要测试文件所在的目录!", "文件未载入");
                flag3 = 0;
            }
        }

        private void Bw1_DoWork(object sender, DoWorkEventArgs e)
        {
            try
            {
                //定义BP神经网络类
                BpNet bp = new BpNet(400, 10);

                int right_count = 0;
                string[] files;
                double[] tmp = new double[20];
                //读取文件
                for (int i = 0; i < 10; i++)
                {
                    right_count = 0;
                    string dir = test_path + @"\" + i + @"\";
                    files = Directory.GetFiles(dir);

                    //共files.Length个样本,每个样本数据有400个字节
                    double[] input = new double[400];
                    double[] output = new double[10];

                    for (int j = 0; j < files.Length; j++)
                    {
                        Bitmap bmp = new Bitmap(files[j]);

                        for (int k = 0; k < bmp.Height; k++)
                        {
                            for (int l = 0; l < bmp.Width; l++)
                            {
                                input[k * bmp.Width + l] = bmp.GetPixel(l, k).R;
                            }
                        }

                        //交换行,因为位图存储时,先存储最后一行,从图片的底部开始,逐渐向上扫描
                        for (int k = 0; k < bmp.Height / 2; k++)
                        {
                            for (int l = 0; l < bmp.Width; l++)
                            {
                                tmp[l] = input[k * bmp.Width + l];
                                input[k * bmp.Width + l] = input[(bmp.Height - 1 - k) * bmp.Width + l];
                                input[(bmp.Height - 1 - k) * bmp.Width + l] = tmp[l];
                            }
                        }

                        if (i == bp.test(input))
                        {
                            right_count++;
                        }
                    }

                    this.lbTestResult.Items.Add(files.Length + "个" + i + "样本识别成功率:" + (1.0 * right_count / files.Length * 100).ToString("0.00") + "%");
                }
                this.lblMessage.Text = "测试成功!";
            }
            catch (Exception ex)
            {
                MessageBox.Show("出错了" + ex.Message);
            }
        }

        private void menuOpenTrain_Click(object sender, EventArgs e)
        {
            FolderBrowserDialog path = new FolderBrowserDialog();
            path.ShowDialog();
            train_path = path.SelectedPath;
            flag2 = 1;
            this.btnTrain.Enabled = true;
            this.lblMessage.Text = "训练样本载入成功";
        }

        private void menuOpenTest_Click(object sender, EventArgs e)
        {
            FolderBrowserDialog path = new FolderBrowserDialog();
            path.ShowDialog();
            test_path = path.SelectedPath;
            flag3 = 1;
            this.btnTest.Enabled = true;
            this.lblMessage.Text = "测试样本载入成功";
        }

        private static int flag1 = 0;
        private void menuStay_Click(object sender, EventArgs e)
        {
            if (flag1 == 0)
            {
                //窗口置顶
                this.TopMost = true;
                this.menuStay.Text = "取消窗口保持在前";
                flag1 = 1;
            }
            else
            {
                //取消窗口置顶
                this.TopMost = false;
                this.menuStay.Text = "窗口保持在前";
                flag1 = 0;
            }
        }

        private void menuRunBackground_Click(object sender, EventArgs e)
        {
            this.Hide();//隐藏窗体
        }
    }
}
  • 3
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 27
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 27
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值