内容介绍
卷积神经网络(Convolutional Neural Network,CNN)
卷积神经网络相较于传统的图像处理算法的优点之一在于避免了对图像复杂的前期预处理过程,卷积神经网络可以直接输入原始图像进行一系列工作,从而提供了一个端到端的解决方案。
根据实际问题构造出网络结构,参数的确定则需要通过训练样本和学习算法来迭代找到最优参数组。
参数的优化过程旨在通过对神经网络中的可变参数进行调整,使得网络输出尽可能的接近期望输出。
算法原理
卷积神经网络结构一般是由输入层、多个交替的卷积层和池化层、全连接层,以及输出层组成。
卷积神经网络之训练算法:
- 确定网络模型 ;
- 初始化权重参数;
- 对于每个样例,执行以下步骤直到收敛:
- 计算模型输出:forward propagation(前向传播)
- 计算代价函数:比较模型输出与真实输出的差距
- 更新权重参数:back propagation(反向传播)
反向传播算法的核心是梯度下降算法。梯度下降算法会迭代式更新网络参数,不断沿着梯度的反方向让参数朝着总损失更小的方向更新使目标函数最小化 。
网络参数的优化分为两个过程,首先通过前向传播算法计算得到预测值,将预测值和真实值对比得到两者之间的差距。再通过反向传播算法计算目标函数对每个参数的梯度,根据梯度和学习率使用梯度下降算法更新每一个参数。
卷积神经网络的特点:局部感知、权值共享
- 局部连接是指特征层上的每个神经元的输入只与前一层的局部区域相连接。
- 权值共享是指在每次提取特征时,卷积核的参数是固定不变的。
上述这两个特征大大减少了网络参数的数目,降低了网络模型的复杂度。
实验环境
MATLABr2018b
实验步骤
- 准备数据集
- 定义网络结构
- 模型训练和测试
%%准备工作空间
clc
clear all
close all
%%导入数据
digitDatasetPath = fullfile('./','/HandWrittenDataset/');
imds = imageDatastore(digitDatasetPath,...
'IncludeSubfolders',true,'LabelSource','foldernames');%采用文件夹名称作为数据标记
%%数据集图个数
countEachLabel(imds)
numTrainFiles = 17;%每一个数字有22个样本,取17个样本作为训练数据
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomized');
%查看图片的大小
img = readimage(imds,1);
size(img)
%%定义卷积神经网络的结构
layers = [
%输入层
imageInputLayer([28 28 1])
%卷积层
convolution2dLayer(5,6,'Padding',2)
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'stride',2)
convolution2dLayer(5,16)%卷积
batchNormalizationLayer%归一化
reluLayer%激活函数
maxPooling2dLayer(2,'stride',2)
convolution2dLayer(5,120)
batchNormalizationLayer
reluLayer
%最终层
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
%%训练神经网络
% 一、设置训练参数
options = trainingOptions('sgdm',...
'MaxEpochs',50,...
'ValidationData',imdsValidation,...
'ValidationFrequency',5,...
'Verbose',false,...
'Plots','training-progress');%显示训练的进度
%训练神经网络,保存网络
net = trainNetwork(imdsTrain,layers,options);
save 'CSNet.mat' net
%%标记数据(文件名称方式,自行构造)
mineSet = imageDatastore('./hw22/hw22/', 'FileExtensions','.jpg',...
'IncludeSubfolders',false);
mLabels = cell(size(mineSet.Files,1),1);
for i = 1:size(mineSet.Files,1)
[filepath,name,ext] = fileparts(char(mineSet.Files{i}));
mLabels{i,1} = char(name);
end
mLabels2 = categorical(mLabels);
mineSet.Labels = mLabels2;
%%%使用网络进行分类并计算准确性
%手写数据
YPred = classify(net,mineSet);
YValidation = mineSet.Labels;
%计算正确率
accuracy = sum(YPred == YValidation)/numel(YValidation)
%绘制预测结果
figure;
nSample = 10;
ind = randperm(size(YPred,1),nSample);
for i = 1:nSample
subplot(2,fix((nSample+1)/2),i)
imshow(char(mineSet.Files(ind(i))))
title(['预测:' char(YPred(ind(i)))])
if char(YPred(ind(i))) ==char(YValidation(ind(i)))
xlabel(['真实:' char(YValidation(ind(i)))],'Color','b')
else
xlabel(['真实:' char(YValidation(ind(i)))],'color','r')
end
end