从0开始训练识别手写数字的神经网络

训练识别手写数字的神经网络

——基于matlab

一.前言

最近在学习神经网络,自己用matlab搭建了一个神经网络(我仅仅是为了自己熟悉这个过程,实际上,matlab自身带有很强大的神经网络工具箱,实际应用的时候可以不造轮子),能够识别手写数字。在这里把学习的心得和过程写出来,加深自己的理解,同时帮助有需要的朋友们。

邮箱:liliangwei@sjtu.edu.cn

微信:302566290

二.必要的知识

1.数学知识

完成所有这些内容需要掌握的数学知识有:矩阵、偏导数。

 

2.编程知识

你需要掌握Matlab的编程技巧。在这里我着重介绍一下我用到的一个比较特殊的函数(官方文档,看不懂英文的朋友可以百度,因为我matlab是英文版的,所以出来的都是英文版的,另外,我代码中所有的注释也都是英文,希望大家谅解,不过需要提醒大家的是,汉语和英语的理解能力在今后的社会将会越来越重要,因为中美两个国家将在各个领域发挥巨大的作用):

--cell:

 

Create cell array.
    cell(N) is an N-by-N cell array of empty matrices.
 
    cell(M,N) or cell([M,N]) is an M-by-N cell array of empty
    matrices.
 
    cell(M,N,P,...) or cell([M N P ...]) is an M-by-N-by-P-by-...
    cell array of empty matrices.
 
    cell(SIZE(A)) is a cell array the same size as A containing
    all empty matrices.

 

--reshape

reshape Reshape array.
    reshape(X,M,N) or reshape(X,[M,N]) returns the M-by-N matrix 
    whose elements are taken columnwise from X. An error results 
    if X does not have M*N elements.
 
    reshape(X,M,N,P,...) or reshape(X,[M,N,P,...]) returns an 
    N-D array with the same elements as X but reshaped to have 
    the size M-by-N-by-P-by-.... The product of the specified
    dimensions, M*N*P*..., must be the same as NUMEL(X).
 
    reshape(X,...,[],...) calculates the length of the dimension
    represented by [], such that the product of the dimensions 
    equals NUMEL(X). The value of NUMEL(X) must be evenly divisible 
    by the product of the specified dimensions. You can use only one 
    occurrence of [].

三.神经网络

关于神经网络的知识我就不在这里说了,在网上大家可以很清楚的看到。

四.代码实现

我一共写了5个函数以及一个脚本文件,如下:

1.sigmoid.m

这个函数是计算神经网络的激活函数,(我使用的是sigmoid函数),代码如下:

function res=sigmoid(z);
res=(1+exp(-z)).^(-1);
end

如果传入的参数是一个矩阵,那么它返回的也是对矩阵当中每一个元素求sigmoid之后所形成的新的矩阵。

2.sigmGrad.m

这个函数是对sigmoid的函数求导数,同样,当传入的参数是矩阵时它会返回一个新的矩阵。代码如下:

function res=sigmGrad(z);
res=sigmoid(z).*(1-sigmoid(z));
end

3.display_pixel.m

这个函数和神经网络本身没有关系,但是我写这么一个函数是为了把数据样本可视化。也就是说,它会把手写字体直观地展示出来。

代码如下:

function display_pixel(X,pixel_width)
[m,n]=size(X);
if ~exist('pixel_width','var')
    if sqrt(n)~=round(sqrt(n))
        fprintf('Please input the number of pixel of each digit');
    else
        pixel_width=round(sqrt(n));
    end
end
colormap(gray);
num_eg_each_cls=round(m/10);
show_mat=[];
for i =1:10
    show_mat_temp=[];
    for j=1:num_eg_each_cls
        new_sqr=reshape(X((i-1)*num_eg_each_cls+j,:),pixel_width,pixel_width);
        show_mat_temp=[show_mat_temp new_sqr];
    end
    show_mat=[show_mat;show_mat_temp];
end
h=imagesc(show_mat);
end

效果图如下:


 

 

这个函数需要传入两个参数,第一个X是我们想要可视化的数据,第二个参数是每一个数字所占据的像素点的个数,如果这个参数缺失,那么我默认传入数据矩阵的第二维的平方根是这个宽度。

4.weights_random.m

这个函数是用来初始化神经网络的各个参数的,代码如下:

 

function w=weights_random(input_units,hidden_units,hidden_layers,output_units);
epsl=sqrt(6)/sqrt(input_units+output_units);
w=[];
for i =1:hidden_layers+1
    if i==1
        temp=-epsl*rand(hidden_units,input_units+1)+2*epsl;
        w=[w;temp(:)];
    elseif i<=hidden_layers
        temp=-epsl*rand(hidden_units,hidden_units+1)+2*epsl;
        w=[w;temp(:)];
    else
        temp=-epsl*rand(output_units,hidden_units+1)+2*epsl;
        w=[w;temp(:)];
    end
end
end

这个函数需要传入四个参数,

input_units     是神经网络输入层的unit个数,也就是训练样本的feature个数。
hidden_units    是神经网络隐层unit的个数
hidden_layers   是神经网络隐层的层数
output_units    是神经网络输出层的unit的个数,也就是分类问题标签的个数,在这里我们需要对十种手写数字进行识别, 因此这个个数是10(0——9)

最后这个函数会输出一个长向量(1维)。这样做的目的是当我们在寻求最优解的时候不选择手动实现梯度下降而选择使用matlab自带的fmincg函数时,可以直接传入这个参数。

5.cost_grad.m

这个函数是我们神经网络的关键,因为在一个神经网络中,我们要知道代价函数以及每一个参数关于代价函数的偏导数。代价函数很好理解,就是说对于一个给定的输入,我们这个神经网络会做出一个预测,那么怎么评判我们的神经网络是否是一个较好的模型呢?我们引入代价函数,这个代价函数反应的是模型预测结果和实际结果之间的偏差。我们希望这个代价函数尽可能的小,但是这样又会存在过拟合的情况,因此我们还要引入正则项,这里我就不展开了。而求解的方法就是根据每个参数关于代价函数的偏导数完成的。

代码如下:

function [Cost,Grad_Theta]=cost_grad(theta,input_units,hidden_units,hidden_layers,...
    output_units,X,y,lambda);
X=double(X);   %convert type
y=double(y);
Cost=0;
Theta=cell(hidden_layers+1,1);    %unroll the long vector of theta
for i=1:hidden_layers+1     %to extract the Theta(i,j) from the long vector of Theta
    if hidden_layers>1
        if i==1
            Theta{i}=reshape(theta(1:hidden_units*(input_units+1)),hidden_units,input_units+1);
        elseif i<=hidden_layers
            Theta{i}=reshape(theta(hidden_units*(input_units+1)+(i-2)*hidden_units*(hidden_units+1)+1:...
                hidden_units*(input_units+1)+(i-1)*hidden_units*(hidden_units+1)),hidden_units,hidden_units+1);
        else
            Theta{i}=reshape(theta(hidden_units*(input_units+1)+(hidden_layers-1)*hidden_units*(hidden_units+1)+1:...
                end),output_units,hidden_units+1);
        end
    else
        Theta{1}=reshape(theta(1:hidden_units*(input_units+1)),hidden_units,input_units+1);
        Theta{2}=reshape(theta(hidden_units*(input_units+1)+1:end),output_units,hidden_units+1);
    end
end
%forward propagation for calculating a 

[m,n]=size(X);
%celldisp(a);
%celldisp(z);
a{1}=X;
a{1}=[ones(m,1) a{1}];
for i=1:hidden_layers+1
    if i~=hidden_layers+1
        z{i+1}=a{i}*Theta{i}';      %input values of (i+1)th layer
        a{i+1}=sigmoid(z{i+1});
        a{i+1}=[ones(size(a{i+1},1),1) a{i+1}];
    else 
        z{i+1}=a{i}*Theta{i}';       %input values of (i+1)th layer
        a{i+1}=sigmoid(z{i+1});
    end
end
%generate the y matrix
y_mat=zeros(m,output_units);
for i=1:output_units
    if i==10
        y_mat(:,i)=(y==10);
    else
        y_mat(:,i)=(y==i);
    end
end
%cost function
for i=1:m
    Cost_tem(i)=-1/m*(y_mat(i,:)*log(a{hidden_layers+2}(i,:))'+(1-y_mat(i,:))*log(1-a{hidden_layers+2}(i,:))');
end
Cost=sum(Cost_tem);
%regularization for cost function
for l=1:hidden_layers+1
    for i=1:size(Theta{l},1)
        for j=2:size(Theta{l},2)
            Cost=Cost+lambda/(2*m)*Theta{l}(i,j)^2;
        end
    end
end
%back propagation for calculating T
T=cell(hidden_layers+1,1);
Grad_Theta=cell(size(Theta));
%celldisp(T);
%celldisp(Grad_Theta);
theta_tem=Theta;
for k=2:hidden_layers+1
    theta_tem{k}(:,1)=[];
end
for k=1:size(Theta,1)
    Grad_Theta{k}=zeros(size(Theta{k}));
end
for i=1:m
    for l=0:hidden_layers
        if l==0
            T{hidden_layers+2-l}=a{hidden_layers+2-l}(i,:)'-y_mat(i,:)';
            Grad_Theta{1+hidden_layers-l}=Grad_Theta{1+hidden_layers-l}+T{hidden_layers+2-l}*a{1+hidden_layers-l}(i,:);
        else
            %size(theta_tem{hidden_layers+2-l}),size(T{hidden_layers+3-l}),size(z{hidden_layers+2-l})
            T{hidden_layers+2-l}=theta_tem{hidden_layers+2-l}'*T{hidden_layers+3-l}.*sigmGrad(z{hidden_layers+2-l}(i,:)');
            Grad_Theta{1+hidden_layers-l}=Grad_Theta{1+hidden_layers-l}+T{hidden_layers+2-l}*a{1+hidden_layers-l}(i,:);
        end
        
    end
            
end
sigmGrad(z{hidden_layers+2-l}(i,:));
z{hidden_layers+2-l}(1,:);
%calculate the partial derivatives of the cost function 
tem=[];
%regularization for Grad_Theta
for l=1:hidden_layers+1
    for i=1:size(Theta{l},1)
        for j=2:size(Theta{l},2)
            Grad_Theta{l}(i,j)=Grad_Theta{l}(i,j)+lambda/m*Theta{l}(i,j);
        end
    end
end
for i=1:size(Grad_Theta,1)
    tem=[tem;Grad_Theta{i}(:)];
end
Grad_Theta=1/m*tem;
end

这个函数同样需要传入一些参数。

theta 这个是神经网络的各个参数,是一个一维的长向量,相当于我们对神经网络的每一个参数进行赋值 

 

input_units,hidden_units,hidden_layers,output_units 同上
X  这个是训练样本

 

y  这个是训练样本的标签

 

lambda  这个是正则化项前面的系数

 

这个函数返回两个结果:

第一个是在输入的参数情况之下,我们这个神经网络的代价函数,当然我们希望这个值越小越好。

第二个是每一个参数关于代价函数的偏导数,这也是一个长向量。
我们在使用梯度下降或者CG或者BFGS算法来优化代价函数的时候会用到这两个返回值。

6.NN_digit_rec.m

这个脚本文件就是把之前所有的函数运行一遍,大家很容易就可以看明白,直接上代码吧:

clear ; close all; clc
load('my_train.mat');
load('y.mat');
load('testX.mat');
load('testy.mat');
load('ex4weights.mat');
%initialisation


input_units=784;
hidden_units=30;
hidden_layers=1;
output_units=10;
theta=weights_random(input_units,hidden_units,hidden_layers,output_units);  %this returns a column vector

lambda=1;
alpha=0.3;    %learning rate
%Using gradient descent to train coefficients
[Cost,Grad_Theta]=cost_grad(theta,input_units,hidden_units,hidden_layers,output_units,my_train,y,lambda);
for i=1:10000
    for j=1:length(theta)
        theta(j)=theta(j)-alpha*Grad_Theta(j);
    end                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
    [Cost(i),Grad_Theta]=cost_grad(theta,input_units,hidden_units,hidden_layers,output_units,my_train,y,lambda);
    fprintf('----Iteration%d,Cost=%f----',i,Cost(i)) 
    if i>1
        if Cost(i)-Cost(i-1)>0
            fprintf('Increase %f%%.\n',(Cost(i)-Cost(i-1))/Cost(i-1))
        else
            fprintf('Reduce %f%%.\n',-(Cost(i)-Cost(i-1))/Cost(i-1))
        end
    else 
        fprintf('\n')
    end
end
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

%Using advanced optimizaton algorithm  

% options=optimset('Gradobj','on','MaxIter',1000);
% [opttheta,Cost,flag]=fmincg(@(t)cost_grad(t,input_units,hidden_units,hidden_layers,output_units,X,y,lambda),theta,...
%     options);


%predict
rate=my_predict(theta,testX,testy,input_units,hidden_units,hidden_layers,output_units);
fprintf('The accuracy rate is %f%%.\n',rate)

fprintf('Theta=')
theta

运行效果如下:


当经过一段时间的迭代以后,我们就会得到一组合适的参数,我们的神经网络也就算训练好了。

 

五.源代码

这个神经网络的源码全部都在我的github上,地址如下:

李良伟的github

把整个文件下载下来(附带有训练和测试样本数据),用matlab或者octave都可以直接跑。

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值