DeepLearningToolBox学习——RBM(Restrict Boltzmann Machine)II

下载地址:DeepLearningToolBox

     学习RBM代码之前,需要一些基本的RBM的知识。

     网上有很多参考资料,我借鉴一篇写的很好的系列文章,看下来也差不多能看懂了,博客地址:http://blog.csdn.net/itplus/article/details/19168937


 目录如下

(一)预备知识

(二)网络结构

(三)能量函数和概率分布

(四)对数似然函数

(五)梯度计算公式

(六)对比散度算法

(七)RBM 训练算法

(八)RBM 的评估


通过学习上面的系列文章,基本上理解了RBM的原理,接下来动手学习toolbox中对应的RBM第二部分代码。

本文一下内容转自博客:原文地址


1.mnistclassify.m


MINST手写数字库的识别部分和前面的降维部分其实很相似。首先它也是预训练整个网络,只不过在MINST识别时,预训练的网络部分需要包括输出softmax部分,且这部分预训练时是用的有监督方法的。在微调部分的不同体现在:MINST降维部分是用的无监督方法,即数据的标签为原始的输入数据。而MINST识别部分数据的标签为训练样本的实际标签

在进行MINST手写数字体识别的时候,需要计算加入了softmax部分的网络的代价函数,作者的程序中给出了2个函数。其中第一个函数用于预训练softmax分类器:

  function [f, df] = CG_CLASSIFY_INIT(VV,Dim,w3probs,target);

  该函数是专门针对softmax分类器那部分预训练用的,因为一开始的rbm预训练部分没有包括输出层softmax网络。输入参数VV表示整个网络的权值向量(也包括了softmax那一部分),Dim为sofmmax对应部分的2层网络节点个数的向量,w3probs为训练softmax所用的样本集,target为对应样本集的标签。f和df分别为softmax网络的代价函数和代价函数的偏导数。

  另一个才是真正的计算网络微调的代价函数:

  function [f, df] = CG_CLASSIFY(VV,Dim,XX,target);

  函数输入值VV代表网络的参数向量,Dim为每层网络的节点数向量,XX为训练样本集,target为训练样本集的标签,f和df分别为整个网络的代价函数以及代价函数的偏导数。

    实验结果:

  作者采用的1个输入层,3个隐含层和一个softmax分类层的输出层,网络的节点数依次为:784-500-500-2000-10。

  其最终识别的错误率为:1.2%.


% Version 1.000
%
% Code provided by Ruslan Salakhutdinov and Geoff Hinton  
%
% Permission is granted for anyone to copy, use, modify, or distribute this
% program and accompanying programs and documents for any purpose, provided
% this copyright notice is retained and prominently displayed, along with
% a note saying that the original programs are available from our 
% web page. 
% The programs and documents are distributed without any warranty, express or
% implied.  As the programs were written for research purposes only, they have
% not been tested to the degree that would be advisable in any important
% application.  All use of these programs is entirely at the user's own risk.


% This program pretrains a deep autoencoder for MNIST dataset
% You can set the maximum number of epochs for pretraining each layer
% and you can set the architecture of the multilayer net.
clc
clear all
close all

maxepoch=5; %50
numhid=500; numpen=500; numpen2=2000; 

fprintf(1,'Converting Raw files into Matlab format \n');
converter; 

fprintf(1,'Pretraining a deep autoencoder. \n');
fprintf(1,'The Science paper used 50 epochs. This uses %3i \n', maxepoch);

makebatches;
[numcases numdims numbatches]=size(batchdata);

fprintf(1,'Pretraining Layer 1 with RBM: %d-%d \n',numdims,numhid);
restart=1;
rbm;
hidrecbiases=hidbiases; 
save mnistvhclassify vishid hidrecbiases visbiases;

fprintf(1,'\nPretraining Layer 2 with RBM: %d-%d \n',numhid,numpen);
batchdata=batchposhidprobs;
numhid=numpen;
restart=1;
rbm;
hidpen=vishid; penrecbiases=hidbiases; hidgenbiases=visbiases;
save mnisthpclassify hidpen penrecbiases hidgenbiases;

fprintf(1,'\nPretraining Layer 3 with RBM: %d-%d \n',numpen,numpen2);
batchdata=batchposhidprobs;
numhid=numpen2;
restart=1;
rbm;
hidpen2=vishid; penrecbiases2=hidbiases; hidgenbiases2=visbiases;
save mnisthp2classify hidpen2 penrecbiases2 hidgenbiases2;

backpropclassify; 

2. backpropclassfiy.m

maxepoch=200;
fprintf(1,'\nTraining discriminative model on MNIST by minimizing cross entropy error. \n');
fprintf(1,'60 batches of 1000 cases each. \n');

load mnistvhclassify %载入3个rbm网络的预训练好了的权值
load mnisthpclassify
load mnisthp2classify

makebatches;
[numcases numdims numbatches]=size(batchdata);
N=numcases; 

%%%% PREINITIALIZE WEIGHTS OF THE DISCRIMINATIVE MODEL%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

w1=[vishid; hidrecbiases];
w2=[hidpen; penrecbiases];
w3=[hidpen2; penrecbiases2];
w_class = 0.1*randn(size(w3,2)+1,10); %因为要分类,所以最后一层直接输出10个节点,类似softmax分类器
 

%%%%%%%%%% END OF PREINITIALIZATIO OF WEIGHTS  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

l1=size(w1,1)-1;
l2=size(w2,1)-1;
l3=size(w3,1)-1;
l4=size(w_class,1)-1;
l5=10; 
test_err=[];
train_err=[];


for epoch = 1:maxepoch %200

%%%%%%%%%%%%%%%%%%%% COMPUTE TRAINING MISCLASSIFICATION ERROR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
err=0; 
err_cr=0;
counter=0;
[numcases numdims numbatches]=size(batchdata);
N=numcases;
 for batch = 1:numbatches
  data = [batchdata(:,:,batch)];
  target = [batchtargets(:,:,batch)];
  data = [data ones(N,1)];
  w1probs = 1./(1 + exp(-data*w1)); w1probs = [w1probs  ones(N,1)];
  w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)];
  w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs  ones(N,1)];
  targetout = exp(w3probs*w_class);
  targetout = targetout./repmat(sum(targetout,2),1,10); %softmax分类器

  [I J]=max(targetout,[],2);%J是索引值
  [I1 J1]=max(target,[],2);
  counter=counter+length(find(J==J1));% length(find(J==J1))表示为预测值和网络输出值相等的个数
  err_cr = err_cr- sum(sum( target(:,1:end).*log(targetout))) ;
 end
 train_err(epoch)=(numcases*numbatches-counter);%每次迭代的训练误差
 train_crerr(epoch)=err_cr/numbatches;

%%%%%%%%%%%%%% END OF COMPUTING TRAINING MISCLASSIFICATION ERROR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%%%%%%%%%%%%%%%%%%%% COMPUTE TEST MISCLASSIFICATION ERROR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
err=0;
err_cr=0;
counter=0;
[testnumcases testnumdims testnumbatches]=size(testbatchdata);
N=testnumcases;
for batch = 1:testnumbatches
  data = [testbatchdata(:,:,batch)];
  target = [testbatchtargets(:,:,batch)];
  data = [data ones(N,1)];
  w1probs = 1./(1 + exp(-data*w1)); w1probs = [w1probs  ones(N,1)];
  w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)];
  w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs  ones(N,1)];
  targetout = exp(w3probs*w_class);
  targetout = targetout./repmat(sum(targetout,2),1,10);

  [I J]=max(targetout,[],2);
  [I1 J1]=max(target,[],2);
  counter=counter+length(find(J==J1));
  err_cr = err_cr- sum(sum( target(:,1:end).*log(targetout))) ;
end
 test_err(epoch)=(testnumcases*testnumbatches-counter); %测试样本的误差,这都是在预训练基础上得到的结果
 test_crerr(epoch)=err_cr/testnumbatches;
 fprintf(1,'Before epoch %d Train # misclassified: %d (from %d). Test # misclassified: %d (from %d) \t \t \n',...
            epoch,train_err(epoch),numcases*numbatches,test_err(epoch),testnumcases*testnumbatches);

%%%%%%%%%%%%%% END OF COMPUTING TEST MISCLASSIFICATION ERROR %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

 tt=0;
 for batch = 1:numbatches/10
 fprintf(1,'epoch %d batch %d\r',epoch,batch);

%%%%%%%%%%% COMBINE 10 MINIBATCHES INTO 1 LARGER MINIBATCH %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 tt=tt+1; 
 data=[];
 targets=[]; 
 for kk=1:10
  data=[data 
        batchdata(:,:,(tt-1)*10+kk)]; 
  targets=[targets
        batchtargets(:,:,(tt-1)*10+kk)];
 end 

%%%%%%%%%%%%%%% PERFORM CONJUGATE GRADIENT WITH 3 LINESEARCHES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  max_iter=3;

  if epoch<6  % First update top-level weights holding other weights fixed. 前6次迭代都是针对softmax部分的预训练
    N = size(data,1);
    XX = [data ones(N,1)];
    w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs  ones(N,1)];
    w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)];
    w3probs = 1./(1 + exp(-w2probs*w3)); %w3probs = [w3probs  ones(N,1)];

    VV = [w_class(:)']';
    Dim = [l4; l5];
    [X, fX] = minimize(VV,'CG_CLASSIFY_INIT',max_iter,Dim,w3probs,targets);
    w_class = reshape(X,l4+1,l5);

  else
    VV = [w1(:)' w2(:)' w3(:)' w_class(:)']';
    Dim = [l1; l2; l3; l4; l5];
    [X, fX] = minimize(VV,'CG_CLASSIFY',max_iter,Dim,data,targets);

    w1 = reshape(X(1:(l1+1)*l2),l1+1,l2);
    xxx = (l1+1)*l2;
    w2 = reshape(X(xxx+1:xxx+(l2+1)*l3),l2+1,l3);
    xxx = xxx+(l2+1)*l3;
    w3 = reshape(X(xxx+1:xxx+(l3+1)*l4),l3+1,l4);
    xxx = xxx+(l3+1)*l4;
    w_class = reshape(X(xxx+1:xxx+(l4+1)*l5),l4+1,l5);

  end
%%%%%%%%%%%%%%% END OF CONJUGATE GRADIENT WITH 3 LINESEARCHES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%

 end

 save mnistclassify_weights w1 w2 w3 w_class
 save mnistclassify_error test_err test_crerr train_err train_crerr;

end


3. CG_CLASSIFY_INIT.m

function [f, df] = CG_CLASSIFY_INIT(VV,Dim,w3probs,target);%只有2层网络
l1 = Dim(1);
l2 = Dim(2);
N = size(w3probs,1);%N为训练样本的个数
% Do decomversion.
  w_class = reshape(VV,l1+1,l2);
  w3probs = [w3probs  ones(N,1)];  

  targetout = exp(w3probs*w_class);
  targetout = targetout./repmat(sum(targetout,2),1,10);
  f = -sum(sum( target(:,1:end).*log(targetout))) ;%f位softmax分类器的误差函数
IO = (targetout-target(:,1:end));
Ix_class=IO; 
dw_class =  w3probs'*Ix_class; %偏导值

df = [dw_class(:)']';

4.CG_CLASSIFY.m

function [f, df] = CG_CLASSIFY(VV,Dim,XX,target);

l1 = Dim(1);
l2 = Dim(2);
l3= Dim(3);
l4= Dim(4);
l5= Dim(5);
N = size(XX,1);

% Do decomversion.
 w1 = reshape(VV(1:(l1+1)*l2),l1+1,l2);
 xxx = (l1+1)*l2;
 w2 = reshape(VV(xxx+1:xxx+(l2+1)*l3),l2+1,l3);
 xxx = xxx+(l2+1)*l3;
 w3 = reshape(VV(xxx+1:xxx+(l3+1)*l4),l3+1,l4);
 xxx = xxx+(l3+1)*l4;
 w_class = reshape(VV(xxx+1:xxx+(l4+1)*l5),l4+1,l5);


  XX = [XX ones(N,1)];
  w1probs = 1./(1 + exp(-XX*w1)); w1probs = [w1probs  ones(N,1)];
  w2probs = 1./(1 + exp(-w1probs*w2)); w2probs = [w2probs ones(N,1)];
  w3probs = 1./(1 + exp(-w2probs*w3)); w3probs = [w3probs  ones(N,1)];

  targetout = exp(w3probs*w_class);
  targetout = targetout./repmat(sum(targetout,2),1,10);
  f = -sum(sum( target(:,1:end).*log(targetout))) ;

IO = (targetout-target(:,1:end));
Ix_class=IO; 
dw_class =  w3probs'*Ix_class; 

Ix3 = (Ix_class*w_class').*w3probs.*(1-w3probs);
Ix3 = Ix3(:,1:end-1);
dw3 =  w2probs'*Ix3;

Ix2 = (Ix3*w3').*w2probs.*(1-w2probs); 
Ix2 = Ix2(:,1:end-1);
dw2 =  w1probs'*Ix2;

Ix1 = (Ix2*w2').*w1probs.*(1-w1probs); 
Ix1 = Ix1(:,1:end-1);
dw1 =  XX'*Ix1;

df = [dw1(:)' dw2(:)' dw3(:)' dw_class(:)']';

实验总结:

   1. 终于阅读了一个RBM的源码了,以前看那些各种公式的理论,现在有了对应的code,读对应的code起来就是爽!

   2. 这里由于用的是整个图片进行训练(不是用的它们的patch部分),所以没有对应的convolution和pooling,因此预训练网络结构时下一个rbm网络的输入就是上一个rbm网络的输出,且当没有加入softmax时的微调阶段用的依旧是无监督的学习(此时的标签依旧为原始的输入数据);而当加入了softmax后的微调部分用的就是训练样本的真实标签了,因为此时需要进行分类。

   3. 深度越深,则网络的微调时间越长,需要很多时间收敛,即使是进行了预训练。

   4. 暂时还没弄懂要是针对大图片采用covolution训练时,第二层网络的数据来源是什么,有可能和上面的一样,是上层网络的输出(但是此时微调怎么办呢,不用标签数据?)也有可能是大图片经过第一层网络covolution,pooling后的输出值(如果是这样的话,网络的代价函数就不好弄了,因为里面有convolution和pooling操作)。



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值