Monitor Deep Learning Training Progress

监控深度学习训练进度

在训练深度学习网络时,监控训练进度通常很有用。通过在训练过程中绘制各种指标,您可以了解训练的进度情况。例如,您可以确定网络准确度是否改善以及改善速度,还可以确定网络是否开始过拟合训练数据。

在 trainingOptions 中将 ‘training-progress’ 指定为 ‘Plots’ 值并开始网络训练时,trainNetwork 会创建一个图窗并在每次迭代时显示训练指标。 每次迭代都是对梯度的一次估计和对网络参数的一次更新。如果在 trainingOptions 中指定验证数据,则每次 trainNetwork 验证网络时,该图窗都会显示验证指标。该图窗绘制以下内容:

训练准确度 - 针对每个小批量的分类准确度。

经过平滑处理的训练准确度 - 经过平滑处理的训练准确度,通过将平滑算法应用于训练准确度来获得。它的噪声低于未平滑的准确度,因此更易于揭示趋势。

验证准确度 - 针对整个验证集的分类准确度(使用 trainingOptions 指定)。

训练损失、经过平滑处理的训练损失和验证损失 - 分别指每个小批量的损失、其经过平滑处理的版本以及验证集的损失。如果网络的最终层是一个 classificationLayer,则损失函数是交叉熵损失。有关分类和回归问题的损失函数的详细信息,请参阅输出层。

对于回归网络,该图窗绘制均方根误差 (RMSE) 而不是准确度。

图窗使用交替底色来标记每一轮训练。一轮训练是对整个训练数据集的一次完整遍历。

在训练过程中,您可以通过点击右上角的停止按钮停止训练并返回网络的当前状态。例如,您可能希望在网络准确度达到稳定水平并且准确度明显不再提高时停止训练。点击停止按钮后,可能需要一段时间才能完成训练。训练完成后,trainNetwork 将返回经过训练的网络。

训练结束后,查看结果,其中显示最终验证准确度和训练结束的原因。最终验证指标在绘图中标注为 Final。如果您的网络包含批量归一化层,则最终验证指标可以与训练过程中评估出的验证指标不同。这是因为在训练完成后,用于批量归一化的均值和方差统计量可能会有所不同。例如,如果 ‘BatchNormalizationStatisics’ 训练选项为 ‘population’,则在训练后,软件通过再次使训练数据穿过来完成批量归一化统计,并使用得到的均值和方差。如果 ‘BatchNormalizationStatisics’ 训练选项为 ‘moving’,则软件在训练过程中使用运行估计来逼近统计量,并使用统计量的最新值。

在训练过程中绘制训练进度
训练网络并在训练过程中绘制训练进度。
加载训练数据,其中包含 5000 个数字图像。留出 1000 个图像用于网络验证。

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

构建网络以对数字图像数据进行分类。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

指定网络训练的选项。要在训练过程中按固定时间间隔验证网络,请指定验证数据。选择 ‘ValidationFrequency’ 值,以使网络大致在每轮训练都被验证一次。要在训练过程中绘制训练进度,请将 ‘training-progress’ 指定为 ‘Plots’ 值。

options = trainingOptions('sgdm', ...
    'MaxEpochs',8, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');

训练网络。

net = trainNetwork(XTrain,YTrain,layers,options);

在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qq-120

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值