BP神经网络应用于手写数字识别--matlab程序

数据集和完整的程序下载见更新版:神经网络用于字符识别更新版

二:BP神经网络应用于字符识别

  字符包括汉字,字母,数字和一些符号。汉字有几千个,字母有几十个,数字的类最少只有10个,所以选择简单的手写数字字符来实现。结合三个相关的程序和论文,一个是语音特征的分类(不调用神经网络工具箱相关函数实现),另外两个是关于手写数字识别的。处理的数据集是放在10个文件夹里,文件夹的名称对应存放的手写数字图片的数字,每个数字500,每张图片的像素统一为28*28,如下

 

这5000张随机选取4500张进行训练,剩下的500张用来测试。为了能对BP神经网络有更深入的了解,选择一步步详细实现。

BP神经网络的特点:信号前向传递,信号反向传播。若输出存在误差,根据误差调整权值和阈值,使网络的输出接近预期。

在用BP神经网络进行预测之前要训练网络

训练过程

1.网络初始化:各个参数的确定包括输入,输出,隐含层的节点数,输入和隐含,隐含和输出层之间的权值,隐含,输出层的阈值,学习速度和激励函数。

2.计算隐含层输出

3.计算输出层输出

4.误差计算

5.权值更新

6.阈值更新

7.判断迭代是否结束

模型建立:

BP神经网络构建-BP神经网络训练-BP神经网络分类

1.确定神经网络的输入,输出。

输入是BP神经网络很重要的方面,输入的数据是手写字符经过预处理和特征提取后的数据。预处理有二值化,裁剪掉空白的区域,然后再统一大小为70*50为特征提取做准备。特征提取采用的是粗网格特征提取,把图像分成35个区域,每个区域100像素,统计区域中1像素所占的比例。经过预处理特征提取后,28*28图像转成1*35的特征矢量。提取完5000张图片后,依次把所有的特征存于一个矩阵(5000*35)中,最后在加上第36行,用来存放原图片的真值。于是最后得到是包含特征向量和真值的矩阵(5000*36),特征向量是神经网络的输入,真值是其输出。

2.神经的网络的训练

用matlab的rands函数来实现网络权值的初始化,网络结构为输入层35,隐藏层34,输出层10,学习速率为0.1,隐藏层激励函数为sigmoid函数。随机抽取4500张图片提取特征后输入,按照公式计算隐含层和输出层输出,误差,更新网络权值。

3.神经网络的预测

训练好神经网络之后,用随机抽取的500个数字字符对网络进行预测,输入特征向量,计算隐含层和输出层输出,得到最后预测的数据。同时计算每个数字的正确率和全体的正确率。最后得到的总体正确率为0.8004。

程序:

[plain]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. clc;  
  2. clearall;  
  3. closeall;  
  4. Files= dir('E:/Matlabdip/word_recognition/testdatabase');  
  5. LengthFiles= length(Files);  
  6. %========读取存在testdatabase下0-10个文件的全部图片========%  
  7. %========存放在number中,number{1}是数字0有500张========%  
  8. fori = 3:LengthFiles;     
  9.     if strcmp(Files(i).name,'.')||strcmp(Files(i).name,'..')  
  10.     else  
  11.        number{i-2}=BatchReadImg(strcat('E:/Matlabdip/word_recognition/testdatabase','/',Files(i).name),0);  
  12.     end   
  13. end  
  14. charvec1=zeros(35,5000);  
  15. %========对读取的图片预处理(二值化-裁剪-特征提取)========%  
  16.  for i=1:10  
  17.      for j=1:500  
  18.         I1=number{1,i}{1,j};  
  19.              Ibw = im2bw(I1,graythresh(I1));  
  20.          bw2 = edu_imgcrop(Ibw);%对图像进行裁剪,使边框完全贴紧字符  
  21.    
  22.         charvec = edu_imgresize(bw2);%提取特征统计每个小区域中图像象素所占百分比作为特征数据  
  23.         charvec1(:,(i-1)*500+j)=charvec;  
  24.      end  
  25.  end  
  26. index=[zeros(1,500),zeros(1,500)+1,...  
  27.    zeros(1,500)+2,zeros(1,500)+3,zeros(1,500)+4,zeros(1,500)+5,zeros(1,500)+6,zeros(1,500)+7,zeros(1,500)+8,zeros(1,500)+9];  
  28. charvec1(36,:)=index;  
  29. %=========BP神经网络创建,训练和测试========%  
  30. %从1到5000间随机排序(在[0,1]之间产生5000个随机数)==将数据顺序打乱  
  31. k=rand(1,5000);  
  32. [m,n]=sort(k);  
  33.    
  34.    
  35. %输入输出数据  
  36. input=charvec1(1:35,:);  
  37. output1=charvec1(36,:);  
  38.    
  39. %把输出从1维变成10维  
  40. fori=1:5000  
  41.     switch output1(i)  
  42.         case 0  
  43.             output(:,i)=[1 0 0 0 0 0 0 0 0 0]';  
  44.         case 1  
  45.             output(:,i)=[0 1 0 0 0 0 0 0 0 0]';  
  46.         case 2  
  47.             output(:,i)=[0 0 1 0 0 0 0 0 0 0]';  
  48.         case 3  
  49.             output(:,i)=[0 0 0 1 0 0 0 0 0 0]';  
  50.         case 4  
  51.             output(:,i)=[0 0 0 0 1 0 0 0 0 0]';  
  52.         case 5  
  53.             output(:,i)=[0 0 0 0 0 1 0 0 0 0]';  
  54.         case 6  
  55.             output(:,i)=[0 0 0 0 0 0 1 0 0 0]';  
  56.         case 7  
  57.             output(:,i)=[0 0 0 0 0 0 0 1 0 0]';  
  58.         case 8  
  59.             output(:,i)=[0 0 0 0 0 0 0 0 1 0]';  
  60.         case 9  
  61.             output(:,i)=[0 0 0 0 0 0 0 0 0 1]';  
  62.     end  
  63. end  
  64.    
  65. %随机提取4500个样本为训练样本,500个样本为预测样本  
  66. input_train=input(:,n(1:4500));  
  67. output_train=output(:,n(1:4500));  
  68. input_test=input(:,n(4501:5000));  
  69. output_test=output(:,n(4501:5000));  
  70.    
  71. % %输入数据归一化  
  72. %[inputn,inputps]=mapminmax(input_train);  
  73.    
  74. %% 网络结构初始化  
  75. innum=35;  
  76. midnum=34;  
  77. outnum=10;  
  78.    
  79.    
  80. %权值初始化  
  81. w1=rands(midnum,innum);%输入到隐藏  
  82. b1=rands(midnum,1);  
  83. w2=rands(midnum,outnum);%隐藏到输出  
  84. b2=rands(outnum,1);  
  85.    
  86. w2_1=w2;w2_2=w2_1;  
  87. w1_1=w1;w1_2=w1_1;  
  88. b1_1=b1;b1_2=b1_1;  
  89. b2_1=b2;b2_2=b2_1;  
  90.    
  91. %学习率  
  92. xite=0.1;  
  93. alfa=0.01;  
  94.    
  95. %% 网络训练  
  96. %for ii=1:10  
  97. %     E(ii)=0;  
  98.     for i=1:1:4500  
  99.        %% 网络预测输出  
  100.         x=input_train(:,i);  
  101.         % 隐含层输出  
  102.         for j=1:1:midnum  
  103.            I(j)=input_train(:,i)'*w1(j,:)'+b1(j);  
  104.             Iout(j)=1/(1+exp(double(-I(j))));  
  105.         end  
  106.         % 输出层输出  
  107.         yn=w2'*Iout'+b2;  
  108.          
  109.        %% 权值阀值修正  
  110.         %计算误差  
  111.         e=output_train(:,i)-yn;      
  112. %         E(ii)=E(ii)+sum(abs(e));  
  113.          
  114.         %计算权值变化率  
  115.         dw2=e*Iout;  
  116.         db2=e';  
  117.         %=======由于采用的是sigmoid单元,所以要对每个输出单元以及隐藏单元计算误差项======%  
  118.         for j=1:1:midnum  
  119.             S=1/(1+exp(double(-I(j))));  
  120.             FI(j)=S*(1-S);  
  121.         end       
  122.         for k=1:1:innum  
  123.             for j=1:1:midnum  
  124.                dw1(k,j)=FI(j)*x(k)*(w2(j,:)*e);%  
  125.                 db1(j)=FI(j)*(w2(j,:)*e);  
  126.             end  
  127.         end  
  128.             
  129.         w1=w1_1+xite*dw1';  
  130.         b1=b1_1+xite*db1';  
  131.         w2=w2_1+xite*dw2';  
  132.         b2=b2_1+xite*db2';  
  133.          
  134.         w1_2=w1_1;w1_1=w1;  
  135.         w2_2=w2_1;w2_1=w2;  
  136.         b1_2=b1_1;b1_1=b1;  
  137.         b2_2=b2_1;b2_1=b2;  
  138.     end  
  139. %end  
  140.    
  141.    
  142. %%% 语音特征信号分类  
  143. %inputn_test=mapminmax('apply',input_test,inputps);  
  144.    
  145.     for i=1:500%1500  
  146.         %隐含层输出  
  147.         for j=1:1:midnum  
  148.            I(j)=input_test(:,i)'*w1(j,:)'+b1(j);  
  149.             Iout(j)=1/(1+exp(double(-I(j))));  
  150.         end  
  151.         %输出层输出  
  152.         fore(:,i)=w2'*Iout'+b2;  
  153.     end  
  154.    
  155.    
  156.    
  157.    
  158. %% 结果分析  
  159. %根据网络输出找出数据属于哪类  
  160. fori=1:500  
  161.    output_fore(i)=find(fore(:,i)==max(fore(:,i)))-1;  
  162. end  
  163.    
  164. %BP网络预测误差  
  165. error=output_fore'-output1(n(4501:5000))';  
  166.    
  167.    
  168.    
  169. %画出预测数字和实际数字的分类图  
  170. figure(1)  
  171. plot(output_fore,'r')  
  172. holdon  
  173. plot(output1(n(4501:5000))','b')  
  174. legend('预测数字','实际数字')  
  175.    
  176. %画出误差图  
  177. figure(2)  
  178. plot(error)  
  179. title('BP网络分类误差','fontsize',12)  
  180. xlabel('输入数字','fontsize',12)  
  181. ylabel('分类误差','fontsize',12)  
  182.    
  183. k=zeros(1,10);   
  184. %找出判断错误的分类属于哪一类  
  185. fori=1:500  
  186.     if error(i)~=0  
  187.         [b,c]=max(output_test(:,i));  
  188.         switch c-1  
  189.             case 1  
  190.                 k(1)=k(1)+1;  
  191.             case 2  
  192.                 k(2)=k(2)+1;  
  193.             case 3  
  194.                 k(3)=k(3)+1;  
  195.             case 4  
  196.                 k(4)=k(4)+1;  
  197.             case 5  
  198.                 k(5)=k(5)+1;  
  199.             case 6  
  200.                 k(6)=k(6)+1;  
  201.             case 7  
  202.                 k(7)=k(7)+1;  
  203.             case 8  
  204.                 k(8)=k(8)+1;  
  205.             case 9  
  206.                 k(9)=k(9)+1;  
  207.             case 0  
  208.                 k(10)=k(10)+1;  
  209.         end  
  210.     end  
  211. end  
  212.    
  213. %找出每类的个体和  
  214. kk=zeros(1,10);  
  215. fori=1:500  
  216.     [b,c]=max(output_test(:,i));  
  217.     switch c-1  
  218.             case 1  
  219.                 kk(1)=kk(1)+1;  
  220.             case 2  
  221.                 kk(2)=kk(2)+1;  
  222.             case 3  
  223.                 kk(3)=kk(3)+1;  
  224.             case 4  
  225.                 kk(4)=kk(4)+1;  
  226.             case 5  
  227.                 kk(5)=kk(5)+1;  
  228.             case 6  
  229.                 kk(6)=kk(6)+1;  
  230.             case 7  
  231.                 kk(7)=kk(7)+1;  
  232.             case 8  
  233.                 kk(8)=k(8)+1;  
  234.             case 9  
  235.                 kk(9)=kk(9)+1;  
  236.             case 0  
  237.                 kk(10)=kk(10)+1;  
  238.     end  
  239. end  
  240.    
  241. %正确率  
  242. rightridio=(kk-k)./kk  
  243.    
  244. right=(sum(kk(:))-sum(k(:)))/sum(kk(:));  
  245.    
  246.    
  247. 用到到函数:BatchImg  
  248.    
  249.    
  250. function [imglist]=BatchReadImg(rootpath,grayflag)  
  251. if nargin<2  
  252.    disp('Not enough parameters!');  
  253.    return;  
  254. end  
  255.    
  256. filelist=dir(rootpath);  
  257. [filenum,temp]=size(filelist);  
  258. tempind=0;  
  259. imglist=cell(0);  
  260. for i=1:filenum  
  261.      
  262.    if strcmp(filelist(i).name,'.')|| strcmp(filelist(i).name,'..')||strcmp(filelist(i).name,'Desktop_1.ini')||strcmp(filelist(i).name,'Desktop_2.ini')  
  263.          
  264.    else  
  265.        tempind=tempind+1;  
  266.        imglist{tempind}=imread(strcat(rootpath,'/',filelist(i).name));  
  267.    end  
  268. end  
  269. if grayflag==1  
  270.    tempcount=size(imglist);  
  271.    for j=1:tempcount(2)  
  272.        imglist{j}=rgb2gray(imglist{j});  
  273.    end  
  274. end  
  275. edu_imgcrop:  
  276. function bw2 = edu_imgcrop(bw)  
  277.    
  278. %找到图像边界  
  279. [y2temp x2temp] = size(bw);  
  280. x1=1;  
  281. y1=1;  
  282. x2=x2temp;  
  283. y2=y2temp;  
  284.    
  285. % 找左边空白  
  286. cntB=1;  
  287. while (sum(bw(:,cntB))==y2temp)  
  288.    x1=x1+1;  
  289.    cntB=cntB+1;  
  290. end  
  291.    
  292. % 左边  
  293. cntB=1;  
  294. while (sum(bw(cntB,:))==x2temp)  
  295.    y1=y1+1;  
  296.    cntB=cntB+1;  
  297. end  
  298.    
  299. % 上边  
  300. cntB=x2temp;  
  301. while (sum(bw(:,cntB))==y2temp)  
  302.    x2=x2-1;  
  303.    cntB=cntB-1;  
  304. end  
  305.    
  306. % 下边  
  307. cntB=y2temp;  
  308. while (sum(bw(cntB,:))==x2temp)  
  309.    y2=y2-1;  
  310.    cntB=cntB-1;  
  311. end  
  312. bw2=imcrop(bw,[x1,y1,(x2-x1),(y2-y1)]);  
  313.    
  314.    
  315. edu_imgresize:  
  316. function lett = edu_imgresize(bw2)  
  317. % ======提取特征,转成5*7的特征矢量,把图像中每10*10的点进行划分相加,进行相加成一个点=====%  
  318. %======即统计每个小区域中图像象素所占百分比作为特征数据====%  
  319. bw_7050=imresize(bw2,[70,50]);  
  320. for cnt=1:7  
  321.    for cnt2=1:5  
  322.        Atemp=sum(bw_7050(((cnt*10-9):(cnt*10)),((cnt2*10-9):(cnt2*10))));%10*10box  
  323.        lett((cnt-1)*5+cnt2)=sum(Atemp);  
  324.    end  
  325. end  
  326. lett=((100-lett)/100);  
  327. lett=lett';  

完整的数据集和可以运行的Matlab代码下载

 邮箱联系:dawnminghuang@gmail.com

  • 7
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值