深度学习Matlab工具箱代码详解

相关资源打包下载: http://download.csdn.net/download/tianjing0805/9939566

  最近研究了几天深度学习的Matlab工具箱代码,发现作者给出的源码中注释实在是少得可怜,为了方便大家阅读,特对代码进行了注释,与大家分享。

  在阅读Matlab工具箱代码之前,建议大家阅读几篇CNN方面的两篇经典材料,对卷积神经网络Matlab工具箱代码的理解有很大帮助。

  (1)《Notes on Convolutional Neural Networks》,这篇文章是与Matlab工具箱代码配套的文献,不过文献中在下采样层也有两种训练参数,在工具箱中的下采样层并没有可训练参数,直接进行下采样操作。

     (2)《CNN学习-薛开宇》,这是与《Notes on Convolutional Neural Networks》内容及其相似的一份中文PPT资料,对卷积神经网络的介绍也是通俗易懂。

     (3)深度学习的Matlab工具箱Github下载地址:https://github.com/rasmusbergpalm/DeepLearnToolbox

     接下来给出一个工具箱中CNN程序在Mnist数据库上的示例程序:

[python] view plain copy
print ?
  1. %%=========================================================================  
  2. % 主要功能:在mnist数据库上做实验,验证工具箱的有效性  
  3. % 算法流程:1)载入训练样本和测试样本  
  4. %          2)设置CNN参数,并进行训练  
  5. %          3)进行检测cnntest()  
  6. % 注意事项:1)由于直接将所有测试样本输入会导致内存溢出,故采用一次只测试一个训练样本的测试方法  
  7. %%=========================================================================  
  8. %%  
  9. %%%%%%%%%%%%%%%%%%%%加载数据集%%%%%%%%%%%%%%%%%%%%  
  10. load mnist_uint8;  
  11. train_x = double(reshape(train_x’,28,28,60000))/255;  
  12. test_x  = double(reshape(test_x’,28,28,10000))/255;  
  13. train_y = double(train_y’);  
  14. test_y  = double(test_y’);  
  15.   
  16. %%  
  17. %%=========================================================================  
  18. %%%%%%%%%%%%%%%%%%%%设置卷积神经网络参数%%%%%%%%%%%%%%%%%%%%  
  19. % 主要功能:训练一个6c-2s-12c-2s形式的卷积神经网络,预期性能如下:  
  20. %          1)迭代一次需要200秒左右,错误率大约为11%  
  21. %          2)迭代一百次后错误率大约为1.2%  
  22. % 算法流程:1)构建神经网络并进行训练,以CNN结构体的形式保存  
  23. %          2)用已知的训练样本进行测试  
  24. % 注意事项:1)之前在测试的时候提示内存溢出,后来莫名其妙的又不溢出了,估计到了系统的内存临界值  
  25. %%=========================================================================  
  26. rand(’state’,0)  
  27. cnn.layers = {  
  28.     struct(’type’‘i’)                                    %输入层  
  29.     struct(’type’‘c’‘outputmaps’6‘kernelsize’5)  %卷积层  
  30.     struct(’type’’s’‘scale’2)                        %下采样层  
  31.     struct(’type’‘c’‘outputmaps’12‘kernelsize’5) %卷积层  
  32.     struct(’type’’s’‘scale’2)                        %下采样层  
  33.     };  
  34. cnn            = cnnsetup(cnn, train_x, train_y);  
  35. opts.alpha     = 1;  
  36. opts.batchsize = 50;  
  37. opts.numepochs = 5;  
  38. cnn            = cnntrain(cnn, train_x, train_y, opts);  
  39. save CNN_5 cnn;  
  40.   
  41. load CNN_5;  
  42. [er, bad]  = cnntest(cnn, test_x, test_y);  
  43. figure; plot(cnn.rL);  
  44. assert(er<0.12‘Too big error’);  
%%=========================================================================
% 主要功能:在mnist数据库上做实验,验证工具箱的有效性
% 算法流程:1)载入训练样本和测试样本
%          2)设置CNN参数,并进行训练
%          3)进行检测cnntest()
% 注意事项:1)由于直接将所有测试样本输入会导致内存溢出,故采用一次只测试一个训练样本的测试方法
%%=========================================================================
%%
%%%%%%%%%%%%%%%%%%%%加载数据集%%%%%%%%%%%%%%%%%%%%
load mnist_uint8;
train_x = double(reshape(train_x',28,28,60000))/255;
test_x  = double(reshape(test_x',28,28,10000))/255;
train_y = double(train_y');
test_y  = double(test_y');

%%
%%=========================================================================
%%%%%%%%%%%%%%%%%%%%设置卷积神经网络参数%%%%%%%%%%%%%%%%%%%%
% 主要功能:训练一个6c-2s-12c-2s形式的卷积神经网络,预期性能如下:
%          1)迭代一次需要200秒左右,错误率大约为11%
%          2)迭代一百次后错误率大约为1.2%
% 算法流程:1)构建神经网络并进行训练,以CNN结构体的形式保存
%          2)用已知的训练样本进行测试
% 注意事项:1)之前在测试的时候提示内存溢出,后来莫名其妙的又不溢出了,估计到了系统的内存临界值
%%=========================================================================
rand('state',0)
cnn.layers = {
    struct('type', 'i')                                    %输入层
    struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5)  %卷积层
    struct('type', 's', 'scale', 2)                        %下采样层
    struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %卷积层
    struct('type', 's', 'scale', 2)                        %下采样层
    };
cnn            = cnnsetup(cnn, train_x, train_y);
opts.alpha     = 1;
opts.batchsize = 50;
opts.numepochs = 5;
cnn            = cnntrain(cnn, train_x, train_y, opts);
save CNN_5 cnn;

load CNN_5;
[er, bad]  = cnntest(cnn, test_x, test_y);
figure; plot(cnn.rL);
assert(er<0.12, 'Too big error');

接下来给出工具箱中有关CNN部分程序注释的网址:
  (1)深度学习Matlab工具箱代码注释——cnnsetup.m
  (2)深度学习Matlab工具箱代码注释——cnntrain.m
  (3)深度学习Matlab工具箱代码注释——cnnff.m
  (4)深度学习Matlab工具箱代码注释——cnnbp.m
  (5)深度学习Matlab工具箱代码注释——cnnapplygrads.m

原文链接

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值