深度学习的模型是怎么训练/优化出来的

以典型的分类问题为例,来梳理模型的训练过程。训练的过程就是问题发现的过程,一次训练是为下一步迭代做好指引。

1.数据准备

准备:

  • 数据标注前的标签体系设定要合理
  • 用于标注的数据集需要无偏、全面、尽可能均衡
  • 标注过程要审核

整理数据集

  1. 将各个标签的数据放于不同的文件夹中,并统计各个标签的数目
    如:第一列是路径,最后一列是图片数目。
    0077idFnly1fyufdvsiexj30ht0430tm.jpg
    PS:可能会存在某些标签样本很少/多,记下来模型效果不好就怨它。
  2. 样本均衡,样本不会绝对均衡,差不多就行了
    如:控制最大类/最小类<\(\delta\)\(\delta=5\),最后一列为均衡的目标值。
    0077idFnly1fyufi7zpbpj30h80410tn.jpg
  3. 切分样本集
    如:90%用于训练,10%留着测试,比例自己定。训练集合,对于弱势类要重采样,最后的图片列表要shuffle;测试集合就不用重采样了。
    训练中要保证样本均衡,学习到弱势类的特征,测试过程要反应真实的数据集分布。
    第一列是图片路径,后面几列是标签(多任务)。
    0077idFnly1fyungcdj47j30lk02a3yh.jpg
    0077idFnly1fyungf6sjdj30lt05l76y.jpg

  4. 按需要的格式生成tfrecord
    按照train.list和validation.list生成需要的格式。生成和解析tfrecord的代码要根据具体情况编写。

2.训练

  • 预处理,根据自己的喜好,编写预处理策略。
    preprocessing的方法,变换方案诸如:随机裁剪、随机变换框、添加光照饱和度、修改压缩系数、各种缩放方案、多尺度等。进而,减均值除方差或归一化到[-1,1],将float类型的Tensor送入网络。
    这一步的目的是:让网络接受的训练样本尽可能多样,不要最后出现原图没问题,改改分辨率或宽高比就跪了的情况。
  • 网络设计,基础网络的选择和Loss的设计。
    基础网络的选择和问题的复杂程度息息相关,用ResNet18可以解决的没必要用101;还有一些SE、GN等模块加上去有没有提升也可以去尝试。
    Loss的设计,一般问题的抽象就是设计Loss数据公式的过程。比如多任务中的各个任务权重配比,centorLoss可以让特征分布更紧凑,SmoothL1Loss更平滑避免梯度爆炸等。
  • 优化算法
    一般来说,只要时间足够,Adam和SGD+Momentum可以达到的效果差异不大。用框架提供的理论上最好的优化策略就是了。
  • 训练过程
    finetune网络,我习惯分两步:首先训练fc层,迭代几个epoch后保存模型;然后基于得到的模型,训练整个网络,一般迭代40-60个epoch可以得到稳定的结果。
    0077idFnly1fyuo6jmd8pj30bt06ldg2.jpg
    total_loss会一直下降的,过程中可以评测下模型在测试集上的表现。真正的loss往往包括两部分。后面total_loss的下降主要是正则项的功劳了。
    0077idFnly1fyuo6mevu1j30kd06xdgi.jpg

3.评估模型

1.混淆矩阵必不可少
混淆矩阵可以发现哪些类是难区分的。基于混淆矩阵可以得到各类的准召,进而可以得到哪些类比较差。
如:列为真值,行为检测的值。

gt/pl靴子单鞋运动休闲棉鞋雪地靴帆布拖鞋凉鞋雨鞋
靴子4524453979125956020
单鞋514088154411591880436
运动3868172470218810
休闲53471718061781181512
棉鞋121105154245523211
雪地靴5365107362801321
帆布鞋52816158115151734
拖鞋6139112333182316606
凉鞋76936002556331
雨鞋2661301251499

进而可得:

label召回精度
靴子0.94466485696387560.947434554973822
单鞋0.91474602819422690.8996478873239436
运动0.71855760773966580.7614165890027959
休闲0.65105008077544420.5840579710144927
.........

PS:运动-休闲容易混淆。

2.抽样看测试数据
从测试数据中每类抽1000张,把它们的模型结果放在不同的文件夹下。对于分析问题还是很有效的,为什么它会分错,要拿出来看看!
有些确实是人工标错了。
0077idFnly1fyup6ysy8lj30di09jaao.jpg

3.CAM
通过CAM可以查看网络究竟学到了什么(是不是学错了)。对于细粒度问题就不用分析CAM了,一般7x7的特征图本来就很小了,根本就看不出细节学到了什么,只能粗略看看部位定位是否准确。
0077idFnly1fyup8uv56lj3068068gn0.jpg
也可以一定程度上帮助理解为什么网络会搞错,比如下面的单鞋被误判为了拖鞋。
0077idFnly1fyup93p1yvj30680680v3.jpg

  • 1
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值