【源码】MATLAB深度学习实战DEMO

深度学习DEMO提供了三个实现目标识别的卷积神经网络CNN示例。三个例子分别为:

  1. 从零开始学习如何建立CNN;
    
  2. 使用已经训练过的模型(迁移学习);
    
  3. 用于特征提取的神经网络训练。
    

每个DEMO都有对应的视频讲解,视频地址为:https://www.mathworks.com/videos/series/deep-learning-with-MATLAB.html

运行以上示例需要安装MATLAB自带的GPU和并行计算工具箱,DEMO 3还需要安装统计与机器学习工具箱。

下面简单介绍DEMO 1:从零开始学习如何建立CNN。

  1. 运行DownloadCIFAR10.m文件,下载DEMO运行所需要的数据。
    
  2. 执行以下代码将训练数据导入MATLAB;
    

%Please note: these are 4 of the 10 categories available

%Feel free to choose which ever you like best!

categories= {‘Deer’,‘Dog’,‘Frog’,‘Cat’};

rootFolder= ‘cifar10Train’;

imds= imageDatastore(fullfile(rootFolder, categories), …

'LabelSource', 'foldernames');
  1. 定义CNN的各层网络,这里可以根据自己的需要调整参数,下面的代码只是一个示例。
    

varSize= 32;

conv1= convolution2dLayer(5,varSize,‘Padding’,2,‘BiasLearnRateFactor’,2);

conv1.Weights= gpuArray(single(randn([5 5 3 varSize])*0.0001));

fc1= fullyConnectedLayer(64,‘BiasLearnRateFactor’,2);

fc1.Weights= gpuArray(single(randn([64 576])*0.1));

fc2= fullyConnectedLayer(4,‘BiasLearnRateFactor’,2);

fc2.Weights= gpuArray(single(randn([4 64])*0.1));

layers= [

imageInputLayer([varSize varSize 3]);

conv1;

maxPooling2dLayer(3,'Stride',2);

reluLayer();

convolution2dLayer(5,32,'Padding',2,'BiasLearnRateFactor',2);

reluLayer();

averagePooling2dLayer(3,'Stride',2);

convolution2dLayer(5,64,'Padding',2,'BiasLearnRateFactor',2);

reluLayer();

averagePooling2dLayer(3,'Stride',2);

fc1;

reluLayer();

fc2;

softmaxLayer()

classificationLayer()];
  1. 设置CNN的训练选项,这些参数设置会严重影响CNN的工作性能,在设置之前应当准确理解这些参数的物理意义。
    

opts= trainingOptions(‘sgdm’, …

'InitialLearnRate', 0.001, ...

'LearnRateSchedule', 'piecewise', ...

'LearnRateDropFactor', 0.1, ...

'LearnRateDropPeriod', 8, ...

'L2Regularization', 0.004, ...

'MaxEpochs', 10, ...

'MiniBatchSize', 100, ...

'Verbose', true);
  1. 开始训练CNN,训练时间长短与具体的硬件设备相关,一般会花费数分钟或以上。
    

[net, info] =trainNetwork(imds, layers, opts);

Training on singleGPU.

Initializing imagenormalization.

|=========================================================================================|

| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning|

| | | (seconds) | Loss | Accuracy | Rate |

|=========================================================================================|

| 1 | 1 | 0.25 | 1.3862 | 24.00% | 0.0010 |

| 1 | 50 | 1.86 | 1.2571 | 39.00% | 0.0010 |

| 1 | 100 | 3.35 | 1.2376 | 39.00% | 0.0010 |

| 1 | 150 | 4.90 | 1.1451 | 50.00% | 0.0010 |

| 1 | 200 | 6.39 | 1.0797 | 59.00% | 0.0010 |

| 2 | 250 | 8.03 | 0.8069 | 69.00% | 0.0010 |

| 2 | 300 | 9.64 | 1.1253 | 51.00% | 0.0010 |

| 2 | 350 | 11.20 | 0.9872 | 59.00% | 0.0010 |

| 2 | 400 | 12.75 | 0.9490 | 59.00% | 0.0010 |

| 3 | 450 | 14.31 | 0.7405 | 70.00% | 0.0010 |

| 3 | 500 | 15.77 | 0.9592 | 59.00% | 0.0010 |

| 3 | 550 | 17.28 | 0.9337 | 61.00% | 0.0010 |

| 3 | 600 | 18.77 | 0.8383 | 65.00% | 0.0010 |

| 4 | 650 | 20.30 | 0.6693 | 71.00% | 0.0010 |

| 4 | 700 | 21.80 | 0.8787 | 63.00% | 0.0010 |

| 4 | 750 | 23.27 | 0.8892 | 63.00% | 0.0010 |

| 4 | 800 | 24.76 | 0.7295 | 69.00% | 0.0010 |

| 5 | 850 | 26.28 | 0.6321 | 72.00% | 0.0010 |

| 5 | 900 | 27.77 | 0.8034 | 71.00% | 0.0010 |

| 5 | 950 | 29.26 | 0.8285 | 68.00% | 0.0010 |

| 5 | 1000 | 30.75 | 0.6893 | 69.00% | 0.0010 |

| 6 | 1050 | 32.27 | 0.5741 | 76.00% | 0.0010 |

| 6 | 1100 | 33.74 | 0.7280 | 73.00% | 0.0010 |

| 6 | 1150 | 35.20 | 0.8312 | 68.00% | 0.0010 |

| 6 | 1200 | 36.69 | 0.5876 | 77.00% | 0.0010 |

| 7 | 1250 | 38.25 | 0.5598 | 75.00% | 0.0010 |

| 7 | 1300 | 39.80 | 0.6704 | 77.00% | 0.0010 |

| 7 | 1350 | 41.37 | 0.7792 | 68.00% | 0.0010 |

| 7 | 1400 | 42.87 | 0.5495 | 78.00% | 0.0010 |

| 8 | 1450 | 44.40 | 0.5561 | 79.00% | 0.0010 |

| 8 | 1500 | 45.89 | 0.6032 | 81.00% | 0.0010 |

| 8 | 1550 | 47.39 | 0.7548 | 68.00% | 0.0010 |

| 8 | 1600 | 48.90 | 0.5371 | 78.00% | 0.0010 |

| 9 | 1650 | 50.49 | 0.5247 | 80.00% | 0.0001 |

| 9 | 1700 | 52.02 | 0.5989 | 79.00% | 0.0001 |

| 9 | 1750 | 53.60 | 0.6982 | 72.00% | 0.0001 |

| 9 | 1800 | 55.17 | 0.4448 | 78.00% | 0.0001 |

| 10 | 1850 | 56.71 | 0.4927 | 79.00% | 0.0001 |

| 10 | 1900 | 58.23 | 0.5630 | 80.00% | 0.0001 |

| 10 | 1950 | 59.71 | 0.6843 | 73.00% | 0.0001 |

| 10 | 2000 | 61.18 | 0.4486 | 79.00% | 0.0001 |

|=========================================================================================|

  1. 将测试验证数据导入MATLAB。
    

rootFolder= ‘cifar10Test’;

imds_test= imageDatastore(fullfile(rootFolder, categories), …

'LabelSource', 'foldernames');
  1. 测试结果输出,通过随机读取一幅图片进行分类测试,如果图片的标题为绿色,则预测结果正确;如果为红色,则预测结果错误。
    

labels= classify(net, imds_test);

ii= randi(4000);

im= imread(imds_test.Files{ii});

imshow(im);

iflabels(ii) ==imds_test.Labels(ii)

colorText = ‘g’;

else

colorText = 'r';

end

title(char(labels(ii)),‘Color’,colorText);

DEMO下载地址:

http://page2.dfpan.com/fs/9lc2j2821f29b1676d7/

更多精彩文章请关注微信号:在这里插入图片描述

  • 2
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值