DeepLearnToolbox是一个简单理解CNN过程的工具箱,可以在github下载。为了理解卷积神经网络的过程,我特此对CNN部分源码进行了注释。公式的计算可以由上一篇blog推导得出。
注意:代码中没有的subsampling进行设置参数,将subsampling层的参数w就设置为了0.25,而偏置参数b设置为0。卷积层计算过程为上一层所有feature map的卷积的结果和,后再加一个偏置,再取sigmoid函数。而subsampling的计算过程为上一层对应的2*2的feature map的像素值求和再取平均,没有加上偏置和取sigmoid。最后一层隐藏层为4*4大小的12个feature map,因此最后能得到192维的特征,全连接就是192*10(分类数目)。。
此外net中一些参数进行说明:
net.fv: 最后一层隐藏层的特征矩阵,采用的是全连接方式
net.o: 最后输出的结果,每一列为一个样本结果
net.od: 最后一层输出层所对应的残差
net.fvd: 最后一层隐藏层所对应的误差(全连接的方式)
test_example_CNN.m
%function test_example_CNN
addpath D:\DeepLearning\DeepLearnToolbox-master\data\
addpath D:\DeepLearning\DeepLearnToolbox-master\CNN\
addpath D:\DeepLearning\DeepLearnToolbox-master\util\
load mnist_uint8;
train_x = double(reshape(train_x',28,28,60000))/255; % 训练集变成60000张28*28的图片大小 28*28*60000,像素点归一化到[0,1]
test_x = double(reshape(test_x',28,28,10000))/255; % 测试集 28*28*10000
train_y = double(train_y'); %10*6000 每列代表一个标签 softmax回归模型
test_y = double(test_y');
%% ex1 Train a 6c-2s-12c-2s Convolutional neural network
%will run 1 epoch in about 200 second and get around 11% error.
%With 100 epochs you'll get around 1.2% error
rand('state',0)
cnn.layers = { %%% 设置各层feature maps个数及卷积模板大小等属性
struct('type', 'i') %input layer
struct('type', 'c', 'outputmaps', 6, 'kernelsize', 5) %convolution layer
struct('type', 's', 'scale', 2) %sub sampling layer
struct('type', 'c', 'outputmaps', 12, 'kernelsize', 5) %convolution layer
struct('type', 's', 'scale', 2) %subsampling layer
};
opts.alpha = 0.01; %迭代下降的速率
opts.batchsize = 50; %每次选择50个样本进行更新 随机梯度下降,每次只选用50个样本进行更新
opts.numepochs = 50; %迭代次数
cnn = cnnsetup(cnn, train_x, train_y); %对各层参数进行初始化 包括权重和偏置
cnn = cnntrain(cnn, train_x, train_y, opts); %训练的过程,包括bp算法及迭代过程
[er, bad] = cnntest(cnn, test_x, test_y);
%plot mean squared error
figure; plot(cnn.rL);
% assert(er<0.12, 'Too big error');
cnnsetup.m
function net = cnnsetup(net, x, y)
% assert(~isOctave() || compare_versions(OCTAVE_VERSION, '3.8.0', '>='), ['Octave 3.8.0 or greater is required for CNNs as there is a bug in convolution in previous versions. See http://savannah.gnu.org/bugs/?39314. Your version is ' myOctaveVersion]);
inputmaps = 1; %输入图片数量 输入feature maps数量
mapsize = siz