超详细深度学习debug指南,国外小哥手把手教你如何调试模型 | 附PPT

晓查 发自 凹非寺
量子位 出品 | 公众号 QbitAI

已经学会深度学习,但你搭建的模型为什么还跑不动,到底哪里出了问题?

看懂了教材,一到编程调试就跪,为了寻找bug的你是否曾经手足无措?

虽然网络上深度学习的教材很多,但是手把手教你调试的技巧却不常见。

最近有人雪中送炭啦!一位来自伯克利的小哥Josh Robin分享了他的深度学习debug心得,从最简单模型开始一步步深入到复杂模型,希望能给刚上手的你一点帮助。

Josh在读博期间曾被debug折磨得很痛苦,他说自己花了大部分时间调试而不是在“有趣”的事情上。有一次,仅仅因为标签错误,Josh就整整花了一天才排查出来。

构思和写代码可能只花费10%~20%的时间,而debug和调试要消耗掉80%~90%的时间!所以这份秘籍可能比教材更常伴随着你。

为什么你的模型效果这么差?

由于深度学习模型的复杂性,按照书本知识来搭建模型,往往“理想很丰满,现实很骨感”。别人的模型都能快速达到较低的错误率,而你的模型错误率却居高不下。

640?wx_fmt=jpeg

 别人的曲线vs你的曲线

造出这种现象的原因可以分为4大类:

1、模型实现中的bug:比如前面说过的标签错误的问题。

2、超参数选择不合适:模型对超参数很敏感,学习率太高或太低都不行。

640?wx_fmt=png

 合适的学习率才能保证较低的错误率

3、数据模型不适配:比如你要训练一个自动驾驶图像识别的模型,用ImageNet数据集来训练就不合适。

640?wx_fmt=jpeg

 ImageNet中的图片vs自动驾驶汽车拍摄的图片

4、数据集的构造问题:没有足够数据、分类不均衡、有噪声的标签、训练集合测试集分布不同。

深度学习debug的流程策略

针对上面的问题,小哥总结出调试深度学习模型的第一要义——悲观主义

既然消除模型中的错误很难,我们不如先从简单模型入手,然后逐渐增加模型的复杂度。

他把这个过程分为5个步骤:

  1. 从最简单模型入手;

  2. 成功搭建模型,重现结果;

  3. 分解偏差各项,逐步拟合数据;

  4. 用由粗到细随机搜索优化超参数;

  5. 如果欠拟合,就增大模型;如果过拟合,就添加数据或调整。

从简单模型开始

在这一步之前,Josh假定你已经有了初始的测试集、需要改进的单一指标、基于某种标准的模型目标性能。

首先,选择一个简单的架构。比如,你的输入是图片就选择类似LeNet的架构,输入是语言序列就选择有一个隐藏层的LSTM。

640?wx_fmt=jpeg

 多输入模型

为了简化问题,我们从一个只有1万样本的数据集开始训练,数据的特点包括:固定数量的目标、分类、更小的图片尺寸。由此创建一个简单的合成训练集。

开始搭建深度学习模型

在搭建模型之前,Josh总结了实现(Implement)的5种最常见的bug:

错误的张量形状;预处理输入错误;损失函数错误输入;忘记设置正确的训练模型;错误的数据类型。

为了防止这些错误发生,Josh给出的建议是:尽可能减少代码的行数,使用现成的组件,然后再构建复杂的数据pipeline。

运行模型后,你可能会遇到形状不匹配、数据类型错误、内存不足等等问题。

对于第一个问题,可以在调试器中逐步完成模型创建和推理。数据类型错误是由于没有把其他类型数据转化成float32,内存不足是因为张量或者数据集太大。

评估

下面我们开始用错误率评估模型的性能。

测试集错误率 = 错误率下限 + 偏移 + 方差 + 分布偏差 + 验证集过拟合

为了处理训练集和测试集分布的偏差,我们使用两个验证数据集,一个样本来自训练集,一个样本来自测试集。

640?wx_fmt=png

改进模型和数据

上一步中粗略搭建的模型错误率仍然相当高,我们应该如何改进?

让我们先用以下方法解决欠拟合的问题:

让模型更大(比如加入更多的层,每层中使用更多的单元);减少正规化错误分析选择另一种性能更好的模型架构调节超参数加入更多特征

640?wx_fmt=jpeg

 解决欠拟合问题

首先,我们给模型加入更多的层,转换到ResNet-101,调节学习率,使训练集错误率降低到0.8%。

640?wx_fmt=jpeg

 把训练集错误率降低到目标值以内

在出现过拟合后,我们可以增加训练集的样本量解决这个问题,把图片数量扩大到25万张。

640?wx_fmt=jpeg

 解决过拟合问题

经历过优化参数、权重衰减、数据增强等一系列操作后,我们终于把测试错误率降低到目标值。

640?wx_fmt=jpeg

 目标错误率

接下来我们着手解决训练集和测试集的分布偏差问题。

分析测试验证集错误率,收集或者合成更多训练数据弥补二者的偏差。比如下面的自动驾驶目标识别模型,训练完成后,让它判断图片里有没有人,常常发生错误。

640?wx_fmt=jpeg

 分析自动驾驶数据集的分布偏差

经过分析得出,训练集缺乏夜晚场景、反光等情况。后续将在训练集中加入此类数据纠正偏差。

另一种修正错误率的方法称为领域适配,这是一种使用未标记或有限标记数据进行训练的技术。它能在源分布上进行训练,并将其推广到另一个“目标”。

超参数优化

这是调试的最后一步,我们需要选取那些更敏感的超参数,下图是模型对不同超参数的敏感性:

640?wx_fmt=png

 模型对不同超参数的敏感性

常用的超参数优化方法有:手动优化、网格搜索、随机搜索、由粗到细、贝叶斯优化。

640?wx_fmt=png

 由粗到细的随机搜索

你可以手动优化超参数,但是耗时而且需要理解算法的细节。Josh推荐的方法是由粗到细的随机搜索贝叶斯优化

由粗到细的随机搜索可以缩小超高性能参数的范围,缺点是由一些手动的操作。贝叶斯优化是优化超参数最有效一种无需手动干涉的方式,具体操作请参考:

https://towardsdatascience.com/a-conceptual-explanation-of-bayesian-model-based-hyperparameter-optimization-for-machine-learning-b8172278050f

最后附上本文提到的所有资源。

资源汇总

下载地址(需科学前往):

http://josh-tobin.com/assets/pdf/troubleshooting-deep-neural-networks-01-19.pdf

或者在我们的公众号中回复debug获取。

Josh在教程最后推荐了吴恩达的《Machine Learning Yearning》,这本书能帮你诊断机器学习系统中的错误:

http://www.mlyearning.org

另外杨百翰大学也有一篇搭建和调试深度学习模型的博客:

https://pcc.cs.byu.edu/2017/10/02/practical-advice-for-building-deep-neural-networks/


2018中国人工智能最受尊敬投资机构

640?wx_fmt=jpeg

加入社群

为给AI从业者提供更好的交流平台,量子位现开放「AI+教育」行业社群,欢迎小伙伴入群交流。


面向人群:AI+教育相关从业者,技术、产品等人员;


入群方式:请添加小助手7,微信号:qbitbot7,并发送‘教育群+您的姓名+公司+职位+工作内容简介’。


Ps.为保证社群价值,小助手会对申请入群的朋友进行审核,请大家理解!

诚挚招聘

量子位正在招募编辑/记者,工作地点在北京中关村。期待有才气、有热情的同学加入我们!相关细节,请在量子位公众号(QbitAI)对话界面,回复“招聘”两个字。

640?wx_fmt=jpeg

量子位 QbitAI · 头条号签约作者

վ'ᴗ' ի 追踪AI技术和产品新动态

喜欢就点「好看」吧 !



  • 7
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
要在Flask中部署深度学习API接口,你需要先训练好你的模型并将其保存为.h5或.pb等格式,然后将其加载到Flask应用程序中。以下是一些基本的步骤: 1. 安装 Flask 和需要的深度学习库,如 TensorFlow 或 PyTorch。 2. 在Flask应用程序中创建一个API端点,用于接收请求并返回预测结果。 3. 在API端点中加载你的模型,并准备输入数据。你可以使用 Flask 的 request 对象来获取请求数据。 4. 对输入数据进行必要的预处理,如缩放或标准化。 5. 使用加载的模型进行预测,并将结果返回给客户端。 6. 在Flask应用程序中设置必要的路由和视图函数,以便客户端可以访问API端点。 以下是一个简单的示例代码: ```python from flask import Flask, request, jsonify import tensorflow as tf import numpy as np # 加载模型 model = tf.keras.models.load_model('my_model.h5') # 创建 Flask 应用程序 app = Flask(__name__) # 定义 API 端点 @app.route('/predict', methods=['POST']) def predict(): # 获取请求数据 data = request.get_json(force=True) x = np.array(data['input']) # 预处理输入数据 x = x / 255.0 # 进行预测 y = model.predict(x) # 返回预测结果 return jsonify({'output': y.tolist()}) # 运行应用程序 if __name__ == '__main__': app.run(debug=True) ``` 在上面的示例中,我们创建了一个名为 `/predict` 的 API 端点,用于接收包含输入数据的 POST 请求,并返回预测结果。我们加载了一个名为 `my_model.h5` 的模型,并将其用于预测。预测完成后,我们将结果转换为 JSON 格式并返回给客户端。 请注意,这只是一个简单的示例,实际上你需要根据你的应用程序和模型进行相应的调整。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值