[深度学习] - 通用网络模型训练过程的 loss 变化分析方法 (loss / val_loss / test_loss)

目录

一、train set 和 test set 基础知识

二、分析 loss 和 val_loss (test_loss) 变化情况


一、train set 和 test set 基础知识

  • train set:训练集是用来训练网络模型的数据集。
  • test set:测试集用来评估网络性能的数据集。默认测试集是不参与网络训练,仅用来测试网络性能。
  • 附:development set(也可称为 validation set),主要用来二次微调整、选择特征以及对学习算法作出其他优化的数据集。常规的训练集和测试集的比例为 0.7 : 0.3,引入验证集后常采用 0.6 : 0.2 : 0.2(现在使用较少)。在保证算法稳定的情况下,训练集和测试集的选取会对输出结果的指标造成不同层次的影响(主要还是数据集的分布特性影响)。

二、分析 loss 和 val_loss (test_loss) 变化情况

  • 通常回调显示的 loss 有很多种,如一个总 total_loss 多个子 sub_loss 。但本文主要分析最基础的训练情况(只有一个训练 loss,和一个验证 loss)。下文用 loss 代表训练集的损失值(墨守成规不写成 train_loss);val_loss 代表验证集的损失值(也写成 test_loss)。
  • 一般训练规律:
loss

val_loss

网络情况
下降下降
  • 网络训练正常,最理想情况情况。
下降稳定/上升
  • 网络过拟合。解决办法:
  • ①降低网络性能:在数据集没问题的前提下,向网络某些层的位置添加 Dropout 层(通常会选择较深的层,如一共 100 层,选择在 75 层;或者选择特征最多的层,如 Unet 的最底层等等);或者逐渐减少网络的深度(靠经验删除部分模块)。
  • ②修改数据集:数据集有明显错误则需要重做,问题小可尝试混洗数据集并重新分配,通常开源数据集不容易出现这种情况。
稳定下降
  • 数据集有严重问题,建议重新选择。一般不会出现这种情况。
快速稳定快速稳定
  • 如果数据集规模不小的情况下,代表学习过程遇到瓶颈,需要减小学习率(自适应动量优化器小范围修改的效果不明显)。
  • 其次考虑修改 batchsize 大小。
  • 如果数据集很规模很小的话代表训练稳定。
上升上升
  • 可能是网络结构设计问题、训练超参数设置不当、数据集需要清洗等问题。
  • 这种情况属于训练过程中最差情况,得一个一个排除问题。
注意:上面提到的“下降”、“稳定”和“上升”是指整体训练趋势。

第一次补充:

  • 一系列的损失值能反应网络在训练过程的“健康程度”(当某一刻损失发生突变,并且一直没恢复,那就说明不健康了)。
  • 权重文件可以每轮都保存,但是个人推荐选择最优评价指标保存。
  • 上面总结仅是广义规律,是用于辅助学习而总结,不保证适用于所有网络。
  • 当网络训练跑偏了后,先检查自己写的网络是否存在小错误。然后检查数据集是否有问题(如果是开源数据集,可以先分个小样本训练来检测网络是否设计问题。如果是自定义数据集,可用一些能处理通用任务的网络(如 Unet 或 ResNet)跑一下来检测)

第二次补充:

  • 为什么说现在有些网络不怎么使用“训练集(0.6)、验证集(0.2)、测试集(0.2)”的划分:其实在大规模的数据集中尤其是开源数据集,很多图像相关特征或者语义分布是集中的。理论上,验证集能一定程度的提高网络学习的全局最优,用来调整网络的学习率优化。但是当训练集和验证的特征分布充分接近时,验证的调整效果就会越来越不明显,所以甚至大家会在开源代码中发现还有使用“训练集(0.9)、测试集(0.1)”的划分。当然不是说验证集就没什么用,个人经验表示:可以向验证集添加多种 data augmentation 来优化训练(如:训练的 size:256x256,验证的 size:512*512。亦或者训练集的 data augmentation 和验证集不一样)。当然这些东西还需要大家去“炼丹”对比。

第三次补充:

  • 很多朋友都在问各种 loss 不正常的原由,其实在正文里面就写了:loss的变化情况只能得出一般规律。光凭 loss 变化没法得出一个确切的问题结论。文中提出的规律方便大家网络调参。一旦网络跑飞了需要结合 loss、评价指标、输出信息(或图像)等一系列结果综合分析。而本文不是提供错误分析的,对应的错误分析需要大家在互联网上搜索、找相关领域的专业人员提问。同时第二次补充的策略虽然有提高网络泛化的作用,但是在高性能网络中需要合理调参,要避免过拟合现象(吐槽:当网络规模远大于数据集规模时,是很容易过拟合的)。
在PyTorch中,`validate()` 函数通常用于评估模型在验证集上的性能。这里给出了`validate()`函数可能的调用方式以及涉及的关键参数: 1. `test_acc`: 测试准确率[^1]。这通常是通过计算预测类别与实际标签相符的情况来确定的。在每个批次的验证数据上,模型会应用`model(x)`得到预测结果,然后与`criterion`计算损失。最后,根据正确的分类数量除以总样本数,得出整个验证集的准确率。 ```python with torch.no_grad(): # 关闭梯度计算以节省内存 correct = 0 total = 0 for images, labels in val_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_acc = correct / total ``` 2. `test_acc_top5`: 测试前五名准确率。这表示对于每个样本,如果模型前五个预测中最少有一个与真实标签相同,则认为该样本被正确预测。实现方法可能包括对输出概率最高的前五个类进行判断。 ```python top5_correct = 0 for images, labels in val_loader: ... top5_pred = torch.topk(outputs, 5, dim=1)[1] top5_correct += (top5_pred[:, :labels.size(1)] == labels.unsqueeze(1)).any(dim=1).sum().item() test_acc_top5 = top5_correct / total ``` 3. `test_loss`: 测试损失。这是通过`criterion`函数计算模型输出与真实标签之间的差异,如交叉熵损失。在每个批次之后累加损失,最后除以总批次数得到平均损失。 ```python test_loss = 0. for images, labels in val_loader: ... loss = criterion(outputs, labels) test_loss += loss.item() test_loss /= len(val_loader) ``` 完整代码示例: ```python def validate(val_loader, model, criterion, opt=None): test_loss = 0. correct = 0 top5_correct = 0 with torch.no_grad(): for images, labels in val_loader: outputs = model(images) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) correct += (predicted == labels).sum().item() top5_pred = torch.topk(outputs, 5, dim=1)[1] top5_correct += (top5_pred[:, :labels.size(1)] == labels.unsqueeze(1)).any(dim=1).sum().item() test_loss += loss.item() test_loss /= len(val_loader) test_acc = correct / total test_acc_top5 = top5_correct / total if opt and opt.verbose: print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc * 100:.2f}%, Top5 Acc: {test_acc_top5 * 100:.2f}%") return test_loss, test_acc, test_acc_top5 ```
评论 29
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值