matlab cnntest,MATLAB R2017b: Deep Learning with CNN

今天读了这篇文章后发现MATLAB的Deep learning原来可以这么简单,有点像Keras,封装的比较好。想当初刚接触tensor flow的时候真的有点头大。

知乎那篇文章中只介绍了CPU的版本,正好手头有块老旧的GPU,拿来试试。

首先去这里下载数据

解压

tar xzvf notMNIST_large.tar.gz

按照教程说的一步步做,最后trainingOption那里改成GPU就好

% Load data

ds = imageDatastore('notMNIST_large/','LabelSource','foldernames','IncludeSubfolders',true);

% Prepare data

[trainDigitData,valDigitData,testData]=ds.splitEachLabel(0.5,0.3,0.2,'Randomize');

% Define network layers

layers = [...

imageInputLayer([28,28,1]);

batchNormalizationLayer();

convolution2dLayer(5,20);

batchNormalizationLayer();

reluLayer()

maxPooling2dLayer(2,'Stride',2);

fullyConnectedLayer(10);

softmaxLayer();

classificationLayer(),...

];

% Customize training option

options = trainingOptions('sgdm',...

'ValidationData',valDigitData,...

'Plots','training-progress',...

'ExecutionEnvironment','gpu');

% Train

net = trainNetwork(trainDigitData,layers,options);

% Test

testLabel = classify(net,testData);

precision = sum(testLabel==testData.Labels)/numel(testLabel)

第一次实验,发现速度非常快

b30743359798

jianshunotMnist1.png

但是默认的validation频率太高了,导致很多时间都花费在了数据与GPU通讯上面,8分半跑了2000多个iteration(250循环/分钟)

于是减慢validation频率至每一千个循环验证一次

options = trainingOptions('sgdm',...

'ValidationData',valDigitData,...

'ValidationFrequency',1000,...

'Plots','training-progress',...

'ExecutionEnvironment','gpu');

b30743359798

jianshunotMnist2.png

可以看到13分钟跑了22000多次循环(1700循环/分钟),可谓效率大大提升。

下面问题就来了,该用这个做什么呢。。。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值