解读hinton对 RBM的matlab实现
关于RBM(受限波尔玆曼机)
受限波尔玆曼机是生成式模型。输入数据可以根据概率生成出来。
RBM通常用contrastive divergence来进行训练,这是Hinton在2002年提出来的。将在后续的博文中对其进行介绍。
RBM由两层组成,一层可见层,一层隐藏层。由隐藏层对可见层数据进行特征提取。可见层可以为二值或实数值,隐藏层为二值。网络的能量由以下公式定义:
E(v,h)=−∑i∈visibleaivi−∑j∈hiddenbjhj−∑i,jvihjwij...(1)
其中
vi
和
hj
是可见层单元
i
和隐藏层单元
所有可能的v,h组合出现联合概率的概率定义为:
p(v,h)=1Ze−E(v,h)...(2)
Z是所有可能的隐藏层和输入层组合v,h的和
Z=∑v,he−E(v,h)...(3)
网络输出可见层v的概率是边缘概率:
p(v)=1Z∑he−E(v,h)…(4)
网络得到一个给定训练图片的概率可以通过调整权重和偏置,降低这个图片的能量,增加其他图片的能量来提升。
∂logp(v)∂wi,j=<vihj>data−<vihj>model...(5)
Δwi,j=ϵ(<vihj>data−<vihj>model)...(6)
其中
ϵ
为学习速率
给定随机选择的训练数据v,隐藏层单元j被置为1的概率为
p(hj=1|v)=σ(bj+∑iviwij)...(7)
给定隐藏层,可见层单元
i
输出为1的概率为
一次Gabbs Sampling包括根据公式(7)更新隐藏层状态然后根据公式8更新可见层状态。Gabbs Sampling主要是用来计算
<vihj>model
<script type="math/tex" id="MathJax-Element-98">
_{model}</script>。进行n轮Gabbs Sampling计算
<vihj>model
<script type="math/tex" id="MathJax-Element-99">
_{model}</script>的方法被称为
CDn
Matlab代码解读
% 本程序用来训练受限波尔玆曼机(RBM)
% 可见层是二进制的输入,隐藏层是也是二进制的
% 使用对称权重来连接可见层和隐藏层
% 使用1步Contrastive Divergence进行训练
% 本程序假设一下变量已经存在
% maxepoch -- 最大训练迭代次数
% numhid -- 隐藏层节点的数量
% batchdata -- 以batch为单位的训练数据,训练数据的维度为(numcases numdims numbatches)
% restart --如果restart为1,则训练从头开始
epsilonw = 0.1; % 权重的学习速率
epsilonvb = 0.1; % 可见层偏置的学习速率
epsilonhb = 0.1; % 隐藏层偏置的学习速率
weightcost = 0.0002; % 权重惩罚系数
initialmomentum = 0.5; % 初始冲量
finalmomentum = 0.9; % 最终冲量
% 每个batch样例数量 可见层输入大小 batch的大小
[numcases numdims numbatches]=size(batchdata);
if restart ==1,
restart=0;
epoch=1;
% 初始化对称权重和可见层与隐藏层的偏置.
vishid = 0.1*randn(numdims, numhid);
hidbiases = zeros(1,numhid);
visbiases = zeros(1,numdims);
% Positive隐藏层的概率
poshidprobs = zeros(numcases,numhid);
% negative隐藏层的概率
neghidprobs = zeros(numcases,numhid);
% positive term
posprods = zeros(numdims,numhid);
% negative term
negprods = zeros(numdims,numhid);
% 权重的delta
vishidinc = zeros(numdims,numhid);
% 隐藏层偏置的delta
hidbiasinc = zeros(1,numhid);
% 可见层偏置的delta
visbiasinc = zeros(1,numdims);
% 每个batch Positive隐藏层的概率
batchposhidprobs=zeros(numcases,numhid,numbatches);
end
% 根据设置的最大迭代次数进行迭代训练
for epoch = epoch:maxepoch,
fprintf(1,'epoch %d\r',epoch);
errsum=0;
% 每个batch分别进行训练
for batch = 1:numbatches,
fprintf(1,'epoch %d batch %d\r',epoch,batch);
%%%%%%%%% 开始 POSITIVE 阶段 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
data = batchdata(:,:,batch);
% 隐藏层的概率为hj = sigmoid(sigma(Vi*Wij)+ hbias_i)
poshidprobs = 1./(1 + exp(-data*vishid - repmat(hidbiases,numcases,1)));
% 记录该batch下的隐藏层的概率
batchposhidprobs(:,:,batch)=poshidprobs;
% positive term : (num_dim num_cases) * (num_cases num_hid)=(num_dim num_cases)
% posprods[i][j] = 每个样本的可见层输入*样本隐藏层概率的和
posprods = data' * poshidprobs;
% 所有样本隐藏层概率的和
poshidact = sum(poshidprobs);
% 所有样本输入数据的和
posvisact = sum(data);
%%%%%%%%% POSITIVE 阶段的Gibbs采样->隐藏层的状态 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
poshidstates = poshidprobs > rand(numcases,numhid);
%%%%%%%%% 开始 NEGATIVE 阶段 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 重建可见层输入的到negdata
% sigmoid(隐藏层状态(二值) * 权重的转置 + 可见层偏置)
negdata = 1./(1 + exp(-poshidstates*vishid' - repmat(visbiases,numcases,1)));
% 由重建的negdata inference negative 隐藏层概率
neghidprobs = 1./(1 + exp(-negdata*vishid - repmat(hidbiases,numcases,1)));
% negative term (num_dim num_hid)
negprods = negdata'*neghidprobs;
% 将样例的隐藏层求和
neghidact = sum(neghidprobs);
% 将样例的可见层求和
negvisact = sum(negdata);
%%%%%%%%% NEGATIVE 阶段结束:计算重建可见层和实际可见层输入的平方和误差 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
err= sum(sum( (data-negdata).^2 ));
errsum = err + errsum;
% 如果迭代次数超过5次,冲量改为最终冲量
if epoch>5,
momentum=finalmomentum;
else
momentum=initialmomentum;
end;
%%%%%%%%% 更新权重和偏置 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 权重更新的公式见文章开头,这里的权重更新增加了冲量来应对局部最优解问题
% 加入了权重的惩罚系数来达到稀疏性的要求
vishidinc = momentum*vishidinc + ...
epsilonw*( (posprods-negprods)/numcases - weightcost*vishid);
% 偏置的delta就是可见层和隐藏层postive的概率-negative的概率
visbiasinc = momentum*visbiasinc + (epsilonvb/numcases)*(posvisact-negvisact);
hidbiasinc = momentum*hidbiasinc + (epsilonhb/numcases)*(poshidact-neghidact);
% 更新权重和偏置
vishid = vishid + vishidinc;
visbiases = visbiases + visbiasinc;
hidbiases = hidbiases + hidbiasinc;
%%%%%%%%%%%%%%%% END OF UPDATES %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
fprintf(1, 'epoch %4i error %6.1f \n', epoch, errsum);
end;