matlab卷积神经网络代码_卷积神经网络(四):反向传播过程的代码实现

本文主要讲解卷积神经网络(CNN)反向传播过程的matlab代码实现。

01

简介

CNN主要由三种层堆叠而成,即卷积层、池化层和全连接层,在《 卷积神经网络(三 ):反向 传播过程》中又推导了这三种层的误差反向传播公式。因此,CNN反向传播的代码主要由这三种层的反向传播代码构成。

02

代码实现

在CNN反向传播时,输出层的误差(目标函数)会依次从后往前经过全连接层、池化层和卷积层。在编写全连接层、池化层和卷积层的代码时,均是先求出误差关于各层输入的导数,然后再计算参数的导数。 假设本文研究的是分类问题,输出层采用softmax函数,目标函数定义为交叉熵损失+参数正则化。 (一)定义并计算目标函数

梯度下降要优化的目标函数,主要分为两部分:一部分是由于分类器输出结果和真实结果的差异引起的误差,另一部分是对权重w的正则约束。

logp = log(probs);index = sub2ind(size(logp),mb_labels',1:size(probs,2));ceCost = -sum(logp(index));wCost = lambda/2 * (sum(Wd(:).^2)+sum(Wc(:).^2));cost = ceCost/numImages + wCost;

(二)softmax层

交叉熵损失函数关于softmax层输入的导数为:

869783241323afa9f76810d7399c4df4.png

即直接用预测结果减去真实结果。如果采用是平方差损失函数,则平方差损失函数关于softmax层输入的导数为(需要分情况讨论):

3a57ef0909b0af56d3b4d28698091c29.png

本文采用分类问题常用的交叉熵损失函数。
output = zeros(size(probs));output(index) = 1;DeltaSoftmax = probs - output;

注:笔者做过利用Levenberg–Marquardt算法优化网络结构,此时即需要平方差损失函数。

(三)全连接层

全连接层的误差反向传播是将 DeltaSoftmax乘以各层的权重以及点乘激活函数的导数。
Wd_grad = (1./numImages) .* DeltaSoftmax*activationsPooled'+lambda*Wd;bd_grad = (1./numImages) .* sum(DeltaSoftmax,2);

(四)池化层

在求出误差关于第一个全连接层 的导数后,需要将 该结果还原成最后一个池化层输出的形状。如果采用的是平均池化,则误差在池化区域内的所有元素上均分;如果采用的是最大池化,则误差只由最大元素负责。
DeltaPool = reshape(Wd' * DeltaSoftmax,outputDim,outputDim,numFilters,numImages);DeltaUnpool = zeros(convDim,convDim,numFilters,numImages);for imNum = 1:numImages    for FilterNum = 1:numFilters        unpool = DeltaPool(:,:,FilterNum,imNum);        DeltaUnpool(:,:,FilterNum,imNum) = kron(unpool,ones(poolDim))./(poolDim ^ 2);    endend

(五)卷积层

卷积层的反向传播较为复杂,但是具体的推导细节已经在《卷积神经网络(三):反向传播过程》中解释清楚。
% 在求出误差关于池化层输入的导数后,再点乘激活函数的导数。DeltaConv = DeltaUnpool .* activations .* (1 - activations);% 卷积层偏置的代码bc_grad = zeros(size(bc));for filterNum = 1:numFilters    error = DeltaConv(:,:,filterNum,:);    bc_grad(filterNum) = (1./numImages) .* sum(error(:));end% 卷积层权重的代码Wc_grad = zeros(filterDim,filterDim,numFilters);% 旋转所有DealtaConv:下面的conv2在函数内部会自动旋转180度,% 所以在这里旋转是为了抵消conv2旋转的影响。for filterNum = 1:numFilters    for imNum = 1:numImages        error = DeltaConv(:,:,filterNum,imNum);        DeltaConv(:,:,filterNum,imNum) = rot90(error,2);    endendfor filterNum = 1:numFilters    for imNum = 1:numImages        Wc_grad(:,:,filterNum) = Wc_grad(:,:,filterNum) + conv2(images(:,:,imNum),DeltaConv(:,:,filterNum,imNum),'valid');    endendWc_grad = (1./numImages) .* Wc_grad + lambda*Wc;
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值