受限玻尔兹曼机RBM实现及能量值思考——matlab实现

版权声明:欢迎大家一起交流,有错误谢谢指正~~~多句嘴,不要复制代码,因为CSDN排版问题,有些东西会自动加入乱糟糟的字符,最好是自己手写代码。格外注意被“踩”的博客,可能有很大问题,请自行查找大牛们的教程,以免被误导。最后,在确认博客理论正确性的前提下,随意转载,知识大家分享。 https://blog.csdn.net/zb1165048017/article/details/50876220

网址:http://www.cs.toronto.edu/~hinton/MatlabForSciencePaper.html

这个代码主要是在mnist上做手写数字识别的代码,贴出来的目的主要是想研究一下在迭代过程中能量的变化情况。

1. 标准能量函数

标准的能量函数的表达式为:


那么就将这个表达式放在每一批次迭代的末尾,然后将所有批次迭代一次的结果累加就是当前迭代次数的能量值。

我写的表达式如下:

  energy=energy-sum(sum(negdata*vishid.*poshidstates))-...
      sum(sum(negdata.*repmat(visbiases,size(negdata,1),1)))-...
      sum(sum(poshidstates.*repmat(hidbiases,size(poshidstates,1),1)));
具体位置截图如下


【注】肯能还有更高效简洁的写法,只要保重RBM对应的二部图中的每一条连接线都有计算入能量函数中去就行。

第一层迭代了五十次的结果如下

Pretraining Layer 1 with RBM: 784-500 
epoch 1
epoch    1 error 906233.0 energy -11099587.215172 
epoch 2
epoch    2 error 557528.5 energy -30047032.045595 
epoch 3
epoch    3 error 496027.4 energy -53112598.240393 
epoch 4
epoch    4 error 468267.6 energy -78539687.671057 
epoch 5
epoch    5 error 452678.9 energy -105299865.601975 
epoch 6
epoch    6 error 478165.1 energy -129373309.014318 
epoch 7
epoch    7 error 454438.2 energy -152106055.038886 
epoch 8
epoch    8 error 438489.1 energy -174308595.910380 
epoch 9
epoch    9 error 428747.2 energy -195910590.668272 
epoch 10
epoch   10 error 421810.9 energy -217028311.853859 
epoch 11
epoch   11 error 416392.9 energy -237610724.670849 
epoch 12
epoch   12 error 411178.5 energy -257759919.144502 
epoch 13
epoch   13 error 408115.5 energy -277438854.842965 
epoch 14
epoch   14 error 405091.1 energy -296657959.425424 
epoch 15
epoch   15 error 403533.2 energy -315397871.648862 
epoch 16
epoch   16 error 400616.4 energy -333710720.884443 
epoch 17
epoch   17 error 398482.7 energy -351563160.004860 
epoch 18
epoch   18 error 396461.9 energy -369009936.125511 
epoch 19
epoch   19 error 395664.3 energy -386036922.651997 
epoch 20
epoch   20 error 394134.4 energy -402656725.716640 
epoch 21
epoch   21 error 392316.9 energy -418988195.512326 
epoch 22
epoch   22 error 391985.3 energy -435041246.765398 
epoch 23
epoch   23 error 390845.3 energy -450774503.774346 
epoch 24
epoch   24 error 389345.9 energy -466241495.153597 
epoch 25
epoch   25 error 388229.8 energy -481430774.200754 
epoch 26
epoch   26 error 387980.9 energy -496349569.785631 
epoch 27
epoch   27 error 386886.4 energy -510986252.435888 
epoch 28
epoch   28 error 385948.6 energy -525397051.912323 
epoch 29
epoch   29 error 385386.0 energy -539599081.300994 
epoch 30
epoch   30 error 385348.8 energy -553574564.640536 
epoch 31
epoch   31 error 385159.7 energy -567319607.941765 
epoch 32
epoch   32 error 384260.7 energy -580846734.110570 
epoch 33
epoch   33 error 383174.4 energy -594214877.616236 
epoch 34
epoch   34 error 382934.3 energy -607405266.358145 
epoch 35
epoch   35 error 382697.1 energy -620430109.937463 
epoch 36
epoch   36 error 382095.5 energy -633260002.931164 
epoch 37
epoch   37 error 381486.8 energy -645950143.694427 
epoch 38
epoch   38 error 380661.3 energy -658521662.217831 
epoch 39
epoch   39 error 379594.0 energy -670919982.496346 
epoch 40
epoch   40 error 379996.7 energy -683233935.943210 
epoch 41
epoch   41 error 379679.9 energy -695420009.453536 
epoch 42
epoch   42 error 379632.1 energy -707458130.547251 
epoch 43
epoch   43 error 379168.7 energy -719365334.031710 
epoch 44
epoch   44 error 378787.6 energy -731171831.263565 
epoch 45
epoch   45 error 378450.0 energy -742883647.550836 
epoch 46
epoch   46 error 379054.0 energy -754479143.953231 
epoch 47
epoch   47 error 377713.5 energy -765939049.101519 
epoch 48
epoch   48 error 377721.1 energy -777334241.922596 
epoch 49
epoch   49 error 376576.0 energy -788647805.279271 
epoch 50
epoch   50 error 377725.3 energy -799861904.731226
可以有两点发现:

1、能量是负值

2、能量是不断减少,或者更形象点是逐渐降低的,也就是说绝对值是不断增大的,越来越远离0

这个也就说明了能量越低,模型越稳定

【注】其实我刚开始一直是以为能量值不断趋近于0的,实验结果证明了我的想法是错的。读者有任何想法也可以在评论区讨论讨论。有其他比较有价值的规律谢谢大家一起来完善

为了进一步验证这个结论,下面列出第二层的能量值变化

Pretraining Layer 2 with RBM: 500-500 
epoch 1
epoch    1 error 987086.0 energy -9062710.051405 
epoch 2
epoch    2 error 665834.8 energy -22633241.117738 
epoch 3
epoch    3 error 610662.3 energy -36830053.019285 
epoch 4
epoch    4 error 572091.6 energy -51281490.520769 
epoch 5
epoch    5 error 540614.1 energy -65915030.623658 
epoch 6
epoch    6 error 508557.3 energy -79431130.648043 
epoch 7
epoch    7 error 434258.1 energy -92467710.762025 
epoch 8
epoch    8 error 399902.2 energy -105430274.786403 
epoch 9
epoch    9 error 383352.3 energy -118255093.037551 
epoch 10
epoch   10 error 374529.1 energy -130904210.951242 
epoch 11
epoch   11 error 368514.4 energy -143354512.397248 
epoch 12
epoch   12 error 363563.4 energy -155627227.699558 
epoch 13
epoch   13 error 360383.7 energy -167717570.110705 
epoch 14
epoch   14 error 358861.5 energy -179615072.973138 
epoch 15
epoch   15 error 356927.2 energy -191293665.350483 
epoch 16
epoch   16 error 354678.5 energy -202753482.269224 
epoch 17
epoch   17 error 353843.7 energy -214037964.189715 
epoch 18
epoch   18 error 352791.3 energy -225155932.564075 
epoch 19
epoch   19 error 351835.7 energy -236104963.909265 
epoch 20
epoch   20 error 351010.4 energy -246873022.282188 
epoch 21
epoch   21 error 350847.1 energy -257470705.809567 
epoch 22
epoch   22 error 349911.6 energy -267900561.925812 
epoch 23
epoch   23 error 349589.7 energy -278184302.605534 
epoch 24
epoch   24 error 349400.3 energy -288317945.532343 
epoch 25
epoch   25 error 349051.0 energy -298312328.865682 
epoch 26
epoch   26 error 348249.4 energy -308184434.905750 
epoch 27
epoch   27 error 348383.1 energy -317909079.840016 
epoch 28
epoch   28 error 347618.3 energy -327508125.959673 
epoch 29
epoch   29 error 347152.3 energy -336984710.426838 
epoch 30
epoch   30 error 346707.3 energy -346366295.669933 
epoch 31
epoch   31 error 347245.8 energy -355642353.726434 
epoch 32
epoch   32 error 346502.2 energy -364798534.538484 
epoch 33
epoch   33 error 346450.0 energy -373854688.256938 
epoch 34
epoch   34 error 346358.1 energy -382818178.531560 
epoch 35
epoch   35 error 346985.5 energy -391695106.532281 
epoch 36
epoch   36 error 346818.4 energy -400478029.118356 
epoch 37
epoch   37 error 346446.7 energy -409183885.047261 
epoch 38
epoch   38 error 346409.3 energy -417793622.658515 
epoch 39
epoch   39 error 346342.5 energy -426331454.717778 
epoch 40
epoch   40 error 345685.5 energy -434801072.747726 
epoch 41
epoch   41 error 345374.3 energy -443199522.909101 
epoch 42
epoch   42 error 345590.6 energy -451522224.505664 
epoch 43
epoch   43 error 345611.4 energy -459769819.711364 
epoch 44
epoch   44 error 345459.2 energy -467949310.324651 
epoch 45
epoch   45 error 345785.0 energy -476052215.833106 
epoch 46
epoch   46 error 345640.1 energy -484099406.063227 
epoch 47
epoch   47 error 344962.8 energy -492075533.907423 
epoch 48
epoch   48 error 344727.7 energy -500000744.350654 
epoch 49
epoch   49 error 345307.3 energy -507865250.165273 
epoch 50
epoch   50 error 345707.6 energy -515675253.128295 
可以发现能量值依旧遵循上面两条规律。

2. 自由能量函数

如果看RBM相关论文,经常会提到一个东东叫做自由能量函数(free energy)


这个是一篇关于用玻尔兹曼机做分类的文章中的一个式子,可以去我的CSDN资源中下载

这里面的softplus也是一种函数类型,类似于ReLu、sigmoid等,Softplus函数是Logistic-Sigmoid函数原函数。


按照这个公式,可以抽取一下普通的RBM中的能量函数,在一位作者的博客中有相关解释,这里搬过来


其实上式在机器之心的某篇文章,戳这里也有提到, 有兴趣可以瞅瞅

当隐藏层单元值为二值情况时候,很容易证明出


代码书写如下:

energy=energy-sum(sum(negdata.*repmat(visbiases,size(negdata,1),1)))-...
    sum(sum( log(1+exp(negdata*vishid+repmat(hidbiases,size(poshidstates,1),1))) ));
迭代第一次的结果:

Pretraining Layer 1 with RBM: 784-500 
epoch 1
epoch    1 error 905276.0 energy -19881397.779961 
epoch 2
epoch    2 error 558779.6 energy -47693109.210766 
epoch 3
epoch    3 error 495834.6 energy -79830190.885475 
epoch 4
epoch    4 error 468697.6 energy -114398712.137404 
epoch 5
epoch    5 error 453039.3 energy -150310200.560388 
epoch 6
epoch    6 error 478686.3 energy -182605142.786704 
epoch 7
epoch    7 error 451919.2 energy -213180793.202729 
epoch 8
epoch    8 error 438010.1 energy -242816234.242592 
epoch 9
epoch    9 error 427760.4 energy -271590223.327619 
epoch 10
epoch   10 error 420847.2 energy -299660032.178816 
epoch 11
epoch   11 error 415645.8 energy -326954499.696240 
epoch 12
epoch   12 error 410386.1 energy -353688087.921110 
epoch 13
epoch   13 error 408491.9 energy -379773439.302978 
epoch 14
epoch   14 error 404256.8 energy -405219040.048392 
epoch 15
epoch   15 error 402607.1 energy -430136766.920731 
epoch 16
epoch   16 error 398868.0 energy -454524480.565422 
epoch 17
epoch   17 error 398208.7 energy -478411816.895757 
epoch 18
epoch   18 error 396493.4 energy -501819653.242841 
epoch 19
epoch   19 error 394605.5 energy -524762199.874034 
epoch 20
epoch   20 error 392981.9 energy -547265960.648713 
epoch 21
epoch   21 error 392338.4 energy -569340099.227584 
epoch 22
epoch   22 error 390601.4 energy -591085154.512036 
epoch 23
epoch   23 error 389879.6 energy -612461539.974077 
epoch 24
epoch   24 error 388782.1 energy -633468451.054830 
epoch 25
epoch   25 error 388364.5 energy -654125425.131981 
epoch 26
epoch   26 error 386876.1 energy -674518443.101429 
epoch 27
epoch   27 error 386544.7 energy -694619635.858860 
epoch 28
epoch   28 error 385006.0 energy -714467061.612437 
epoch 29
epoch   29 error 384898.6 energy -734014091.415583 
epoch 30
epoch   30 error 384137.1 energy -753300736.046747 
epoch 31
epoch   31 error 383938.5 energy -772303792.788062 
epoch 32
epoch   32 error 383642.5 energy -791090169.476486 
epoch 33
epoch   33 error 382299.7 energy -809661333.174433 
epoch 34
epoch   34 error 381459.4 energy -828056903.650036 
epoch 35
epoch   35 error 382378.4 energy -846229339.194885 
epoch 36
epoch   36 error 381321.8 energy -864236457.069615 
epoch 37
epoch   37 error 380576.9 energy -882130475.692142 
epoch 38
epoch   38 error 380819.1 energy -899805881.559751 
epoch 39
epoch   39 error 380575.8 energy -917291008.281778 
epoch 40
epoch   40 error 380441.0 energy -934597287.769714 
epoch 41
epoch   41 error 378924.6 energy -951799519.728817 
epoch 42
epoch   42 error 379242.5 energy -968837996.416408 
epoch 43
epoch   43 error 378084.0 energy -985763999.393868 
epoch 44
epoch   44 error 377760.6 energy -1002485955.994587 
epoch 45
epoch   45 error 378065.3 energy -1019135408.410493 
epoch 46
epoch   46 error 378113.7 energy -1035609968.466859 
epoch 47
epoch   47 error 378137.6 energy -1051977206.619751 
epoch 48
epoch   48 error 377804.7 energy -1068256355.969882 
epoch 49
epoch   49 error 377078.1 energy -1084361351.832399 
epoch 50
epoch   50 error 376941.1 energy -1100352833.515859 
迭代第二次的结果

Pretraining Layer 2 with RBM: 500-500 
epoch 1
epoch    1 error 993889.2 energy -15205761.985310 
epoch 2
epoch    2 error 676743.8 energy -34794600.316635 
epoch 3
epoch    3 error 622237.7 energy -54860642.213566 
epoch 4
epoch    4 error 585614.3 energy -75025441.023491 
epoch 5
epoch    5 error 555836.9 energy -95270049.773320 
epoch 6
epoch    6 error 521063.7 energy -113064624.687646 
epoch 7
epoch    7 error 442315.3 energy -129533185.056863 
epoch 8
epoch    8 error 407058.7 energy -145427948.901373 
epoch 9
epoch    9 error 389512.7 energy -160926615.310918 
epoch 10
epoch   10 error 379349.7 energy -176038677.575432 
epoch 11
epoch   11 error 372263.0 energy -190819909.913752 
epoch 12
epoch   12 error 368580.7 energy -205272520.010386 
epoch 13
epoch   13 error 365315.1 energy -219429209.242861 
epoch 14
epoch   14 error 363198.0 energy -233282955.187790 
epoch 15
epoch   15 error 360865.5 energy -246872028.136193 
epoch 16
epoch   16 error 360440.2 energy -260191385.979074 
epoch 17
epoch   17 error 358262.1 energy -273277908.363102 
epoch 18
epoch   18 error 356881.8 energy -286135114.134431 
epoch 19
epoch   19 error 356049.6 energy -298750619.512386 
epoch 20
epoch   20 error 356009.3 energy -311137943.591194 
epoch 21
epoch   21 error 355319.0 energy -323321281.109947 
epoch 22
epoch   22 error 354783.1 energy -335312699.884763 
epoch 23
epoch   23 error 354477.3 energy -347119931.104351 
epoch 24
epoch   24 error 354285.7 energy -358745715.833463 
epoch 25
epoch   25 error 353366.3 energy -370214880.549267 
epoch 26
epoch   26 error 352662.9 energy -381532009.670373 
epoch 27
epoch   27 error 353588.7 energy -392675673.518439 
epoch 28
epoch   28 error 352470.9 energy -403697030.809657 
epoch 29
epoch   29 error 353054.5 energy -414589450.592768 
epoch 30
epoch   30 error 352482.0 energy -425354947.125350 
epoch 31
epoch   31 error 351719.4 energy -435996803.022372 
epoch 32
epoch   32 error 351546.8 energy -446524221.283293 
epoch 33
epoch   33 error 351439.6 energy -456927518.928294 
epoch 34
epoch   34 error 351518.5 energy -467237727.772437 
epoch 35
epoch   35 error 351350.9 energy -477451966.520637 
epoch 36
epoch   36 error 350986.0 energy -487572943.439880 
epoch 37
epoch   37 error 350939.2 energy -497598679.794740 
epoch 38
epoch   38 error 350509.6 energy -507538226.850344 
epoch 39
epoch   39 error 350470.0 energy -517375214.715247 
epoch 40
epoch   40 error 351177.7 energy -527125526.454195 
epoch 41
epoch   41 error 350052.4 energy -536812334.692637 
epoch 42
epoch   42 error 350338.3 energy -546408162.241688 
epoch 43
epoch   43 error 349897.6 energy -555922938.066808 
epoch 44
epoch   44 error 350848.5 energy -565355955.788417 
epoch 45
epoch   45 error 350459.8 energy -574712019.150728 
epoch 46
epoch   46 error 349903.4 energy -583997810.590286 
epoch 47
epoch   47 error 350130.9 energy -593219406.465314 
epoch 48
epoch   48 error 349957.5 energy -602379132.214209 

回顾一下便能发现,判断玻尔兹曼机的参数设置是否合理或者训练是否朝着正常方向的两个可能的判别标准就是:

1、重构误差:一般是均方差,是正值、曲线向X轴逼近

2、能量值或者自由能量值:负值、远离X轴方向收敛

当然如果是判别式的RBM,可以利用CNN的判别方法,观察验证集的分类错误率。


文章代码下载地址:http://download.csdn.net/detail/zb1165048017/9750247

数据集即为:链接:http://pan.baidu.com/s/1pK9YZIb 密码:qw3w

【注】博文是博主自己研究的,可能有错误,如若有任何问题,请在评论区标注,大家一起研究。

展开阅读全文

没有更多推荐了,返回首页