深度学习Matlab工具箱代码注释——cnntrain.m

  1. %%=========================================================================  
  2. %函数名称:cnntrain()  
  3. %输入参数:net,神经网络;x,训练数据矩阵;y,训练数据的标签矩阵;opts,神经网络的相关训练参数  
  4. %输出参数:net,训练完成的卷积神经网络  
  5. %算法流程:1)将样本打乱,随机选择进行训练;  
  6. %         2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出  
  7. %         3)通过BP算法计算误差对网络权值的导数  
  8. %         4)得到误差对权值的导数后,就通过权值更新方法去更新权值  
  9. %注意事项:1)使用BP算法计算梯度  
  10. %%=========================================================================  
  11. function net = cnntrain(net, x, y, opts)  
  12. m = size(x, 3);                      %m保存的是训练样本个数  
  13. disp(['样本总个数=' num2str(m)]);  
  14. numbatches = m / opts.batchsize;     %numbatches表示每次迭代中所选取的训练样本数  
  15. if rem(numbatches, 1) ~= 0           %如果numbatches不是整数,则程序发生错误  
  16.     error('numbatches not integer');  
  17. end  
  18.   
  19. %%=====================================================================  
  20. %主要功能:CNN网络的迭代训练  
  21. %实现步骤:1)通过randperm()函数将原来的样本顺序打乱,再挑出一些样本来进行训练  
  22. %         2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出  
  23. %         3)通过BP算法计算误差对网络权值的导数  
  24. %         4)得到误差对权值的导数后,就通过权值更新方法去更新权值  
  25. %注意事项:1)P = randperm(N),返回[1, N]之间所有整数的一个随机的序列,相当于把原来的样本排列打乱,  
  26. %            再挑出一些样本来训练  
  27. %         2)采用累积误差的计算方式来评估当前网络性能,即当前误差 = 以前误差 * 0.99 + 本次误差 * 0.01  
  28. %            使得网络尽可能收敛到全局最优  
  29. %%=====================================================================  
  30. net.rL = [];                         %代价函数值,也就是误差值  
  31. for i = 1 : opts.numepochs           %对于每次迭代  
  32.     disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);  
  33.     tic;                             %使用tic和toc来统计程序运行时间  
  34.       
  35.     %%%%%%%%%%%%%%%%%%%%取出打乱顺序后的batchsize个样本和对应的标签 %%%%%%%%%%%%%%%%%%%%  
  36.     kk = randperm(m);                 
  37.     for l = 1 : numbatches  
  38.         batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));  
  39.         batch_y = y(:,    kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));  
  40.           
  41.         %%%%%%%%%%%%%%%%%%%%在当前的网络权值和网络输入下计算网络的输出(特征向量)%%%%%%%%%%%%%%%%%%%%  
  42.         net = cnnff(net, batch_x); %卷积神经网络的前馈运算  
  43.           
  44.         %%%%%%%%%%%%%%%%%%%%通过对应的样本标签用bp算法来得到误差对网络权值的导数%%%%%%%%%%%%%%%%%%%%  
  45.         net = cnnbp(net, batch_y); %卷积神经网络的BP算法  
  46.           
  47.         %%%%%%%%%%%%%%%%%%%%通过权值更新方法去更新权值%%%%%%%%%%%%%%%%%%%%  
  48.         net = cnnapplygrads(net, opts);  
  49.           
  50.         if isempty(net.rL)  
  51.             net.rL(1) = net.L;     %代价函数值,也就是均方误差值 ,在cnnbp.m中计算初始值 net.L = 1/2* sum(net.e(:) .^ 2) / size(net.e, 2);         
  52.         end  
  53.         net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L; %采用累积的方式计算累积误差  
  54.     end  
  55.     toc;  
  56. end  
  57. end  

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值