MATLAB如何使用CIFAR-10数据集训练神经网络

       

目录

下载CIFAR-10数据集

对数据集进行处理

设计网络结构

基础识别网络的设计

基础网络的训练结果和改进

改进网络的结果分析

        关于CIFAR数据集网上已经有很多使用的教程,MATLAB官方也给出了一个示例,但是因为课程要求我们不能使用网上的示例,并且要自己分析网络结构对于训练结果的影响,所以我写了这篇文章,记录一下自己摸索着使用CIFAR-10数据集的过程。

下载CIFAR-10数据集

        点击下面这个链接,找到下图里蓝色字样的下载链接就可以下载CIFAR-10的数据集啦。按照需求下载即可,我需要在MATLAB里处理这些数据,因此下载的是MATLAB version。

        CIFAR-10 and CIFAR-100 datasets (toronto.edu)

1168c74f2f88463dbed0d1cc762bdc2c.png

对数据集进行处理

        由于我的训练网络是在MATLAB提供的示例基础上进行修改得来的,而该网络读取的是.jpg或,png格式的文件,但CIFAR-10提供的是.mat格式的文件,因此需要先对解压后的CIFAR-10文件进行处理。

        首先在MATLAB的根目录下创建名为“CIFAR-10-batches-mat”的文件夹,并在该文件夹下创建名为“test”和“train”的文件夹,在“test”和“train”文件夹里分别创建名为“0”-“9”的文件夹。等会我们要把数据集里的图片存到这里,训练集的十类数据按类放进“train”文件夹下的“0”-“9”的文件夹里,同样,测试集的十类数据放进“test”文件夹下的“0”-“9”的文件夹里。(选择其他位置也可以,只是放在MATLAB根目录下调用起来比较方便)

0e779fc835334ebf95424c85e240a377.png

bd44624663194dd1aa85e4f7dcdf660a.png

e09b15f8e8044308b478fa7013afdd1b.png

        之后就可以使用MATLAB把数据集里的图片按标签放在这些文件夹里啦,具体的代码如下:

%% 从test_batch中提取图片的代码
load(['E:\MATLAB\cifar-10-batches-mat\test_batch.mat'])%这里的路径需要换成自己的.mat文件的路径
for i=1:size(data,1)
    p=data(i,:);
    label=labels(i);
    fig=zeros(32,32,3);
    fig(:,:,1)=reshape(p(1:1024),32,32)';
    fig(:,:,2)=reshape(p(1025:2048),32,32)';
    fig(:,:,3)=reshape(p(2049:end),32,32)';
    imwrite(fig/256,['E:\MATLAB\cifar-10-batches-mat\test\',num2str(label), '\_label_' num2str(label) '_' num2str(i) ,'.png'])%这里的路径换成你需要保存的路径
end
%% 从data batch 1-5中提取图片的代码
for j =1:5
load(['E:\MATLAB\cifar-10-batches-mat\data_batch_',num2str(j),'.mat'])%这里改成data_batch所在的路径
    for i=1:size(data,1)
        p=data(i,:);
        label=labels(i);
        fig=zeros(32,32,3);
        fig(:,:,1)=reshape(p(1:1024),32,32)';
        fig(:,:,2)=reshape(p(1025:2048),32,32)';
        fig(:,:,3)=reshape(p(2049:end),32,32)';
        imwrite(fig/256,['E:\MATLAB\cifar-10-batches-mat\train\',num2str(label), '\_label_'  ,num2str(label) ,'_', num2str(j),'_', num2str(i) ,'.png'])%这里改成想要保存的路径
    end
end

        如果将数据集解压到MATLAB根目录下,并按照上文的要求在MATLAB根目录下创建了文件夹,那么代码中注释的部分就只需要修改前面MATLAB的路径即可。

        执行完代码后你会发现,“test”和“train”的每个子文件夹下图片都已经按照它的类别分好了。比如“0”文件夹,其中就只包含标签为“airplane”的图片。

df5f77af9235439ea9cd570a49a5cba8.png

设计网络结构

基础识别网络的设计

        首先向MATLAB中导入这些图片和对应的标签,代码如下:

digitDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','train');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
YTrain=imds.Labels;
digitTestDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','test');
test_imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

        其中“matlabroot”变量代表MATLAB的根目录位置。

        之后设置网络输入参数,因为CIFAR数据集提供的是eq?32%20%5Ctimes%2032%20%5Ctimes%203的rgb图片,因此设置图片输入层参数为32 32 3。同时为了防止网络对图片的过拟合,需要对图片进行强化,具体来说就是对图片设置随机反转、随机X轴或Y轴平移等。代码如下:

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

        网络的设置了三层卷积层,相邻卷积层之间通过池化层连接,最后连接一个全连接层。这样一个简单的网络就设计好啦。训练轮次MaxEpochs设置为20轮,验证集ValidationData设置为刚导入的测试集图片“test_imds”,验证频率ValidationFrequency设置为30,即可开始训练。完整代码如下

%% 导入训练集数据和验证集数据
digitDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','train');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
YTrain=imds.Labels;
digitTestDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','test');
test_imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

% figure;
% perm = randperm(10000,20);
% fprint("B7744F63DB749E48C05C06F7F3BD11D4")
% for i = 1:20
%     subplot(4,5,i);
%     imshow(imds.Files{perm(i)});
% end
labelCount = countEachLabel(imds)

img = readimage(imds,1);
for i =1:50000
    XTrain(:,:,:,i)=readimage(imds,i);
end
size(img)
%% 设置训练参数
imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');
%% 设置网络结构

layers = [
    imageInputLayer([32 32 3])
    
    convolution2dLayer(5,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,64,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
%% 设置训练选项
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',20, ...
    'Shuffle','every-epoch', ...
    'ValidationData',test_imds, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress','ExecutionEnvironment','gpu');
net = trainNetwork(augimdsTrain,layers,options);
%% 对有效识别图片进行分类并计算准确率
YPred = classify(net,test_imds);
YValidation = test_imds.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)

        把上面的代码粘到MATLAB里直接运行即可。训练时间大概为30min(因硬件不同而异),训练之后就可以得到一个能够识别这十类图片的网络了。

基础网络的训练结果和改进

        根据上文的代码,即可训练出一个网络。但是我训练完这个网络之后发现一个问题,那就是该网络对“猫”和“狗”的识别效果很差。
 

93087cf271c944d0a0bccecd7f82518c.png

    

fc2c736a30e84e2e83222476ad5aca0e.png

        可以看到,在混淆矩阵中有接近一半的“猫”图片被识别为“狗”(3代表“猫”类,5代表“狗”类)。让神经网络识别在网上找的猫猫图片,同样会把猫识别成狗。说明网络对于猫和狗的细节部分还是没有把握住。因此可以考虑增加卷积层的数量,同时增加训练的迭代轮数,强化网络对猫和狗的不同特征的学习。代码如下:

%% 导入训练集数据和验证集数据
digitDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','train');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
YTrain=imds.Labels;
digitTestDatasetPath = fullfile(matlabroot,'cifar-10-batches-mat','test');
test_imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');

% figure;
% perm = randperm(10000,20);
% fprint("B7744F63DB749E48C05C06F7F3BD11D4")
% for i = 1:20
%     subplot(4,5,i);
%     imshow(imds.Files{perm(i)});
% end
labelCount = countEachLabel(imds)

img = readimage(imds,1);
for i =1:50000
    XTrain(:,:,:,i)=readimage(imds,i);
end
size(img)
%% 设置训练参数
imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');
%% 设置网络结构

layers = [
    imageInputLayer([32 32 3])
    
    convolution2dLayer(5,32,'Padding','same')
    batchNormalizationLayer
    leakyReluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,64,'Padding','same')
    batchNormalizationLayer
    leakyReluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,128,'Padding','same')
    batchNormalizationLayer
    leakyReluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,256,'Padding','same')
    batchNormalizationLayer
    leakyReluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(5,512,'Padding','same')
    batchNormalizationLayer
    leakyReluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
%% 设置训练选项
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',50, ...
    'Shuffle','every-epoch', ...
    'ValidationData',test_imds, ...
    'ValidationFrequency',200, ...
    'Verbose',false, ...
    'Plots','training-progress','ExecutionEnvironment','gpu');
net = trainNetwork(augimdsTrain,layers,options);
%% 对有效识别图片进行分类并计算准确率
YPred = classify(net,test_imds);
YValidation = test_imds.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)

        训练轮次调整为了50,并且网络又增加了2层卷积层,因此训练起来会更费时间,大概在50min左右。

改进网络的结果分析

        改进以后的网络虽然训练时间更长了,但是识别结果也更准了,对“猫”类的识别正确率由原来的57.5%一下子提高到92.8%,整体识别率为95.69%,提高了很多。原本识别错误的猫猫图片现在也能识别正确了。

0195dda36ac74e4b88727871130afc78.png

7205034df15249e980a1b8192f02433f.png

训练好的网络对图片进行分类的代码

myImage1=imread('图片位置');
figure,imshow(myImage1)
myImagex=imresize(myImage1,[32,32]);
YPred1=classify(net,myImagex);
str_YPred1=string(YPred1);
num_YPred=double(str_YPred1);  
switch num_YPred
    case 0
        fprintf('This is a airplane.\n')
        title('This is a airplane.')
    case 1
        fprintf('This is a automobile.\n')
        title('This is a automobile.')
    case 2
        fprintf('This is a bird.\n')
        title('This is a bird.')
    case 3
        fprintf('This is a cat.\n')
        title('This is a cat.')
    case 4
        fprintf('This is a deer.\n')
        title('This is a deer.')
    case 5
        fprintf('This is a dog.\n')
        title('This is a dog.')
    case 6
        fprintf('This is a frog.\n')
        title('This is a frog.')
    case 7
        fprintf('This is a horse.\n')
        title('This is a horse.')
    case 8
        fprintf('This is a ship.\n')
        title('This is a ship.')
    case 9
        fprintf('This is a truck.\n')
        title('This is a truck.')
end

        绘制混淆矩阵图的代码这里就不放啦,有需要可以私信我。

 

 

 

 

### 回答1: 要下载CIFAR-10数据集MATLAB中,可以按照以下步骤操作: 1. 首先,打开MATLAB,并确保已连接到互联网。 2. 在MATLAB命令窗口中输入以下命令: ```matlab websave('cifar-10-data.mat','https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz') ``` 这个命令将使用MATLAB的`websave`函数从CIFAR-10数据集的官方网站下载压缩文件,并将其保存为`cifar-10-data.mat`文件。 3. 下载完成后,解压缩刚刚下载的文件。可以使用以下命令: ```matlab untar('cifar-10-data.mat') ``` 这个命令将解压缩刚刚下载的文件。 4. 解压缩后,可以在MATLAB使用加载函数`load`加载CIFAR-10数据集使用以下命令: ```matlab load('cifar-10-batches-mat/data_batch_1.mat') ``` 这个命令将加载CIFAR-10数据集的第一个批次,可以根据需要加载其他批次的数据。 5. 加载后的数据将被存储在一个MATLAB结构体变量中,可以根据需要访问不同的字段来获取图像和标签数据。 以上就是在MATLAB下载CIFAR-10数据集的步骤。下载完成后,你就可以使用这些数据来进行图像分类、目标识别等机器学习任务。 ### 回答2: 要下载CIFAR-10数据集,您可以按照以下步骤使用MATLAB进行操作。 首先,您需要访问CIFAR-10数据集的官方网站(https://www.cs.toronto.edu/~kriz/cifar.html)以获取数据集下载链接。 接下来,在MATLAB的命令行窗口中使用"web"函数打开CIFAR-10数据集的网页。例如,输入以下命令并按Enter键: web('https://www.cs.toronto.edu/~kriz/cifar.html','-browser') 然后,您将看到网页加载在MATLAB的浏览器中。 在网页中,您可以找到"CIFAR-10 binary version (suitable for C programs)"这个选项,该选项包含了CIFAR-10数据集下载链接。点击链接以下载数据集下载完成后,您可以将数据集解压缩到您选择的文件夹中。建议您将数据集保存在一个清晰和易于访问的位置。 在MATLAB中,您可以使用"load"函数加载下载数据集文件。例如,假设您将数据集保存为"CIFAR-10"文件夹,您可以使用以下命令读取数据集: load(fullfile('CIFAR-10', 'data_batch_1.mat')) 这将加载数据集中的第一个数据批次到MATLAB的工作空间中,您可以使用MATLAB的各种功能和工具来进一步处理和分析数据。 总结起来,要在MATLAB下载CIFAR-10数据集,请访问官方网站获取下载链接,使用MATLAB的"web"函数打开网页并下载数据集,然后使用"load"函数加载数据集文件到MATLAB。 ### 回答3: 要在MATLAB下载CIFAR-10数据集,可以按照以下步骤进行: 1. 首先需要在MATLAB中创建一个文件夹用于保存CIFAR-10数据集。可以使用以下代码创建一个名为"CIFAR-10"的文件夹: ```matlab mkdir('CIFAR-10'); ``` 2. 使用wget命令下载CIFAR-10数据集的压缩文件。可以使用以下代码在MATLAB命令窗口中运行wget命令: ```matlab !wget https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz ``` 下载完成后,会在当前文件夹中生成一个名为"cifar-10-matlab.tar.gz"的压缩文件。 3. 使用untar命令解压缩下载的压缩文件。可以使用以下代码在MATLAB命令窗口中运行untar命令: ```matlab untar('cifar-10-matlab.tar.gz', 'CIFAR-10'); ``` 解压缩完成后,CIFAR-10数据集的.mat文件将会存储在"CIFAR-10"文件夹中。 4. 现在可以在MATLAB中加载CIFAR-10数据集并进行数据分析、处理和训练模型等操作。可以使用以下代码加载CIFAR-10数据集: ```matlab load('CIFAR-10/cifar-10-batches-mat/data_batch_1.mat'); ``` 加载数据集后,数据集的相关变量将会在MATLAB的工作空间中生成,可以使用这些变量进行进一步的数据处理和分析。 以上是在MATLAB下载CIFAR-10数据集的简单步骤。确保在下载和解压缩过程中的网络连接正常,并提前安装好wget和untar命令。
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值