TowardsDataScience 2023 博客中文翻译(三百七十五)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

XGBoost:理论与超参数调优

原文:towardsdatascience.com/xgboost-theory-and-hyperparameter-tuning-bc4068aba95e

一个包含 Python 示例的完整指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jorge Martín Lasaosa

·发表于 Towards Data Science ·17 分钟阅读·2023 年 2 月 16 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Joanne FrancisUnsplash

介绍

几个月后,我将从事数据科学工作满 3 年。我知道这还不算长的职业生涯,但结合我的学术经验,我已经能够参与多个不同行业的机器学习项目(能源、客户体验等)。所有这些项目都使用了 表格数据,即结构化数据(按行和列组织)。相比之下,使用图像或文本等非结构化数据的项目则更多地与计算机视觉或自然语言处理(NLP)等机器学习领域相关。

根据我的经验,XGBoost 通常表格数据 项目中表现良好。尽管“无免费午餐定理” [1] 表明当将两种算法的表现平均到所有可能的问题上时,它们是等效的,但在 Bojan Tunguz 的 Twitter [2] 上,你可以阅读到与其他专业人士的频繁讨论,关于为何基于树的模型(尤其是 XGBoost)通常是解决 表格数据 项目最佳候选者,即便随着对深度学习技术应用于这种数据的研究不断增加。 [3]

而且,看到一位 Kaggle 大师 [4] 玩笑说自己是 XGBoost 的宣传者也挺有趣的。

Bojan Tunguz 的置顶推文。

尽管 XGBoost 取得了巨大的成功,但在过去当我想从模型中获得最佳性能时,我没有找到集中所有必要知识的完整指南。虽然有很多理论和实践上的解释,我将在阅读中参考这些内容,但我没有找到任何完整的指南提供整体视角。这就是我决定写这篇文章的原因。此外,我将把我在这里收集到的内容应用到一个著名的 Kaggle 比赛中,房价预测 — 高级回归技术

文章的其余部分分为两个部分:

  • 理论: 在简短的介绍后,我们将深入原始论文,以理解这个伟大模型背后的理论。接着,简要的视觉解释将帮助我们更好地理解理论。

  • 实践: 在概述 XGBoost 参数后,我将提供一个逐步指南来调整超参数。

除非另有说明,否则所有图片均由作者提供。

理论

XGBoost 代表 eXtreme Gradient Boosting,由 Tianqi Chen 和 Carlos Guestrin 于 2016 年正式发布 [5]。在发布之前,它已经被确立为 Kaggle 比赛中最优秀的算法之一。

尽管深度学习在计算机视觉和自然语言处理等领域取得了巨大成功,XGBoost 和其他基于树的模型(CatBoost [6] 或 LightGBM [7])仍然是预测表格数据 [8] 的最佳选项之一。所有这些基于树的算法都基于梯度提升,如果你想了解这种技术是如何工作的,我建议你查看我关于树集成的文章。[9]

## Tree Ensembles: Bagging, Boosting and Gradient Boosting

理论与实践详细解释

towardsdatascience.com

你想知道 XGBoost 有什么特别之处吗?我将通过两个不同的子部分来解释:原始论文视觉解释。让我们开始吧!

原始论文

如前所述,XGBoost 基于梯度提升,因此多个树是依次在前一棵树的残差上进行训练的。然而,有一些小的改进使得 XGBoost 能够通过防止过拟合超越现有的树集成模型:

  • 正则化学习目标(类似于RGF [10])。在梯度提升中,每棵树都以最佳方式训练,以实现学习目标:减少预测与目标之间的差异。在 XGBoost 中,这个学习目标被一个正则化学习目标取代,该目标在差异计算中添加了一个正则化项。用简单的话说,这个项增加了每棵树学习时的噪音,并旨在减少预测对单个观察值的敏感性。如果将正则化项设置为零,则目标会回到传统的梯度提升。

  • 收缩(借鉴自随机森林 [11])。一种限制每棵训练树在最终预测中权重的技术。因此,每棵树的影响被减少,未来的树有更多空间来改善预测。这类似于学习率(在参数部分也如此指定)。

  • 列子采样(借鉴自随机森林 [11])。它允许为每棵树、树级别和/或树节点随机选择一个特征子样本。因此,对第一棵树/级别/节点非常重要的特征,可能在第二棵树中不可用,从而推动使用其他特征。

树学习中的最大问题之一是找到最佳分裂。如果有一个从 0 到 100 变化的连续特征,那么应该使用什么值来进行分裂?20?32.5?70?为了找到最佳分裂,XGBoost 可以应用不同的算法:

  • 精确贪心算法。这是旧版树提升实现中最常用的算法,它包括测试所有特征的所有可能分裂。对连续特征进行这种操作计算量大,因此随着数据样本的增加,需要更多时间。

  • 加权分位数草图算法:根据特征分布的百分位数提出候选分裂点。该算法将连续特征映射到这些候选点划分的桶中,聚合统计数据,并根据聚合统计数据在提议中找到最佳解决方案。此外,它可以处理加权数据,并在每个树节点中包含一个默认方向,使算法能够识别数据中的稀疏模式。稀疏性可能由于缺失值的存在、频繁的零条目或特征工程的后果(例如使用独热编码)造成。

此外,该算法设计用于高效地与系统交互。无需详细说明,我可以指出数据存储在内存中的称为块的单元中,以帮助算法对数据进行排序。这种技术允许并行化。此外,在大数据集的情况下,定义了缓存感知访问,以避免在某些计算无法适应 CPU 缓存时搜索分裂的速度变慢。最后,它使用块压缩和块碎片化来处理不适合主内存的数据。

视觉解释

本节高度受到Josh Starmer 的 YouTube 频道和 Shreya Rao 的文章的启发。我的目标是通过一个小示例解释 XGBoost 的工作原理,并将其与之前看到的论文理论联系起来。为此,让我们按照以下步骤进行。

步骤 1:创建合成数据

使用了一个关于房价的小型合成数据集。它有两个特征(平方米数是否有车库?)和一个目标(价格)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

用于视觉解释的关于房价的合成数据。

步骤 2:计算残差

如你所见,残差已经在上表中计算出来。计算方法很简单,你只需将前一个树的预测价格从实际价格中减去(记住,XGBoost 会顺序训练多个树)。然而,这还是第一棵树,所以我们没有预测价格。在这种情况下,计算了价格的平均值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算第一个残差。

步骤 3:构建 XGBoost 的第一棵树

第一棵树将用所有的残差作为目标进行训练。因此,首先需要计算所有残差的相似性得分。这是树分裂旨在增加的得分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算第一个树的相似性得分。

你看到那个 lambda (λ) 字符了吗?如论文理论部分所述,这是一个正则化项,通过添加噪声来帮助防止过拟合。因此,XGBoost 的学习目标(增强相似性得分)实际上是一个正则化学习目标。XGBoost 中正则化项的默认值是 1。现在,让我们看看特征*是否有车库?*对相似性得分和增益的影响。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算了带有车库特征的分裂的相似性得分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

车库分裂的最终增益。

一旦计算了相似性得分,*是否有车库?*的特征增益为 7448.58。如果最终增益为正,则这是一个好的分裂,否则不是。由于它是正的,我们可以得出这是一个好的分裂。

对于像平方米这样的连续特征,过程略有不同。首先,我们需要找到连续特征的最佳分裂。为此,将使用论文理论部分中解释的精确贪婪算法。也就是说,将测试每一个可能的分裂。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对特征进行排序并找到所有可能的分裂(贪婪)。

如上图所示的排序在数据集很大的时候可能会计算成本高。正如论文理论部分评论的那样,XGBoost 使用块单元来允许并行化,并帮助解决这个问题。同时,请记住,XGBoost 可以使用加权分位数草图算法根据特征分布的百分位数来提议候选分裂点。这里不会详细解释,但这是 XGBoost 的主要优势之一。话虽如此,我们将计算每个可能分裂的相似度得分和增益。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个平方米分裂的增益计算。

平方米特征通过 175 的分裂具有最大的增益(甚至比是否有车库?的分裂还要大),因此它应该是第一个分裂。右叶子中只剩下一个残差(156)。所以我们来关注左叶子,寻找另一个分裂。经过所有计算,最佳的分裂再次使用平方米特征。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最佳第二次分裂在车库特征上。

经过新计算,右叶子可以再次分裂。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

完成所有可能分裂的最终树。

步骤 4:剪枝树

为了避免过拟合,还有一个叫做树剪枝的过程是 XGBoost 执行的。从下往上验证每个增益。怎么做?如果一个叫做gamma(γ)的 XGBoost 参数大于增益,则会移除分裂。XGBoost 中γ的默认值是 0,但设置为 200 以表示移除分裂的情况。为什么?因为在这种情况下,200(γ)大于最后的增益(198)。因此,树被剪枝,最终结果是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

剪枝后的第一棵树的最终版本。

步骤 5:使用树进行预测

为了进行预测,第一步是每个最终叶子中都有一个单一的值(输出值)。为此,我们使用以下公式:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算输出值(每个叶子的值)。

结果是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个叶子都有一个输出值的第一棵树。

最终,最终叶子的输出值可以用下面显示的公式进行新预测。i 是我们想要预测的观察值,prediction_t0 是第一次预测(观察价格的均值),ɛ 是学习率,leaf_i 是按照树的规则得到的观察值 i 的值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

请注意,学习率(ɛ),其默认值为 0.3,是收缩理论部分的解释。话虽如此,让我们预测所有观察值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个值的预测。

现在,我们使用预测来计算新的残差(Residuals_1):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第一棵树的残差。

正如我们在上面看到的,每个残差(除了第一个)都更接近零。这意味着第一棵树提供的信息改善了第一次预测(均值)。下一步是创建另一棵树,使残差减少的方向再进一步。

步骤 6:训练新树

如果我们想创建一棵新树,唯一需要做的就是用新的残差(Residuals_1)作为目标重新执行步骤 3-5。然后,最终预测将是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用两棵树计算最终预测。

XGBoost 可以通过这些步骤顺序训练树。就我个人而言,我发现视觉解释是理解模型及其理论的有效方法。我希望这对你也有帮助。不过,现在是时候进入实际部分了。

实践

首先,我想强调的是,尽管我不是 XGBoost 实践方面的专家,但我使用了一些证明成功的技术。因此,我认为我可以提供一个简要指南,帮助你调整超参数。

为了实现这一目标,必须将实践分为两个子部分。第一个部分将专注于定义算法的超参数并将其编译成一个简明的备忘单。第二个子部分称为实际操作,将提供一个逐步指南,带你完成 XGBoost 的训练过程。

超参数

所有在这里解释的信息都可以在官方 XGBoost 文档 [12] 中找到。一个通用的参数被称为booster。正如我在之前的文章 [9] 中解释的那样,梯度提升是一种不限于使用决策树的技术,它可以应用于任何模型。这就是为什么 XGBoost 接受booster参数的三种值:

  • gbtree:使用决策树的梯度提升(默认值)

  • dart:一种使用 Vinayak 和 Gilad-Bachrach(2015)[13] 提出的将深度神经网络社区中的 dropout 技术添加到提升树中的方法的决策树梯度提升。

  • gblinear:使用线性函数的梯度提升。

虽然gblinear是捕捉预测变量和结果之间线性关系的最佳选择,但基于决策树的提升器(gbtreedart)在捕捉非线性关系方面要好得多。由于gbtree是最常用的值,本文余下部分将使用它。如果你有兴趣使用其他方法,请查看文档[12],因为根据你的选择,参数会有所不同。

一旦我们将gbtree设置为booster,我们可以调整多个参数以获得最佳模型。这里解释了最重要的几个:

  • eta(又名学习率):在视觉解释部分显示为ɛ,它限制每棵训练树在最终预测中的权重,以使提升过程更具保守性。

  • gamma: 在视觉解释部分显示为γ,它标记了在树的叶节点上进行进一步划分所需的最小增益。

  • max_depth: 设置树的最大深度。

  • n_estimators: 训练的树的数量。

  • min_child_weight: 设置进行拆分时子节点所需的最小实例权重(残差之和)。

  • subsample: 设置在训练每棵树之前获得的样本百分比(随机)。

  • colsample_by[]: 一组用于列抽样的参数。抽样可以在每棵树(colsample_bytree)、每棵树中达到的深度级别(colsample_bylevel)或每次评估新拆分时(colsample_bynode)进行。注意,这些参数可以同时工作:如果每个参数的值为 0.5,并且你有 32 列,那么每个拆分将使用 4 列(32/ 2³)。

  • lambda(L2 正则化):在视觉解释中显示为λ。它通过增加分母平滑地减少输出值(视觉解释中的第 5 步)。[14]

  • alpha(L1 正则化):它通过强制输出值为 0 来减少输出值并促进稀疏性。[14]

  • tree_method: 设置树用于查找拆分的构建算法。正如论文理论部分所讨论的,可以使用精确或近似算法。在实践中,这个参数有 5 个可能的值:auto 让启发式算法从下列选项中选择最快的选项,exact 应用枚举所有拆分候选项的精确贪婪算法,approx 应用使用分位数草图和梯度直方图的近似贪婪算法,hist 使用近似算法的直方图优化版本,gpu_hist 使用hist的 GPU 实现。

  • scale_pos_weight: 它控制正负权重的平衡,这对于类别不平衡的情况非常有用。一个典型的值是sum(negative instances) / sum(positive instances)

  • max_leaves: 设置仅在未选择tree_method参数的exact值时要添加的最大节点数。

还有更多参数需要调整:updaterrefresh_leafprocess_typegrow_policymax_binpredictornum_parallel_treemonotone_constraintsinteraction_constraints。然而,前面的参数已经足以充分发挥 XGBoost 模型的作用,因此这些参数得到了说明。

话虽如此,调整解释的特征如何影响模型?下面显示的备忘单帮助我们理解以不同方式调整每个参数的效果。此外,它展示了特定调整如何影响结果,无论是理解方差和偏差如何改善还是恶化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

主要参数及其效果的备忘单。

免责声明:减少或增加值意味着从特定特征的完美偏差-方差平衡出发,而不是从默认值开始。此外,参数之间有很多相互作用,因此减少某些值并增加其他值可能会导致与表中解释的不同结果。

实践实施

对于实施,我们将使用两个资源。首先,我们将使用 Kaggle 竞赛中数据,该竞赛名为房价[15]。其次,由于本文的重点是 XGBoost,而不是机器学习项目中的其他任务,我们将借用 Nanashi 的笔记本[16]中的代码来进行数据处理

尽管你将在阅读过程中看到最相关的代码片段,但所有代码可以在我专门为 Medium 文章创建的GitHub 仓库中查看。

如前所述,我并不是 XGBoost 训练方面的专家,所以我鼓励你在评论区分享你的技巧和方法,甚至批评这种训练方式,如果你真的认为它有问题的话。然而,根据我的经验,遵循以下步骤是迭代改进模型的好方法。

步骤 0:数据读取和处理

如前所述,数据已经读取和处理。它被分为两部分:

  • 训练集(数据的 85%),包含 X_train 和 y_train。步骤 1–4 将使用此数据集来构建最佳的 XGBoost 模型。在构建过程中,将讨论平均测试得分,需要说明的是,这个“测试”得分实际上并不对应于测试集,而是对应于在交叉验证过程中生成的验证测试。

  • 测试集(数据的 15%),包含 X_test 和 y_test。它通常被称为保留集。当我们得到最佳的 XGBoost 模型时,我们应该检查模型在测试集上的表现是否与在训练集上的表现一致。

读取和处理来自房价 Kaggle 竞赛的数据的代码。

如果你想查看处理细节,请查看 GitHub 仓库。

步骤 1:创建基准

第一步是创建基准。为此,创建一个简单的 XGBoost 模型而不调整任何参数。为了更好地比较结果,使用了 5 折交叉验证[17]。所用的度量标准是平均绝对误差(MAE)。

创建 XGBoost 基准的代码

这给我们提供了 0.0961 的平均测试分数和 0.0005 的主要训练分数。看起来偏差较低但方差较高,这可以解释为存在过拟合。然而,基准只是一个参考,让我们继续进行,看看第二步接下来会有什么。

第 2 步:使用 GridSearchCV 改进基准

现在,我们将使用 GridSearchCV [18] 来搜索参数网格中的良好参数组合。在第一次尝试中,我们将使用接近 XGBoost 默认值的参数值:

使用 GridSearchCV 和 XGBoost 的代码

最佳组合是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

GridSearchCV 找到的最佳参数组合。

获得的最佳测试平均值为 0.0840。这比第一个模型获得的 0.0961 稍好,因此现在将其作为改进的指标。

第 3 步:单独调整参数

一旦我们超越了之前的基准分数,我们可以尝试了解如何单独调整参数影响结果。使用提供的代码,可以一次测试多个参数。例如,我们来看看这些参数值如何影响结果:

  • n_estimators: [125, 150, 175, 200, 225]

每个可能的值都用于交叉验证训练。返回的结果被编译并显示在下面的图中。

单独调整参数和绘制分数的代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

绘制估计器数量变化如何影响结果。

图表显示,如果 n_estimators 较高,平均测试分数可能会降低。然而,平均训练分数接近零,所以过拟合即将到来。

第 4 步:重复第 2 步和第 3 步。

利用第 3 步中收集的信息,我们可以重新定义 GridSearchCV 中的参数网格,尝试获得更好的模型。在这种情况下,测试了更高的 n_estimators。这意味着模型的复杂度将更高,因此在参数网格中,我们还包括了一些可以帮助避免过拟合的参数值(较低的 learning_rate、更高的 lambda、更高的 gamma……)。超参数部分定义的备忘单在这里可能会非常有用。

第二次 GridSearchCV 与 XGBoost 的代码。

使用这段代码,我们获得了新的最佳平均测试分数 0.0814(之前为 0.0840)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

第二次 GridSearchCV 找到的最佳参数组合。

第 2 步至第 3 步应根据需要重复进行,以改进模型。

第 5 步:使用测试集验证模型。

最后,我们需要使用步骤 0 中的测试集(留出测试,无验证测试)来验证所构建的模型。

计算留出预测 MAE 的代码。

获得的 MAE 为 0.0808。由于低于我们最好的 MAE(0.0814),我们可以说我们的模型泛化良好,并且已经得到了很好的训练。

结论

在本文中,我旨在提供一个使用 XGBoost 的全面指南。经过数小时对论文、指南、帖子和文章的研究,我相信我已完成一篇完整的文章,可以帮助全面理解和使用 XGBoost。

希望你觉得阅读有用且愉快。最重要的是,我很高兴收到任何形式的反馈。请随时分享你的想法!

参考文献

[1] Wolpert, D. H., & Macready, W. G. «优化的无免费午餐定理». IEEE Transactions on Evolutionary Computation (1997). ieeexplore.ieee.org/abstract/document/585893

[2] Bojan Tunguz Twitter 账户。 twitter.com/tunguz

[3] Raschka, Sebastian (2022). 针对表格数据的深度学习。 sebastianraschka.com/blog/2022/deep-learning-for-tabular-data.html

[4] Bojan Tunguz Kaggle 个人资料。 www.kaggle.com/tunguz

[5] Chen, T., & Guestrin, C. (2016 年 8 月). Xgboost: 一个可扩展的树提升系统。在 第 22 届 ACM SIGKDD 国际知识发现与数据挖掘大会论文集(第 785–794 页)。

[6] Dorogush, A.V.; Ershov, V.; Gulin, A. CatBoost: 支持类别特征的梯度提升». ArXiv:1810.11363, 24 (2018). arxiv.org/abs/1810.11363.

[7] Ke, G.; Meng, Q.; Finley, T; Wang, T; Chen, W; Ma, W; Ye, Q; Liu, T. «LightGBM: 一种高效的梯度提升决策树». 神经信息处理系统进展, 20 (2017). proceedings.neurips.cc/paper/2017/hash/6449f44a102fde848669bdd9eb6b76fa-Abstract.html.

[8] Tunguz, Bojan. twitter.com/tunguz/status/1620048813686923266?s=20&t=BxzKnvn7G0ieo1I7SfDnSQ

[9] Martín Lasaosa, Jorge. «树集成:Bagging、Boosting 和梯度提升». 在 Towards Data Science (Medium)。 (2022) medium.com/r/?url=https%3A%2F%2Ftowardsdatascience.com%2Ftree-ensembles-theory-and-practice-1cf9eb27781

[10] T. Zhang 和 R. Johnson. 使用正则化贪婪森林学习非线性函数。IEEE 计算机学会模式分析与机器智能汇刊,36(5),2014. ieeexplore.ieee.org/abstract/document/6583153

[11] Breiman, L. «随机森林». 机器学习 45, (2001): 5–32. doi.org/10.1023/A:1010933404324.

[12] XGBoost 文档(参数部分) xgboost.readthedocs.io/en/stable/parameter.html

[13] Vinayak, R. K.; Gilad-Bachrach, R. «Dart: Dropouts 遇见多重加法回归树». 收录于 人工智能与统计. PMLR. (2015) proceedings.mlr.press/v38/korlakaivinayak15.html

[14] Um, Albert. «XGBoost 回归中的 L1、L2 正则化» 收录于 Medium (2021) albertum.medium.com/l1-l2-regularization-in-xgboost-regression-7b2db08a59e0

[15] 房价 Kaggle 竞赛。 www.kaggle.com/competitions/house-prices-advanced-regression-techniques

[16] Nanashi 在 Kaggle 上的笔记本 www.kaggle.com/code/jesucristo/1-house-prices-solution-top-1

[17] Yiu, Tony. «理解交叉验证» 收录于 Towards Data Science (Medium) (2020) towardsdatascience.com/understanding-cross-validation-419dbd47e9bd

[18] «使用 GridSearchCV 调优超参数» 收录于 Great Learning (2022) www.mygreatlearning.com/blog/gridsearchcv/

YOLO-NAS:如何在目标检测任务中实现最佳性能

原文:towardsdatascience.com/yolo-nas-how-to-achieve-the-best-performance-on-object-detection-tasks-6b95347908d4

通过神经网络结构搜索、创新的量化块和强大的预训练范式生成的基础模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Thomas A Dorfer

·发布于数据科学前沿 ·阅读时间 7 分钟·2023 年 5 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Anubhav Saxena提供,Unsplash上可见。由作者使用 YOLO-NAS-L 处理。

在目标检测领域,YOLO(只需一次)已成为家喻户晓的名字。自 2015 年发布第一个模型以来,YOLO 家族一直在稳步增长,每个新模型在平均精度(mAP)和推理延迟上均超越了其前身。

两周前,YOLO 家族迎来了新成员:YOLO-NAS,这是由深度学习公司Deci开发的一种新型基础模型。

在本文中,我们将探讨它相对于之前 YOLO 模型的优势,并展示如何将其用于你自己的目标检测任务。

YOLO-NAS:新变化是什么?

尽管之前的 YOLO 模型在目标检测方面在创新和性能上领先,但它们也存在一些局限性。主要问题之一是缺乏适当的量化支持,这旨在减少模型的内存和计算需求。另一个问题是精度和延迟之间的权衡不足,其中一个的改进往往会导致另一个的显著下降。

通过利用一种称为**神经网络结构搜索(NAS)**的概念,Deci 的研究人员直接解决了这些局限性。实质上,NAS 的概念可以被视为对训练好的深度学习模型的一次改造。

传统上,神经网络架构由人类专家根据经验和直觉进行手动设计。然而,这一过程涉及探索可能架构的广阔设计空间,始终非常耗时且繁琐。

NAS,另一方面,自动重新设计模型架构,以提升其在速度、内存使用和吞吐量等方面的性能。它通常涉及一个定义可能架构选择的搜索空间,例如层数、层类型、卷积核大小和连接模式。搜索算法然后通过在给定任务和数据集上训练和评估不同架构来评估它们。基于这些评估,算法迭代地探索和完善架构空间,最终返回最佳性能的架构。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Google DeepMindUnsplash

为了执行 NAS,Deci 利用了其专有的AutoNAC技术,这是一种优化引擎,可以重新设计模型的架构,以最大限度地提高特定硬件的推理性能,同时保持准确性。

除了 NAS,这个新的 YOLO 成员的另一个重大改进涉及量化的使用。在这种情况下,量化指的是将神经网络的权重、偏置和激活从浮点值转换为整数值(INT-8),从而使模型更加高效。

这项工作有两个方面:(1) 该模型使用了适合量化的块,这些块结合了重新参数化和 INT-8 量化的优点。这些块采用了Chu 等人(2022)提出的方法,重新设计这些块,使其生成的权重和激活分布有利于量化。(2) 作者采用了一种混合量化方法,有选择地量化模型的特定层,从而最小化信息丢失,在延迟和准确性之间取得平衡。

这种新方法的结果不言而喻。如下图所示,量化后的中型模型YOLO-NAS-INT8-M在推理延迟方面提高了 50%,同时与最新的最先进模型相比,准确性提高了 1 mAP。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来源: Deci-AI。许可证: Apache License 2.0

在撰写时,已经发布了 YOLO-NAS 的三种模型:小型、中型和大型,每种模型都有一个量化的 INT-8 版本。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来源: Deci-AI。许可证: Apache License 2.0

毫不奇怪,量化版本的精度略有下降。然而,由于采用了这些新颖的量化友好型块以及选择性量化,这种精度下降仍然相对较小。此外,这里的优势远大于劣势,推理延迟显著改善。

YOLO-NAS 还在 COCO、Objects365 和 Roboflow 100 数据集上进行预训练,这使得它非常适合下游的物体检测任务。

预训练过程利用了知识蒸馏的概念,这使得模型可以从自身的预测中学习,而不仅仅依赖于外部标记数据,从而提高性能。在这种范式中,教师模型对训练数据生成预测,这些预测作为学生模型的指导(或软目标)。学生模型使用原始标记数据和教师模型生成的软目标进行训练。它基本上尝试模仿教师模型的预测,同时调整其参数以匹配原始标记数据。总体而言,这种方法使模型能够更好地泛化,减少过拟合,并实现更高的准确性,特别是当标记数据不丰富时。

训练过程通过加入**分布焦点损失(DFL)**进一步增强。DFL 是一种损失函数,扩展了焦点损失的概念,通过对难以分类的样本分配更高的权重来解决类别不平衡的问题。在物体检测的背景下,DFL 在训练过程中将框回归作为分类任务进行学习。它将边界框预测离散化为有限的选项,并对这些选项进行分布预测。最终的预测通过加权求和将这些分布结合起来。通过考虑类别分布并相应地调整损失函数,模型能够提高对欠代表类别的检测准确性。

最终,YOLO-NAS 已经以开源许可证发布,并且在 Deci 的基于 PyTorch 的计算机视觉库 SuperGradients 上提供了预训练权重供研究使用。

如何使用 YOLO-NAS

为了使用 YOLO-NAS 进行推理,我们需要首先安装 super_gradients 包:

pip install super-gradients

为了准备推理任务,让我们取一张样本图像,我们称之为 image.jpg

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

krakenimages 的照片,来源于 Unsplash

为了进行推断,我们可以使用以下代码片段:

import torch
from super_gradients.training import models

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
yolo_nas_s = models.get("yolo_nas_s", pretrained_weights="coco").to(device)
out = yolo_nas.predict("image.jpg")
out.save("image_yolo.jpg")

首先,我们需要从 super_gradients 库中导入 torchmodels。然后,我们声明一个变量 device,设置为使用第一个可用的 GPU(如果有的话),否则设置为使用 CPU。

随后,我们指定使用模型的小型版本 YOLO-NAS-S,并使用 COCO 数据集的预训练权重。此外,我们将检测到的物体保存到 image_yolo.jpg 中。

我们的输出图像如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

krakenimages 的照片,来源于 Unsplash。由作者使用 YOLO-NAS-S 处理。

我们可以看到各种物体在不同的置信度水平下被检测到。对于聚焦中的物体(如两个人、杯子和笔记本电脑),模型的置信度较高。然而,我们也观察到一些误分类,可能是由于物体失焦。这包括一个笔筒被误分类为盆栽植物,以及一支笔被误分类为牙刷。令人惊讶的是,我们还可以看到模型准确地检测到了仅部分可见的物体,如人物坐的椅子,只有靠背部分可见。

最后,值得一提的是,通过简单地将 predict() 调用的输入参数更改为相应的视频文件,可以以完全相同的方式进行视频目标检测。

结论

YOLO 系列又迎来了一员新成员——YOLO-NAS,它自豪地超越了 YOLOv6、YOLOv7 和 YOLOv8 等年轻兄弟。

通过神经架构搜索、量化支持以及包含知识蒸馏和分布式焦点损失的强大预训练程序的创新组合,YOLO-NAS 实现了精度和推理延迟之间的显著权衡。

考虑到计算机视觉和目标检测领域的快速发展,另一款 YOLO 模型很可能很快会面世。

更多资源

喜欢这篇文章吗?

让我们联系吧!你可以在 TwitterLinkedInSubstack 找到我。

如果你想支持我的写作,可以通过Medium 会员来做到,这将使你可以访问我所有的故事以及 Medium 上成千上万其他作家的故事。

[## 通过我的推荐链接加入 Medium - 托马斯·A·多费尔

阅读托马斯·A·多费尔(Thomas A Dorfer)以及成千上万其他作者在 Medium 上的每个故事。你的会员费直接支持…

medium.com

Raspberry Pi 上的 YOLO 目标检测

原文:towardsdatascience.com/yolo-object-detection-on-the-raspberry-pi-6de3629256fa

在低功耗设备上运行目标检测模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Dmitrii Eliuseev

·发布于 Towards Data Science ·阅读时长 9 分钟·2023 年 7 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

YOLO 目标检测结果,作者提供的图片

在这篇文章的第一部分中,我测试了 YOLO(You Only Look Once)的“复古”版本,这是一个流行的目标检测库。仅使用 OpenCV 运行深度学习模型,而不使用 PyTorch 或 Keras 等“重型”框架,对低功耗设备来说是很有前景的,我决定深入探讨这个话题,看看最新的 YOLO v8 模型在 Raspberry Pi 上的表现如何。

让我们深入了解一下。

硬件

在云端运行任何模型通常没有问题,因为资源几乎是无限的。但对于“现场”的硬件,限制就多得多。有限的 RAM、CPU 功率,甚至不同的 CPU 架构,较旧或不兼容的软件版本,缺乏高速互联网连接等等。云基础设施的另一个大问题是其成本。假设我们正在制作一个智能门铃,并且我们想添加人脸检测功能。我们可以在云端运行模型,但每次 API 调用都需要付费,那么谁来支付呢?并不是每个客户都愿意为门铃或任何类似的“智能”设备支付月费,因此在本地运行模型可能是必要的,即使结果可能不是最佳的。

在这个测试中,我将会在 Raspberry Pi 上运行 YOLO v8 模型:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Raspberry Pi 4,图片来源 en.wikipedia.org/wiki/Raspberry_Pi

Raspberry Pi 是一种便宜的信用卡大小的单板计算机,运行 Raspbian 或 Ubuntu Linux。我将测试两个不同的版本:

  • 2015 年制造的 Raspberry Pi 3 Model B。它有一个 1.2 GHz Cortex-A53 ARM CPU 和 1 GB 的 RAM。

  • Raspberry Pi 4,制造于 2019 年。它有一个 1.8 GHz Cortex-A72 ARM CPU 和 1、4 或 8 GB 的 RAM。

目前,Raspberry Pi 计算机被广泛使用,不仅用于爱好和 DIY 项目,还用于嵌入式工业应用(Raspberry Pi Compute Module 专门为此设计)。因此,了解这些板子如何处理诸如物体检测这样的计算密集型操作是很有趣的。在所有进一步的测试中,我将使用这张图像:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

测试图像,由作者制作

现在,让我们看看它是如何工作的。

“标准” YOLO v8 版本

作为热身,让我们尝试标准版本,如在官方 GitHub 页面上描述的

from ultralytics import YOLO
import cv2
import time

model = YOLO('yolov8n.pt')

img = cv2.imread('test.jpg')

# First run to 'warm-up' the model
model.predict(source=img, save=False, save_txt=False, conf=0.5, verbose=False)

# Second run
t_start = time.monotonic()
results = model.predict(source=img, save=False, save_txt=False, conf=0.5, verbose=False)
dt = time.monotonic() - t_start
print("dT:", dt)

# Show results
boxes = results[0].boxes
names = model.names
confidence, class_ids = boxes.conf, boxes.cls.int()
rects = boxes.xyxy.int()
for ind in range(boxes.shape[0]):
    print("Rect:", names[class_ids[ind].item()], confidence[ind].item(), rects[ind].tolist())

在“生产”系统中,可以通过摄像头获取图像;对于我们的测试,我使用了之前描述的“test.jpg”文件。我还执行了“predict”方法两次,以使时间估计更准确(第一次运行通常需要更多时间来“热身”和分配所有所需的内存)。Raspberry Pi 在没有显示器的“无头”模式下工作,因此我使用控制台作为输出;这是大多数嵌入式系统的标准工作方式。

Raspberry Pi 3 上运行 32 位操作系统时,此版本无法使用:pip 无法安装“ultralytics”模块,原因如下错误:

ERROR: Cannot install ultralytics

The conflict is caused by:
    ultralytics 8.0.124 depends on torch>=1.7.0

结果发现 PyTorch 仅适用于 ARM 64 位操作系统。

Raspberry Pi 4 上运行 64 位操作系统时,代码确实可以运行,计算时间约为 0.9 秒。

控制台输出如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我还在桌面 PC 上进行了相同的实验以可视化结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

YOLO v8 Nano 检测结果,由作者提供的图像

正如我们所见,即使是“nano”尺寸的模型,结果也相当不错。

Python ONNX 版本

ONNX (开放神经网络交换) 是一种用于表示机器学习模型的开放格式。它也得到 OpenCV 的支持,因此我们可以轻松地以这种方式运行我们的模型。YOLO 开发人员已经提供了一个 命令行工具 来进行这种转换:

yolo export model=yolov8n.pt imgsz=640 format=onnx opset=12

这里,“yolov8n.pt”是一个 PyTorch 模型文件,将被转换。文件名中的最后一个字母“n”表示“nano”。提供了不同的模型(“n”——nano,“s”——small,“m”——medium,“l”——large),显然,对于 Raspberry Pi,我将使用最小且最快的一个。

转换可以在桌面 PC 上完成,并且可以使用“scp”命令将模型复制到 Raspberry Pi 上:

scp yolov8n.onnx pi@raspberrypi:/home/pi/Documents/YOLO

现在我们准备好准备源代码了。我使用了 Ultralytics 仓库 中的一个示例,并对其进行了稍微的修改以在 Raspberry Pi 上运行:

import cv2
import time

model: cv2.dnn.Net = cv2.dnn.readNetFromONNX("yolov8n.onnx")
names = "person;bicycle;car;motorbike;aeroplane;bus;train;truck;boat;traffic light;fire hydrant;stop sign;parking meter;bench;bird;" \
        "cat;dog;horse;sheep;cow;elephant;bear;zebra;giraffe;backpack;umbrella;handbag;tie;suitcase;frisbee;skis;snowboard;sports ball;kite;" \
        "baseball bat;baseball glove;skateboard;surfboard;tennis racket;bottle;wine glass;cup;fork;knife;spoon;bowl;banana;apple;sandwich;" \
        "orange;broccoli;carrot;hot dog;pizza;donut;cake;chair;sofa;pottedplant;bed;diningtable;toilet;tvmonitor;laptop;mouse;remote;keyboard;" \
        "cell phone;microwave;oven;toaster;sink;refrigerator;book;clock;vase;scissors;teddy bear;hair dryer;toothbrush".split(";")

img = cv2.imread('test.jpg')
height, width, _ = img.shape
length = max((height, width))
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = img
scale = length / 640

# First run to 'warm-up' the model
blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
model.setInput(blob)
model.forward()

# Second run
t1 = time.monotonic()
blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
model.setInput(blob)
outputs = model.forward()
print("dT:", time.monotonic() - t1)

# Show results
outputs = np.array([cv2.transpose(outputs[0])])
rows = outputs.shape[1]

boxes = []
scores = []
class_ids = []
output = outputs[0]
for i in range(rows):
    classes_scores = output[i][4:]
    minScore, maxScore, minClassLoc, (x, maxClassIndex) = cv2.minMaxLoc(classes_scores)
    if maxScore >= 0.25:
        box = [output[i][0] - 0.5 * output[i][2], output[i][1] - 0.5 * output[i][3],
               output[i][2], output[i][3]]
        boxes.append(box)
        scores.append(maxScore)
        class_ids.append(maxClassIndex)

result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)
for index in result_boxes:
    box = boxes[index]
    box_out = [round(box[0]*scale), round(box[1]*scale),
               round((box[0] + box[2])*scale), round((box[1] + box[3])*scale)]
    print("Rect:", names[class_ids[index]], scores[index], box_out)

正如我们所见,我们不再使用 PyTorch 和原始 Ultralytics 库了,但所需的代码量增加了。我们需要将图像转换为 blob,这对于 YOLO 模型是必需的。在打印结果之前,我们还需要将输出的矩形转换回原始坐标。但作为一个优点,这段代码在“纯” OpenCV 上运行,不依赖任何额外的库。

Raspberry Pi 3上,计算时间为 28 秒。为了好玩,我还加载了“medium”模型(这是一个 101 MB 的 ONNX 文件!)看看会发生什么。令人惊讶的是,应用程序没有崩溃,但计算时间为 224 秒(几乎 4 分钟)。显然,2015 年的硬件不适合运行 2023 年的 SOTA 模型,但看到它如何工作的过程还是很有趣的。

Raspberry Pi 4上,计算时间为 1.08 秒。

C++ ONNX 版本

最后,让我们尝试使用工具集中最强大的武器,并用 C++ 编写相同的代码。但在此之前,我们需要为 C++ 安装 OpenCV 库和头文件。最简单的方法是运行类似“sudo apt install libopencv-dev”的命令。但至少对于 Raspbian,这种方法不起作用。通过“apt”获得的最新版本是 4.2,而加载 YOLO 模型的最低 OpenCV 要求是 4.5。因此,我们需要从源代码构建 OpenCV。

我将使用 OpenCV 4.7,这是我在 Python 测试中使用的版本:

sudo apt update
sudo apt install g++ cmake libavcodec-dev libavformat-dev libswscale-dev libgstreamer-plugins-base1.0-dev libgstreamer1.0-dev 
sudo apt install libgtk2.0-dev libcanberra-gtk* libgtk-3-dev libpng-dev libjpeg-dev libtiff-dev
sudo apt install libxvidcore-dev libx264-dev libgtk-3-dev libgstreamer1.0-dev gstreamer1.0-gtk3

wget https://github.com/opencv/opencv/archive/refs/tags/4.7.0.tar.gz
tar -xvzf 4.7.0.tar.gz
rm 4.7.0.tar.gz
cd opencv-4.7.0
mkdir build && cd build

cmake -D WITH_QT=OFF -D WITH_VTK=OFF -D CMAKE_BUILD_TYPE=RELEASE -D CMAKE_INSTALL_PREFIX=/usr/local -D WITH_FFMPEG=ON -D PYTHON3_PACKAGES_PATH=/usr/lib/python3/dist-packages -D BUILD_EXAMPLES=OFF ..
make -j2 && sudo make install && sudo ldconfig

Raspberry Pi 不是世界上最快的 Linux 计算机,编译过程大约需要 2 小时。对于拥有 1 GB RAM 的 Raspberry Pi 3,交换文件大小应该增加到至少 512 MB;否则,编译将失败。

C++ 代码本身很简短:

#include <opencv2/opencv.hpp>
#include <vector>
#include <ctime>
#include "inference.h"

int main(int argc, char **argv) {
    Inference inf("yolov8n.onnx", cv::Size(640, 640), "", false);

    cv::Mat frame = cv::imread("test.jpg");

    // First run to 'warm-up' the model
    inf.runInference(frame);

    // Second run
    const clock_t begin_time = clock();

    std::vector<Detection> output = inf.runInference(frame);

    printf("dT: %f\n",  float(clock() - begin_time)/CLOCKS_PER_SEC);

    // Show results
    for (auto &detection : output) {
        cv::Rect box = detection.box;

        printf("Rect: %s %f: %d %d %d %d\n", detection.className.c_str(), detection.confidence,
                                             box.x, box.y, box.width, box.height);        
    }

    return 0;
}

在这段代码中,我使用了 Ultralitics GitHub 仓库中的“inference.h”和“inference.cpp”文件,这些文件应该放在同一个文件夹中。我还像以前的测试一样执行了“runInference”方法两次。我们现在可以使用以下命令编译源代码:

c++ yolo1.cpp inference.cpp -I/usr/local/include/opencv4 -L/usr/local/lib -lopencv_core -lopencv_dnn -lopencv_imgcodecs -lopencv_imgproc -O3 -o yolo1

结果令人惊讶。C++ 版本的速度比以前的版本明显 !在Raspberry Pi 3上,执行时间为 110 秒,比 Python 版本长了 3 倍以上。在Raspberry Pi 4上,计算时间为 1.79 秒,比 Python 版本长了约 1.5 倍。总的来说,很难说清楚原因。Python 的 OpenCV 库是通过 pip 安装的,而 C++ 的 OpenCV 是从源代码构建的,也许某些 ARM CPU 优化没有启用。如果有读者知道原因,请在下方评论中写明。无论如何,看到这样的效果是很有趣的。

结论

我可以“有根据地猜测”大多数数据科学家和数据工程师都在云端或至少在高端设备上使用他们的模型,并且从未尝试过在嵌入式硬件上“实地”运行代码。本文的目的是为读者提供一些关于它是如何工作的见解。在这篇文章中,我们尝试在不同版本的 Raspberry Pi 上运行 YOLO v8 模型,结果相当有趣。

  • 在低功耗设备上运行深度学习模型可能是一个挑战。即使是目前最好的基于 Raspbian 的模型——Raspberry Pi 4,在使用 YOLO v8 Tiny 模型时也只能提供大约~1 FPS。当然,还有改进的空间。一些优化可能是可行的,例如将模型转换为 FP16(精度较低的浮点格式)或甚至 INT8 格式。最后,也可以使用一个在有限数据集上训练的更简单的模型。最后但同样重要的是,如果仍然需要更多的计算能力,可以在像 NVIDIA Jetson Nano 这样的特殊单板计算机上运行代码,它支持 CUDA,并且可以更快。

  • 在本文开头,我写道“仅使用 OpenCV 运行深度学习模型,而不依赖像 PyTorch 或 Keras 这样沉重的框架,对低功耗设备是有前景的”。实际上,结果是 PyTorch 是一个有效且高度优化的框架。基于 PyTorch 的原始 YOLO 版本是最快的,而 OpenCV ONNX 代码慢了 10–20%。但在撰写本文时,PyTorch 在 32 位 ARM CPU 上不可用,因此在某些平台上可能没有其他选择。

  • C++版本的结果则更为有趣。正如我们所见,开启适当的优化可能是一个挑战,特别是对于嵌入式架构而言。在不深入探讨这些细节的情况下,定制的 OpenCV C++代码可能比板制造商提供的 Python 版本运行得更慢。

感谢阅读。如果有人有兴趣在相同硬件或 NVIDIA Jetson Nano 板上测试 FP16 或 INT8 YOLO 模型,请在评论中留言,我会在下一部分文章中讨论这一点。

如果你喜欢这个故事,欢迎 订阅 Medium,这样你会在我的新文章发布时收到通知,并且可以全面访问其他作者的成千上万的故事。

你不能踏入同一条河流两次

原文:towardsdatascience.com/you-cant-step-in-the-same-river-twice-cfacf7cee305

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:弗拉基斯拉夫·巴比恩科Unsplash

《为何的书》第 7 和 8 章,阅读系列

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 朱子景博士

·发表于Towards Data Science ·18 分钟阅读·2023 年 12 月 20 日

在我的之前的文章中,我们了解了观察数据中的混杂因素和碰撞器,这些因素阻碍了可靠因果关系的建立。珀尔提供的解决方案是绘制因果图,并使用后门准则找到需要阻断的混杂因素集合,留下碰撞器和中介变量。

然而,当处理那些无法观察或测量的混杂变量时,从观察数据中估计因果关系变得困难。为应对这个问题,在《为何的书》第七章中,朱迪亚·珀尔介绍了do-calculus 规则。这些规则对于前门准则工具变量特别有用。即使存在不可观察的混杂变量,它们也可以用来建立因果关系。

在第八章中,我们将探索反事实的奇妙世界。以诗人罗伯特·弗罗斯特的名句开篇:

“而且对不起,我不能两者兼得

“而且做一个旅行者,我站了很久……”

珀尔表示,尽管不可能走两条路径或踏入同一条河流两次,我们的大脑可以想象如果我们选择了另一条路径会发生什么。为了明确并传递智慧给机器人,珀尔介绍了必要原因充分原因的区别,以及如何利用结构因果模型系统地进行反事实分析。

随着章节的深入,内容变得更加技术化和信息密集。在接下来的部分,我将首先讨论如何处理未观察到的混杂因素,不幸的是,涉及一些数学内容,针对第二级干预。然后,我将讨论反事实,作为第三级应用。

前门标准

从因果图开始,我们用来理解 X 对 Y 的因果影响:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者基于《为什么的书》第七章制作的图像

在这里,X 通过中介 M 影响 Y。然而,我们无法直接从数据中估计因果关系而不控制混杂因素 U。“U -> X” 的后门路径会产生 X 与 Y 之间的虚假相关性。后门标准告诉我们要控制 U,但如果 U 是不可观测的呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Tobias Tullius 提供的照片,来源于 Unsplash

例如,在分析烟雾(X)与癌症(Y)之间的因果关系时,我们可能会看到这样的因果图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者基于《为什么的书》第七章制作的图像

在这里,吸烟通过焦油积累引发癌症,还有一个混杂因素——吸烟基因,正如一些研究人员所言,它可能会影响一个人的吸烟行为和患肺癌的几率。我们无法收集这些基因的数据,因为我们不知道它们是否存在。因此,后门调整在这种情况下无法起作用。

为了获得因果影响,我们可以使用前门标准。这里的前门是中介过程,吸烟增加焦油沉积,进而增加患癌症的几率。如果我们不能直接估计烟雾对癌症的影响,我们能否通过焦油对癌症的影响来估计烟雾对焦油的因果影响?步骤如下:

步骤 1: 吸烟 -> 焦油

吸烟到焦油的唯一后门路径:

吸烟 <- 吸烟基因 -> 癌症 <- 焦油

它被 Collider 癌症阻挡了。因此,我们可以通过计算条件概率直接从数据中估计吸烟对焦油的因果影响:

The causal impact of smoking on tar is:
P(tar|smoking) - P(tar|no smoking)

步骤 2: 焦油 -> 癌症

焦油到癌症的一个后门路径:

焦油 <- 吸烟 <- 吸烟基因 -> 癌症

在这里,吸烟和吸烟基因都是混杂因素,但我们可以控制其中一个来阻断路径。由于我们没有吸烟基因的数据,我们可以控制吸烟:

The causal impact of tar on cancer is:
P(cancer|do(tar)) - P(cancer|do(no tar))

where,

P(cancer|do(tar)) = P(cancer|tar,smoking) * P(smoking) + 
                    P(cancer|tar,no smoking) * P(no smoking)

P(cancer|do(no tar)) = P(cancer|no tar,smoking) * P(smoking) + 
                       P(cancer|no tar,no smoking) * P(no smoking)

为了估计焦油对癌症的因果影响,我们从数据中测量以下四个概率:

  • 吸烟人群中通过积累焦油患癌症的概率:

P(癌症|焦油, 吸烟) * P(吸烟)

  • 在非吸烟人群中,通过积累焦油患癌症的概率:

P(癌症|焦油, 不吸烟) * P(不吸烟)

  • 在吸烟人群中,没有足够焦油的癌症发生概率:

P(cancer|no tar, smoking) * P(smoking)

  • 在非吸烟人群中,没有足够焦油的癌症发生概率:

P(cancer|no tar, no smoking) * P(no smoking)

步骤 3:吸烟 -> 焦油 -> 癌症

一旦我们知道吸烟对焦油的因果影响以及焦油对癌症的因果影响,我们可以通过前门调整推导出吸烟对癌症的无偏因果影响:

The causal impact of smoking on cancer is:
**P(cancer|do(smoking)) = P(cancer|do(tar)) * P(tar|do(smoking)) +
                        P(cancer|do(no tar)) * P(no tar|do(smoking))**

######**Math Alert**########################
Since no backdoor between smoking and tar:

**P(tar|do(smoking)) = P(tar|smoking) 
& 
P(no tar|do(smoking)) = P(no Tar|smoking)** 
And from back door adjustment for tar and cancer:

**P(cancer|do(tar)) = P(cancer|tar,smoking) * P(smoking) + 
                    P(cancer|tar,no smoking) * P(no smoking) 
&
P(cancer|do(no tar)) = P(cancer|no tar,smoking) * P(smoking) + 
                       P(cancer|no tar,no smoking) * P(no smoking)** 
Finally,
**P(cancer|do(smoking)) = (P(cancer|tar,smoking) * P(smoking) + 
                        P(cancer|tar,no smoking) * P(no smoking))
                        * P(tar|smoking) 
                        +
                        (P(cancer|no tar,smoking) * P(smoking) + 
                        P(cancer|no tar,no smoking) * P(no smoking))
                        * P(no tar|smoking)**

或者,更一般地,对于任何类似于之前的因果图的图:

  • X 仅通过中介变量 M 影响 Y,而中介变量 M 也是前门;

  • 存在一个不可观测的混杂因素 U,它与 X 和 Y 相关,但与中介变量 M 无关;

使用前门调整的以下公式有效:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

前门调整

与后门调整相比:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

后门调整

这两个公式都估计 X 对 Y 的因果影响,它们都成功去除了do操作符。这意味着我们可以从数据中估计因果影响,即从第一层数据得出第二层和第三层结论。在前门调整公式中,我们也成功移除了未观测的混杂因素 U。在吸烟对癌症的案例中,我们可以在不包括吸烟基因的情况下估计吸烟的影响。

前门准则为从观察数据中挤出更多信息提供了可能性,即使面对不可测量的混杂因素。然而,由于现实世界通常比教科书中的因果图更复杂,应用过程中仍然存在挑战。例如,未观测的混杂因素也可能影响中介变量:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者根据《为何的书》第七章制作

即使存在这种复杂情况,只要 U 与 M 之间的关系较弱,使用“金标准”随机对照试验(RCT)估计 P(Y|do(X)) 作为基准,实证研究表明使用前门调整提供了比后门调整更好的估计,即使没有阻挡所有必要的后门。

Do-calculus 规则

使用上面的例子来说明前门调整,Pearl 还总结了三个规则,提供了移除 do 的一般指导:

  • 规则 1 表示如果变量 W 对 Y 无关,或其通往 Y 的直接路径被变量 Z 阻塞,那么移除或添加 W 将不会改变以下概率:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Do-calculus 的规则 1

  • 规则 2 表示如果变量集 Z 阻塞了 X 到 Y 的所有后门路径,那么在 Z 的条件下将移除do操作:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Do-calculus 的规则 2

  • 规则 3 表示如果 X 到 Y 没有因果路径,那么我们可以完全移除do操作:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

do-演算的规则 3

上述三条规则为***do-***演算奠定了基础,使我们能够从观察数据中推导出第 2 和第 3 层次的因果影响。规则 1 展示了从数据中收集哪些变量是有用的;规则 2 展示了如何从观察数据中推导干预,即第 2 层次的结论;规则 3 展示了干预是否有效。

更多数学警告:如果数学让你头痛,可以跳过

使用这三条规则,我们可以重新审视吸烟对癌症的因果图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者基于《为什么的书》第七章的图像

并查看我们如何在更一般的情况下利用规则 2 和规则 3。步骤如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于《为什么的书》第七章的 7 个步骤

在这里,为了理解吸烟如何因果影响癌症,我们使用了前门路径吸烟 -> tar -> 癌症,共 7 步:

  • 第 1 步基于概率理论,我们引入中介变量 t(tar),以估计因果影响;

  • 第 2 步基于规则 2。由于 tar ->癌症的后门被 s(吸烟)阻塞,我们可以用(do(s), do(t))替代(do(s), t);

  • 第 3 步再次基于规则 2。由于吸烟 -> tar 没有后门,我们可以用 s 替代 do(s);

  • 第 4 步基于规则 3。由于吸烟仅通过 tar 对癌症产生因果影响,我们可以用 P(c|do(t))替代 P(c|do(s), do(t));

  • 第 5 步基于概率理论,我们在方程中引入了不同的 S 谱。在这种情况下,我们有吸烟与非吸烟的对比;

  • 第 6 步基于规则 2。同样,tar ->癌症的后门被 s 阻塞,所以我们用(t, s’)替代(do(t), s’);

  • 第 7 步基于规则 3。由于 tar 对吸烟没有因果影响,我们用 P(s’|do(t))替代 P(s’)

在最后一个方程中,我们可以看到我们已经完全移除了 do 操作,并且没有不可观测的变量。下一步将使用数据来计算因果影响。

工具变量

处理未观测混杂因素的另一种方法是寻找工具变量。工具变量的定义在因果图中得到更好的说明:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者基于《为什么的书》第七章的图像

假设在估计 X 对 Y 的因果影响时,存在一个不可观测的混杂因素 U,阻碍我们得到正确的估计。如果存在一个变量 Z 满足:

  • Z 和 U 不相关。在图中,Z 和 U 之间没有箭头;

  • Z 是 X 的直接原因;

  • Z 仅通过 X 影响 Y。换句话说,除了 Z->X->Y 之外,没有直接或间接的箭头从 Z 到 Y。

如果所有条件满足,Z 将是一个很好的工具变量。工具变量在许多科学领域中非常有用。在书中,Pearl 提供了一个使用工具变量研究临床试验中药物治疗效果的例子。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由AbsolutVision拍摄,发布在Unsplash上。

这药物是为了降低患者的胆固醇水平而发明的。虽然临床试验通常是随机的,但它们仍然面临非依从性的挑战,即受试者接受了药物但选择不服用。

他们不服药的决定可能取决于多个因素,如他们的病情,并且通常难以观察或测量。非依从性的存在会降低药物效果的估计,我们实际上没有很好的方法来预测临床试验中会有多少非依从性。

在这种情况下,研究人员在 RCT 设计中引入了工具变量“分配”。如果患者被随机分配接受药物,则“分配”的值为 1;如果他们接受的是安慰剂,则值为 0。我们将得到以下因果图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者根据《为何之书》第七章制作的图片

“分配”是一个工具变量,因为:

  • 药物的分配在患者中是随机的,因此它与任何混杂变量 U 没有相关性;

  • 患者被分配到哪个组将决定他们接受的治疗,无论是药物还是安慰剂。因此,“分配”是“接收”的直接原因;

  • 患者是否被分配到安慰剂组不会直接影响其胆固醇水平。因此,“分配”仅通过“接收”影响胆固醇。

当我们找到或建立一个工具变量时,可以在观察数据中估计三个关系:

  • 分配对接收的因果影响;

  • 分配对胆固醇的因果影响;

  • 通过从“分配 -> 胆固醇”中去除“分配 -> 接收”影响,得到接收对胆固醇的因果影响。

这就是工具变量的定义及其示例。Pearl 的书中包含了更多关于如何使用工具变量来解决未观察混杂因素问题的例子。

反事实:可能发生了什么?

从因果关系阶梯的第 2 层到第 3 层,我们现在面临的问题是找到没有治疗情况下可能的结果。这与第 2 层的干预在两个方面有所不同:

  1. 从平均因果效应到个体因果效应:

到目前为止,我们讨论的因果影响主要集中在总体或子总体。例如,吸烟是否对所有人都导致癌症?然而,更具实际意义的问题,尤其是在解决现实世界问题时,是个体因果效应。例如,如果我开始吸烟,它是否会导致我得癌症?如果我给这位顾客打折,他或她会买更多产品吗?个性化的因果关系可以通过反事实分析来推断。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Ilya lix 拍摄,来源于 Unsplash

2. 有两种类型的反事实:

在二元设置中,结果和潜在结果有两个选项。在吸烟与癌症的例子中,结果是得癌症或不得癌症。相应地,会有两种潜在结果和两组因果因素:

  • 必要因果关系: 如果一个人得了癌症,潜在的结果是没有癌症,我们确定吸烟是否是律师所称的“但为因果关系”:如果没有吸烟行为,癌症就不会发展。

  • 充分因果关系: 如果一个人没有得癌症,潜在的结果是癌症,我们确定是否吸烟行为会使这个人得癌症。

区分必要因果关系和充分因果关系不仅有助于机器人在确定因果关系时更像人类思考,还能帮助我们在不同目标和场景下找到更好的行动点——更多内容将在下一节讨论。

匹配与结构因果模型(SCM)

当面对“本可以发生什么”的问题时,我们可能将其视为缺失数据问题。以探讨教育水平如何提高不同人的收入为例。以下是总结在表格中的数据:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表格由作者根据《为何的书》第八章,第 8.1 表制作

在这里,我们有来自不同员工的各种数据条目,包括他们当前的经验水平(EX)和最高教育水平(ED)。为了简化,我们假设有三个教育水平:0 = 高中学历,1 = 大学学历,2 = 研究生学历。

他们的薪资也被报告为 S0, S1, S2。请注意,由于每个员工在特定时间点只能有一个最高学历,因此所有员工的 S0, S1 和 S2 都会有两个缺失项。

如果我们将填补上述表格中的问号视为缺失数据问题,我们有两种方法:

  1. 匹配:

我们找到类似的员工,并匹配不同教育水平的薪资水平。例如,在表格中,我们只有一个额外的特征——经验,我们看到 Bert 和 Caroline 都有九年的经验,那么我们可以得出 S2(Bert) = S2(Caroline) = $97,000,且 S1(Caroline) = S1(Bert) = $92,500。两个缺失数据已填补!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Kara Eads 拍摄,来源于 Unsplash

2. 线性回归或更复杂的模型

在强假设所有数据来自某些未知的随机来源的前提下,标准统计模型找到最适合数据的模型。在线性回归模型中,我们可能会为这个特定问题找到一个这样的方程:

S = $65,000 + 2,500EX + 5,000ED

从系数来看,方程告诉我们,教育水平每增加一个级别,工资会增加$5,000\。当特征空间增加时,我们可以使用更复杂的模型。

这些方法有什么问题?从根本上说,它们都是数据驱动的方法,而不是模型驱动的方法。我们仍然在用第一层的方法解决第三层的问题。因此,无论我们的模型变得多么复杂以及我们获得了多少特征来预测结果变量,我们仍然面临着遗漏因果机制的根本缺陷。

在这个简单的例子中,我们面临的一个问题是经验和教育彼此之间并非独立。通常,更多的教育可能会减少个人的经验年限。如果 Bert 拥有研究生学位而不是目前的大学学位,他的经验水平将低于九年。因此,他将不再适合拥有九年经验的研究生学位持有者 Caroline。总之,我们将会有一个如下的因果图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者基于《为什么的书》第八章制作

因果图表明教育不仅对工资有直接的因果影响,而且还通过中介变量经验影响工资。因此,我们将需要两个方程:

S = f_s(EX, ED, U_s)
where
EX = f_ex(ED,U_ex)

关于这些方程有几个要点,它们构成了这个问题的结构因果模型(SCM):

  • 工资是经验、教育和一些影响工资的不可观察变量 U_s 的函数。注意不可观察变量是外生的,意味着它们与教育和经验没有相关性;

  • 经验是教育和一些影响经验的不可观察变量 U_ex 的函数。

  • 由于没有函数显示教育作为经验的函数,这意味着没有经验对教育的因果影响。这是我们所做的假设。

  • 这两个方程假设了右侧(结果)和左侧(处理)之间的因果关系。

  • 无法观察到的变量 U_s 和 U_ex 量化了个体水平的 uncertainties,这与使用因果链接的概率来量化不确定性的贝叶斯网络不同。它们与经验或教育无关,可以由个人自定义。

  • 函数 f 表示特征与结果变量之间的关系。它可以是线性的,也可以是非线性的,具体取决于假设。

如果我们仍然假设线性关系,我们将基于数据和我们对教育如何影响经验的理解得到以下方程:

S = $65,000 + 2,500*EX + 5,000*ED + U_s
and
EX= 104*ED + U_ex

通过这些方程,我们可以计算每个人的个性化因素 U_s 和 U_ex 来预测反事实。以 Alice 为例。我们知道 S(Alice)是 81000,EX(Alice)是六,ED(Alice)是 0。首先,将这些代入第二个方程得到 U_ex。我们想代入第二个方程,因为它只包含 ED,这是我们关注的唯一因果因素:

6 = 104*0 + U_ex(Alice)
-> U_ex(Alice) = -4
Thus,
EX(Alice) = 104*ED(Alice) - 4

在这里,我们不仅仅是将 EX(Alice)代入方程,而是间接使用这个值。知道 EX(Alice)等于六有助于我们计算 U_ex(Alice)。然后我们将 S(Alice)、ED(Alice)和 U_ex(Alice)代入第一个方程来得到 U_s(Alice):

 81000 = 65000 + 2500*(104*ED + U_ex) + 5000*ED +  U_s(Alice) 
-> U_s(Alice) = -5000*ED - 2500*U_ex(Alice) -9000
-> U_s(Alice) = 0 - 2500*(-4) - 9000 = 1000

#Note **DO NOT** calculate by plug EX=6 directly:
81000 = 65000 + 2500*6 + 5000*0 +  U_s(Alice)

这里是关于 Alice 的函数:

S(Alice) = $65,000 + 2,500*EX + 5,000*ED + U_s(Alice)
and
EX= 104*ED + U_ex(Alice)

一旦 SCMs 准备好,我们就可以对 Alice 进行反事实分析。我们可以计算她如果上大学会得到什么样的工资。如果 ED(Alice)是 1 而不是 0,我们将首先计算:

EX(Alice) = 104+(-4) = 2

然后计算 S_1(Alice):

S_1(Alice) = $65,000 + 2,500*2 + 5,000*1 + 1000 = $76,000

注意我们从回归模型中得到的不同结果,其中我们代入 ED(Alice) = 1 和 EX(Alice)=6:

S = 65000 + 2500*6 + 5000*1 = $85,000 -> Biased estimation 

# If Alice has six years of experience and a high school degree now,
# She couldn't get six years of experience if she goes to college.

这是一个简单的利用 SCMs 理解因果影响和计算个体水平反事实的例子。总之,Pearl 称之为“因果推断的第一定律”:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方程显示,潜在结果 Y_x(u)可以通过模型 M_x 进行推算,只要我们能去除所有到 X 的反向路径。这里,模型比线性回归模型有更大的灵活性,只要因果关系是基于因果图的。

必要原因 (PN) 与 充分原因 (PS)

为了理解反事实,我们有两个不同的测量:必要性概率 (PN) 和 充分性概率 (PS)。为了看到区别,我们用一个例子:房子着火是因为有人点燃了火柴,空气中有氧气:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

火柴和氧气都是房子着火的因果因素。然而,它们在 PN 和 PS 上有所不同。必要性的概率是:

PN = P(Y_x=0 = 0|x=1, Y=1) 

在这种情况下,房子着火了,火柴被点燃了 (Y=1, x=1)。这个概率问的是如果火柴没有点燃,房子是否不会着火 (Y_(x=0) = 0)。这个概率非常高,因为即使我们有足够的氧气,火柴引发的火也不会使房子着火。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Erick ZajacUnsplash上的照片

当我们计算 x=氧气的 PN 时,同样的逻辑适用。如果房子里没有足够的氧气,即使我们点燃了火柴,火也不会发生。

必要性的概率显示了如果治疗没有发生,结果变量会发生什么。如果概率很高,这意味着该治疗是一个必要原因。在法庭上,证明如果被告没有采取某些行动受害者不会死就足以定罪。

另一方面,充分性概率是:

PS = P(Y_x=1 = 1|x=0, Y=0)

在这种情况下,房子没有着火,火柴也没有点燃(Y=0, x=0)。这个概率问的是如果火柴点燃,房子是否会着火(Y_(x=0) = 0)。这个概率也很高,因为通常,氧气无处不在,房子在有氧气和点燃火柴的情况下,很可能会着火。

然而,氧气的 PS 非常低。仅仅因为房子有氧气,火灾不太可能发生。氧气不足以引发房屋火灾。我们需要点燃火柴等其他火源的行动。

因此,充分性概率告诉我们如果治疗发生了,结果变量会发生什么。如果概率很高,这意味着这种治疗是一个充分原因。在这种情况下,点燃火柴符合引发房屋火灾的充分原因,但氧气不符合。

为什么要区分 PS 和 PN?

为什么要做这些区分? 简而言之,尽管多个变量可能是结果的因果因素,人脑会根据一些条件自动“排序”这些因素。心理学家发现,人们会想象如果做了某些行动可能会改变不理想的结果。他们更倾向于:

  • 想象一下,撤销一个稀有事件比撤销一个常见事件更难。例如,点燃火柴是一个比家中有氧气更稀有的事件。(希望如此!)

  • 人们更倾向于责怪自己的行为,即点燃火柴,而不是那些不受他们控制的事件。

正如 Pearl 所强调的,将 PN 和 PS 都嵌入模型中将建议一种系统化的方式来教机器人生成对观察到的结果有意义的解释。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Brett Jordan拍摄,来源于Unsplash

此外,理解 PS 和 PN 之间的区别可以指导我们采取行动。研究和找出极端热浪原因的气候科学家可能会有两种不同的陈述:

  • PN: 有 90%的概率认为人为气候变化是热浪的必要原因;

  • PS: 气候变化至少有 80%的概率会在每 50 年内引发一次如此强烈的热浪。

PN 陈述告诉我们归因:谁对热浪负责——人为气候变化。它找出原因。PS 陈述告诉我们政策。它说明我们对热浪的准备更充分,因为更多的事件正在发生。谁导致了这些事件?没有具体说明。这两种陈述都很有信息量,只是方式不同。

真是篇长文! 这篇文章完成了“与我一起阅读”系列中关于朱迪亚·珀尔《为何之书》的第五篇文章。感谢你坚持读到最后。第七章和第八章的信息量确实很大。希望这篇文章对你有帮助。如果你还没有读过前四篇文章,可以在这里查看:

如果感兴趣,订阅我的邮件列表以加入当前进行的双周讨论。我们将发布一篇关于中介变量的文章,来结束这个系列:

## 什么使得人工智能强大?

“为何之书”第九章和第十章,“与我一起阅读”系列

towardsdatascience.com

还有一篇额外文章:

## 因果推断在学术界和工业界有何不同?

“为何之书”系列的额外文章

towardsdatascience.com

一如既往,我强烈鼓励你阅读、思考,并在这里或在你自己的博客上分享你的主要收获。

感谢阅读。如果你喜欢这篇文章,请不要忘记:

你暂时不需要数据领域……

原文:towardsdatascience.com/you-dont-need-data-domains-yet-23af8ffc3e69

解密数据网格领域

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Louise de Leyritz

·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 5 月 23 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

你暂时不需要数据领域——图片由Castor提供

数据网格范式提倡使用数据领域作为将数据划分为有意义组的方式。领域最初引入是为了帮助创建明确的所有权结构并实现更好的数据发现。

尽管这种方法可能有效,但也可能会令人困惑。我已经研究了数据网格领域一年了,但仍然无法完全理解这个概念。因此,我一直在想:我们是否可以在不涉及数据领域的情况下实现更好的所有权和发现?

在一定程度上,我们可以做到。

在这篇文章中,我旨在阐明数据领域的目的,并探索实现这些目标的更简单的替代方法。

从根本上说,数据领域只是组织中数据分组的一种方式。我们为什么要对数据进行划分?有三个原因:

  • 使人们更容易找到他们需要的数据。

  • 使得在出现问题时能够明确分配所有权和责任。

  • 为了提供必要的背景,以便更好地理解数据。

好消息是,大多数公司无需考虑数据领域的概念来处理这三个要素。所以,停止阅读关于数据领域的所有文章,因为你可能不需要它们。或者至少,目前不需要。

为了应对上述要素,公司只需找到良好的数据分组方法,以便于分配所有权和数据发现。幸运的是,有两种你所在组织已经非常熟悉的数据分组方法:团队来源

这些分类方法简单但非常有效,在所有权、发现和理解方面可以实现 90%的工作。通过合理使用这些分类,公司可以在不引入相当复杂的数据领域概念的情况下改善数据管理。

在这篇文章中,我们将探讨如何根据团队来源对数据进行分组,以解决所有权和发现挑战。我们还将探讨在组织复杂度增长时何时需要引入数据域的概念。

从团队开始

团队进行数据分区将帮助你实现强大的数据发现,并为这些数据的上下文分配所有权。让我们仔细看看基于团队的分区如何帮助实现这两件事。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于团队的数据分区 — 图片来源于Castor

所有权

按照团队对数据进行分区在归属数据所有权和确保职责明确方面非常有意义。

在这里,所有权指的是上下文所有权,而非技术所有权。这意味着,当我们说市场团队拥有市场数据时,我们的意思是他们负责维护这些资产的文档和上下文。

既然我们已经建立了这一点,以下是将数据按团队分组的三个有力理由:

👉 轻松分配所有权:团队在组织中是非常明确的实体。这样可以轻松地将数据分配给一个团队。团队不会模糊不清,也不会重叠,这使得分配所有权变得非常简单。

👉 更好的文档:当团队拥有自己的数据时,数据周围的文档突然变得更有意义。团队对他们负责的过程、目标和指标有深刻的理解。这使他们成为为数据提供上下文的最佳人选。

👉 职责明确:按团队对数据进行分组,可以明确谁负责为数据资产提供上下文。当一个利益相关者对市场数据有疑问时,他们只需两秒钟就会意识到需要联系市场团队,而不是联系一些模糊领域的所有者。

发现

团队分区数据还可以改进数据发现过程,帮助人们更快地找到所需的数据。

每个人都熟悉团队。因此,基于团队的数据分区使数据发现变得更加熟悉。这降低了技术能力较差的人的门槛,他们可能会被数据吓到。

基于团队进行数据分区也满足了两类寻求数据上下文的人群:信息寻求者,他们在寻找特定的信息,以及探索者,他们希望浏览数据的全貌。

例如,如果一个信息寻求者在寻找市场数据,他们会知道市场团队负责这些数据,他们只需通过“团队 → 市场”来筛选数据即可找到。若他们在理解数据时遇到困难,可以直接向市场团队咨询。

类似地,如果一个探索者希望更好地理解数据策略,他们可能会想要探索市场团队处理的数据类型。通过浏览市场团队的仪表板和知识图谱,他们将更清楚地了解业务的运作方式。

继续处理来源

虽然基于团队的数据分区对于分配所有权很有帮助,但可能不足以满足发现目的。因此,另一种数据分区方法是基于来源,这是商业人员熟悉的另一个元素。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于来源的数据分区 — 图片来源于Castor

每个人都对数据来源感到熟悉,因为利益相关者使用它们来执行日常任务。例如,销售人员可能不是数据专家,但他们对 Salesforce 了如指掌。因此,他们在探索数据仓库时,应该能够按来源进行筛选。

基于来源的数据分区也满足了寻找文档的两类人群:信息寻求者探索者

例如,当一个信息寻求者在搜索特定信息时,他们会对要寻找的数据来源有清晰的认识。他们可以按来源(例如 Salesforce)过滤搜索,轻松找到所需的信息。这提升了发现体验,并减少了数据探索所需的时间。

同样,对于一个探索者,如果他们希望获得对数据全貌的更广泛理解,通过来源角度探索数据可能会很有帮助。例如,他们可以确定哪些来源具有最相关或最可信的数据,或识别冗余或不完整的数据。

这种额外的分区层次,与基于团队的分区结合,可以提高组织中数据发现的效率。

什么时候需要引入数据领域?

一旦你将数据分为这两类,你将有效解决所有所有权问题,以及大部分发现和理解相关的挑战。

然而,一些组织在更复杂的环境中运营,可能需要引入额外的复杂层次,以准确捕捉其业务的细节。

对于这些组织来说,基于团队来源的分区可能不足以确保可发现性和所有权。

比如你在 Airbnb 工作。公司已将数据所有权划分到团队,如营销、工程和客户支持。他们还根据数据的来源对数据进行了分组,区分了网站数据和移动应用数据。

但如果你需要了解平台上体验类项目的定价,例如旅游或烹饪课程,该怎么办?这些信息并不完全适合现有的团队或数据源。这不是营销团队拥有的内容,也不是仅从网站或应用程序中获得的内容。那么,你该怎么办呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

有时候,数据无法完美地归入两个类别中的任何一个 — 图片由Castor提供。

这时,领域(domains)的概念就显得很重要。在这种情况下,Airbnb 需要创建一个新的领域,称为“体验”(Experiences),以捕捉这些数据并确保其得到适当的管理。通过引入这种新的分类层级,他们可以更好地组织数据,并确保重要信息不会被遗漏。

关键要点是,随着公司成长和变得更加复杂,它们可能需要引入新的分类层级来管理数据。

然而,这种新的数据划分方式只有在你拥有一些跨多个团队或来源的信息,并且这些信息无法完美地归入任何一个类别时,才应引入。

结论

去中心化数据可能是一个困难的过程,但你应该始终保持简洁。虽然数据领域可以是处理复杂业务需求的有力工具,但你不必在数据旅程的早期就引入它们。

在大多数情况下,团队来源足以进行数据划分,确保所有权和可发现性。因此,你应从这些更为熟悉的概念开始,并根据需要逐步增加复杂性。

目标不是引入不必要的复杂性,而是提供一个帮助团队更有效地处理数据的框架。

最初发布于 https://www.castordoc.com.

你的数据科学可视化将不再相同——Plotly 和 Dash

原文:towardsdatascience.com/your-data-science-visualizations-will-never-be-the-same-plotly-dash-6327d07d9efb

数据可视化

使用 Plotly 和 Dash 创建交互式仪表板

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Pol Marin

·发布于 Towards Data Science ·14 min read·2023 年 10 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Isaac SmithUnsplash 提供的照片

不久前,我写了一篇简单的介绍,介绍了四个 Python 数据可视化库,展示了它们的优缺点,并通过实际示例展示了它们的能力。

当我们深入讨论我最喜欢的那些工具时,我强烈建议你先查看那篇文章,因为这篇文章会扩展其中展示的内容:

## 使用 Python 构建交互式数据可视化——讲故事的艺术

Seaborn、Bokeh、Plotly 和 Dash 用于有效传达数据洞察

towardsdatascience.com

今天我们将重点讨论Plotly[1]和Dash[2]。为什么是两个?因为它们是相辅相成的。正如我在上面链接的文章中所述,“Dash 本身不是一个绘图库。它是一个用于生成仪表板的出色框架。”

所以 Plotly 是我们用来绘图的库,而 Dash 是我们用来从这些图中生成酷炫交互式仪表板的框架。

这是我们今天构建仪表板的步骤:

  • 设置和安装——让我们进入正确的状态。

  • 一些简单的用例——展示 Plotly 如何工作

  • 使用 Dash 构建仪表板——创建最佳仪表板。

  • 结论——总结故事并查看结果。

在深入之前,我们需要讨论数据。我们需要某种数据才能进行可视化,对吧?跟上我最新的 Medium 内容,我将专注于体育,更具体地说,是足球(soccer)。

我将使用 Statsbomb 提供的 2015–16 赛季 LaLiga 的免费数据[3]。

那个赛季有很多数据,但我想要可视化巴萨球员的表现,主要集中在进攻方面:射门、进球、助攻……

目的可能会根据分析师的位置有所不同:你是皇马的分析师吗?那我敢肯定你会想要解码你的球队如何阻止梅西(剧透:你做不到)。

但如果你在巴萨组织工作,你可能想要检查一下你球员的数据,看看哪些球员的表现比其他球员更好。

无论是什么情况,始终确保在创建任何仪表板之前定义你的目标——你可以可视化的信息太多了,你必须有目的地选择你想查看的图表。

始终追求简单;非技术人员需要从你的仪表板中得出结论。

设置和安装

我喜欢保持事物的有序和结构化。所以我们要做的第一件事是创建一个新的目录在你希望托管应用程序的路径中。为了简单起见,我将在桌面上创建它。这是我在终端上运行的两个命令:

$ cd ~/Desktop
$ mkdir plotly-dash

现在,下一步自然是创建一个新的 Python 环境在新的目录中。我将使用pipenv [4],但你可以使用你喜欢的虚拟环境管理工具(或者不使用)。

如果你的机器上还没有安装 pipenv,那么首先运行这个命令:

$ pip install --user pipenv

然后,创建环境:

$ cd plotly-dash
$ pipenv shell

这将创建一个新的环境并自动激活它。你现在从那个终端安装的任何东西都会只安装在该环境中。

所以让我们开始使用 pip 安装库:

(plotly-dash) $ pip install dash pandas statsbombpy

是的,安装这三个库我们会有足够的功能。它们都有自己的依赖关系,我们将利用一些像PlotlyNumPy的库。

一切准备好后,我们现在可以开始探索 Plotly。

使用 Plotly 进行数据可视化

我在这里的建议是从 jupyter notebook 进行测试,因为这将使你的开发阶段更加流畅。在这种情况下,你也应该安装它——我保证这是我们运行的最后一个安装——我们也将打开它:

(plotly-dash) $ pip install notebook
... (installation outputs)

(plotly-dash) $ jupyter notebook

像往常一样,我们需要准备数据,我们将创建一个新的笔记本叫做plotly.ipynb。为了避免极大的笔记本和文件,我喜欢将我的代码模块化。因此,我在项目文件夹中创建了一个src目录,并在其中添加了两个新文件:functions.pyclasses.py。现在的结构如下:

- plotly-dash/
    - src/
        - classes.py
        - functions.py
    - plotly.ipynb

我将创建的第一个函数叫做prepare_team_data(),它将返回指定球队(在我们这个例子中是巴萨)的事件、射门和助攻数据。

由于函数本身对今天的目的并不重要,因为我们想专注于绘制和创建仪表盘,我不会放函数的代码。但你可以在资源部分[5]找到整个代码的链接。

# Third-party libraries
import pandas as pd
from statsbombpy import sb

# Custom modules
from src.functions import prepare_team_data

events, shots, assists = prepare_team_data('Barcelona')
shots.head()

这是投篮 DF 的一个快照。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

投篮 DF 屏幕截图——图像由作者提供

好的,让我们从投篮开始吧。我想绘制一个球员的投篮分布图,以查看他从哪里射门以及他的进球来自哪里。为此,我在classes.py模块中创建了一个FootballPitch类。

这个类允许我们绘制一个完整的足球场、它的一半(即进攻半场),或者甚至像我们将要做的那样绘制一个热图。

再次提醒,你可以在本文底部的资源部分找到 GitHub 链接[5]中的代码。但我们会稍微检查一下,因为这是我们使用了很多 Plotly 功能的地方。

这个类基本上有两个方法:plot_pitch()plot_heatmap。由于我们首先关注的是展示球员的投篮,所以我们从第一个方法开始,将其分解为小的代码块。

请注意,你会看到一些变量和类属性我们没有赋值。这些是函数参数或在创建对象时初始化的。

首先:让我们声明函数将使用的基本变量。

# Fig to update
fig = go.Figure()

# Internal variables
self.height_px = self.pitch_width*10*zoom_ratio
self.width_px = self.pitch_length*10*zoom_ratio

pitch_length_half = self.pitch_length/2 if not self.half else 0
pitch_width_half = self.pitch_width/2
corner_arc_radius = 1

centre_circle_radius = 9.15

goal = 7.32
goal_area_width = goal + (5.5*2)
goal_area_length = 5.5
penalty_area_width = goal_area_width + (11*2)
penalty_area_length = goal_area_length + 11
penalty_spot_dist = 11
penalty_circle_radius = 9.15

现在我们已经声明了图形,我们将一遍又一遍地添加轨迹或形状以根据我们的需要进行自定义。因此,例如,函数的第一步是绘制一个矩形形状,即足球场本身:

fig.add_trace(
    go.Scatter(
        x=[0, self.pitch_length, self.pitch_length, 0, 0], 
        y=[0, 0, self.pitch_width, self.pitch_width, 0], 
        mode='lines',
        hoverinfo='skip',
        marker_color=line_color,
        showlegend=False,
        fill="toself",
        fillcolor=bg_color
    )
)

在这里,我们添加了一个散点图轨迹,模式为lines——这意味着我们想要的是一条线,而不是一个有独立点的真正的散点图。参数是相当自解释的,例如 x 和 y(我们想绘制的数据),颜色……hoverinfo 标签用于确定当我们将鼠标悬停在这些线条上时想显示的内容。由于我们将足球场作为背景的一部分构建,并且它不会告诉我们任何关于我们想要分析的数据的信息,所以我将其设置为跳过。

然后我们在图的布局中设置了一些额外的配置:

fig.update_layout(
    yaxis_range=[-self._vertical_margin, self.pitch_width + self._vertical_margin], 
    xaxis_range=[-self._horizontal_margin, self.pitch_length + self._horizontal_margin],
    height=self.height_px,
    width=self.width_px,
    plot_bgcolor='rgba(0,0,0,0)',
    xaxis=dict(showgrid=False, visible=False),
    yaxis=dict(showgrid=False, visible=False)
)

这给我们带来了以下结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

足球场(仅草地)——图像由作者提供

现在我们已经绘制好了我们的足球场。还不是特别有意义……不过。

在 Plotly 中绘图确实如此简单!通过在图中添加更多的轨迹和形状,我的足球场背景最终看起来是这样的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

足球场——图像由作者提供

现在,你可能对显示一个足球场不感兴趣。这就是为什么我没有把所有代码都放在这里……但出色的仪表盘是创造力和技能的结果,而绘制一个足球场是展示发生在场上的足球事件(如果我们对位置感兴趣)的一个很好的方式。

所以,让我们开始展示真实数据吧!

由于我们要显示射门——和进球——散点图似乎是一个合适的选择。记住,我们已经准备好了数据,我们只需要过滤和显示它。

让我们绘制梅西的射门和进球:

import plotly.graph_objects as go
from src.classes import FootballPitch

player = 'Leo Messi'

pitch = FootballPitch(half=True)
fig = pitch.plot_pitch(False, bg_color='#C1E1C1') 

player_shots = get_player_shots(player, shots.copy(), pitch)
scatter_colors = ["#E7E657", "#57C8E7"]

for i, group in enumerate([True, False]):
    fig.add_trace(go.Scatter(
        x=player_shots[player_shots['goal'] == group]['x'],
        y=player_shots[player_shots['goal'] == group]['y'],
        mode="markers",
        name='Goal' if group else 'No Goal',
        marker=dict(
            color=scatter_colors[i],
            size=8,
            line=dict(
                color='black',
                width=1
            )
        ),
    ))

fig.update_layout(
    title='Shot distribution'
)

第一部分不言而喻:我们只是声明变量,实例化球场,将图形存储在fig变量中,并运行一个函数来过滤shots数据框,以返回仅玩家的射门数据。

然后,在一个 2 次迭代的循环中,我们添加了两次散点图:一次用于未进球的射门(将显示为蓝色),另一次用于进球的射门。结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

梅西在 2015/16 赛季的射门和进球分布——图片由作者提供

使 plotly 令人惊叹的是,这个图形是完全互动的。我们可以通过悬停鼠标来查看真实的射门位置,隐藏非进球的射门以仅检查进球射门……

现在让我们继续,构建一个折线图。它当然会是互动的,我们将用它来按季度检查球员的射门情况,并将其与队友和球队的平均水平进行比较。

为此,我们将开始按季度(每 15 分钟一段)将射门分组。下一部分将是绘制这些值本身,并调整线条的不透明度,以突出当前的球员(梅西)。

player = 'Leo Messi'
max_shots = 0
fig = make_subplots()

for p in shots.player.unique():
    player_shots = get_player_shots(p, shots)

    xy = 15 * (player_shots[['float_time', 'minutes']]/15).round()
    xy = xy.groupby(['float_time']).count()[['minutes']]

    max_shots = xy.minutes.max() if xy.minutes.max() > max_shots else max_shots

    fig.add_trace(
        go.Scatter(
            name=p,
            x = xy.index, 
            y = xy.minutes,
            mode='lines',
            opacity=1 if p == player else 0.2
        )
    )

现在我们已经准备好了所有球员,我们将添加球队的平均值作为虚线。代码的功能与上面的代码片段完全相同,但使用的是团队级的数据。

# Add team's avg
xy = 15 * (shots[['float_time', 'minutes']]/15).round()
xy = xy.groupby(['float_time']).count()[['minutes']]/len(shots.player.unique())

fig.add_trace(
    go.Scatter(
        name="Team's Average",
        x = xy.index, 
        y = xy.minutes,
        line = go.scatter.Line(dash='dash'),
        marker=None,
        mode='lines'
    )
)

然后,我们将为布局添加一些样式:

fig.update_xaxes(range=[0, 91])
fig.update_layout(
    #title='Shots by Quarter',
    margin=dict(l=20, r=20, t=5, b=20),
    xaxis = dict(
        tickmode = 'array',
        tickvals = xy.index.values
    ),
    height=200,
    plot_bgcolor="#F9F9F9", 
    paper_bgcolor="#F9F9F9",
    yaxis_range=[-3,max_shots+5]
)

结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

绿色的高亮线是梅西的数据(当我悬停在第 60 分钟的射门计数上时,标签显示)。由于某些原因,也许是疲劳,梅西的射门在 60 到 75 分钟期间减少,但在比赛的最后几分钟却增加了。

我们看到球队的大多数射门在最后 15 分钟减少,但梅西却走向相反的方向。这显示了他对球队的影响和他的胜利欲望。

总之,介绍部分到此为止。我们已经成功绘制了两个不同的图形,并且为我们的图形创建了一个出色的背景。我认为我们已经覆盖了 Plotly 的基本知识。

创建仪表板

仪表板只是以有序和吸引人的方式显示的图形组合。我们已经创建了图形——我们在上一部分做的——所以我们只需将它们显示出来。

现在,这并不那么简单。我们需要对上面共享的代码片段进行一些更改,但我保证这些更改会很小。

为了完成仪表板,我将添加一些更多的图形和功能,使其完全互动。

安装了Dash之后,我将创建一个名为app.py的新文件:

- plotly-dash/
    - src/
        - classes.py
        - functions.py
    - plotly.ipynb
    - app.py

文件的模板将开始是这样的简单:

from dash import html, Dash, dcc, Input, Output, callback

app = Dash(__name__) 

if __name__ == '__main__':
    app.run(debug=True)

如果你继续执行文件(python app.py),你会在终端中看到类似以下的消息:

(plotly-dash) $ python app.py
Dash is running on http://127.0.0.1:8050/

 * Serving Flask app 'app'
 * Debug mode: on

请访问127.0.0.1:8050/。你会看到一个空白页面,但这实际上是你的仪表板。

让我们开始添加内容吧?看看下一个代码。

@callback(
    Output('shot_distribution', 'figure'),
    Input('player_dropdown', 'value')
)
def create_shot_distribution(player):
    pitch = FootballPitch(half=True)
    fig = pitch.plot_pitch(False, bg_color='#C1E1C1', zoom_ratio=0.8) 

    player_shots = get_player_shots(player, SHOTS.copy(), pitch)

    scatter_colors = ["#E7E657", "#57C8E7"]

    for i, group in enumerate([True, False]):
        fig.add_trace(go.Scatter(
            x=player_shots[player_shots['goal'] == group]['x'],
            y=player_shots[player_shots['goal'] == group]['y'],
            mode="markers",
            name='Goal' if group else 'No Goal',
            marker=dict(
                color=scatter_colors[i],
                size=8,
                line=dict(
                    color='black',
                    width=1
                )
            ),
            #marker_color=scatter_colors[i] # #E7E657 i #57C8E7  
        ))

    fig.update_layout(
        margin=dict(l=20, r=20, t=5, b=20),
    )

    return fig

到现在,这应该听起来很熟悉。这正是我们用来显示梅西射门的代码… 但现在,代替定义为梅西的玩家,它是函数参数。

那这个参数来自哪里?就在函数声明的上方,我们有回调装饰器。这些回调使 Dash 的仪表板具有交互性。

我们用它们来确定关联应用组件的输入和输出。在这种情况下,我们说明函数需要player参数,这个参数将来自名为player_dropdown的元素(我们还未定义)。

关于输出,我们让函数返回fig。多亏了回调装饰器,应用知道这将是我们仪表板中的shot_distribution元素所用的图形。

你现在可能有很多问题。怎么定义一个下拉框或任何可交互组件?我如何真正绘制shot_distribution元素?

从第一个问题开始:下拉框。Dash 有自己的核心组件(dcc),下拉框就是其中之一。创建它非常简单:

dcc.Dropdown(
    PLAYER_OPTIONS,
    'All players', 
    id='player_dropdown', 
    style={'width': '200px', 'margin': '20px auto', 'text-align': 'left'}
)

这将创建一个下拉框,使用所有玩家名称作为可能的选项,以All players作为默认值。但最重要的是id。在这里我们可以告诉 Dash 这个下拉框是与前一个函数的输入回调关联的。

换句话说,这个下拉框的值将是射击分布图上显示的玩家。

但我们仍然需要将这两个组件放入我们的仪表板中。页面依然是空白的。

你现在需要一些 HTML 知识,但基本知识就足够了(尽管它可以复杂到你想要的程度)。

我们需要将这些组件放在 HTML 代码中。Dash 再次使这一过程非常简单。对于下拉框,可以通过简单地用html.Div组件包裹代码,即将下拉框放在<div></div> HTML 元素中来实现:

filter = html.Div(
    [
        dcc.Dropdown(
            PLAYER_OPTIONS,
            'All players', 
            id='player_dropdown', 
            style={
                'width': '200px', 
                'margin': '20px auto', 
                'text-align': 'left'
            }
        )
    ],
    style={'display': 'inline-block'}
)

这个工作的方式是html.Div可以有许多子元素(因此是列表),然后我们可以使用 style 属性设置元素的 CSS 样式,它是一个字典。简单吧?

对于射击分布图,等效的代码如下:

shot_distribution_graph = html.Div(
    [
        html.H2('Shot Distribution'),
        dcc.Graph(id='shot_distribution', figure={})
    ], 
    style={
        'padding': '2%',
        'display': 'inline-block'
    }
)

结构相同,但为了显示图表,我们使用dcc.Graph组件,并且你可能猜到了,id 属性在这里也是关键。它将这个特定组件与我们声明的函数的输出回调关联起来。因此,那里计算的内容将在这里显示。

我们现在已经用 HTML 代码包裹了这些组件。但它们还没有显示出来。我们需要将它们添加到仪表板的布局中:

app.layout = html.Div([
    shot_distribution_graph, filter
], style={
    'width': '1650px', 
    'margin': 'auto'
})

这里没有秘密;结构相同,但层级更高。我们将之前的<div></div>元素放入一个大的容器中(整个网站容器),并提供一些额外的样式。现在,如果你刷新网站或重新启动应用,你会看到你的初步结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

投篮分布图 — 图片由作者提供

已经构建出的成果令人惊叹,对吧?这种互动性非常强大。

为了完成这一部分,让我们对另一个我们构建的图表做同样的处理。这一次,我会在这里粘贴整个代码,以便你可以一次性查看所有内容:

# Functions
@callback(
    Output('shots_by_quarter', 'figure'),
    Input('player_dropdown', 'value')
)
def create_shots_by_quarter(player):
    fig = make_subplots()

    max_shots = 0

    for p in SHOTS.player.unique():
        player_shots = get_player_shots(p, SHOTS)

        xy = 15 * (player_shots[['float_time', 'minutes']]/15).round()
        xy = xy.groupby(['float_time']).count()[['minutes']]

        max_shots = xy.minutes.max() if xy.minutes.max() > max_shots else max_shots

        fig.add_trace(
            go.Scatter(
                name=p,
                x = xy.index, 
                y = xy.minutes,
                mode='lines',
                opacity=1 if p == player else 0.2
            )
        )

    # Add team's avg
    xy = 15 * (SHOTS[['float_time', 'minutes']]/15).round()
    xy = xy.groupby(['float_time']).count()[['minutes']]/len(SHOTS.player.unique())

    fig.add_trace(
        go.Scatter(
            name="Team's Average",
            x = xy.index, 
            y = xy.minutes,
            line = go.scatter.Line(dash='dash'),
            marker=None,
            mode='lines'
        )
    )

    fig.update_xaxes(range=[0, 91])
    fig.update_layout(
        margin=dict(l=20, r=20, t=5, b=20),
        xaxis = dict(
            tickmode = 'array',
            tickvals = xy.index.values
        ),
        height=200,
        plot_bgcolor="#F9F9F9", 
        paper_bgcolor="#F9F9F9",
        yaxis_range=[-3,max_shots+5]
    )

    return fig

# Dashboard's layout components
shots_by_quarter = html.Div(
    [
        html.H2('Shots By Quarter', style={'margin-top': '20px'}),
        dcc.Graph(id='shots_by_quarter', figure={})
    ],
    style={
        'padding': '2%'
    }
)

# Create layout
app = Dash(__name__)
app.layout = html.Div([
    shot_distribution_graph, filter, shots_by_quarter
], style={'width': '1650px', 'margin': 'auto'})

# Run app
if __name__ == '__main__':
    app.run(debug=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结果仪表盘包含两个图表 — 图片由作者提供

现在,这个是功能性的。但它真的不太吸引人……HTML 和 CSS 将是使其视觉上更具吸引力的工具(尽管我在设计方面不太擅长)。

然而,这超出了我们的范围。我们的目标是创建一个仪表盘,我们已经实现了。这虽然非常简单,但如果你能够理解我们做的所有内容,最终的仪表盘是如何完成的,我在开始时分享了,也将在下一节再次分享,你就不会对它感到陌生(同样,底部的代码是自由获取的)。

总结

今天我们构建了一个包含两个图表和一个下拉框的仪表盘。但我们可以根据需要进行扩展。例如,了解如何放置一个下拉框后,我们就知道如何放置一个滑块。那两个呢?

我们今天学到的所有内容都可以应用于任何你想要可视化的数据, 从经济报告到医疗结果或广告活动洞察。我选择将其应用于足球,因为我对足球充满热情,但请将这些知识普及到其他领域。

了解如何放置两个图表后,我们可以创建更多的图表。而且是不同类型的:一个展示助攻,另一个展示球员在场上的影响,还有比较他的进球和预期进球……再加上一点 HTML 和 CSS,我们就能得到最终的仪表盘:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终仪表盘 — 图片由作者提供

我真的希望你能看到这个工具的强大之处。

Dash 和 Plotly 必须是任何数据分析师的技能库中的必备工具。它们是非常棒的库,我们可以用来以高度定制的方式分享我们的数据和见解——即根据你的需求进行调整——并且容易理解。

**Thanks for reading the post!** 

I really hope you enjoyed it and found it insightful.

Follow me and subscribe to my mail list for more 
content like this one, it helps a lot!

**@polmarin**

资源

[1] Plotly: 低代码数据应用开发

[2] Dash 文档和用户指南 — Plotly

[3] 免费数据 | StatsBomb

[4] Pipenv: 人性化的 Python 开发工作流程

[5] Plotly & Dash 项目代码 — GitHub

你的数据(终于)在云端了。现在,别再那么依赖本地了

原文:towardsdatascience.com/your-datas-finally-in-the-cloud-now-stop-acting-so-on-prem-bbb7b4f35529?source=collection_archive---------2-----------------------#2023-08-16

现代数据栈允许你以不同的方式操作,而不仅仅是在更大的规模上。充分利用它。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Barr Moses

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 8 月 16 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Massimo Botturi 提供,来自 Unsplash

想象一下,你在大部分职业生涯中一直用锤子和钉子建房子,而我给了你一把钉枪。但你却不是将它对准木头按下扳机,而是将其侧向放置,就像用锤子一样敲钉子。

你可能会认为这既昂贵又效果不佳,而网站的检查员则会正确地将其视为安全隐患。

好吧,那是因为你正在使用现代工具,但却带着遗留的思维和流程。虽然这个类比并不完美地概括了某些数据团队从本地到现代数据栈后的运作方式,但它很接近。

团队很快理解到超弹性计算和存储服务如何使他们能够处理以前从未见过的多样数据类型和速度,但他们并不总是理解云对其工作流程的影响。

所以也许对这些最近迁移的数据团队来说,一个更好的类比是,如果我给你 1,000 把钉子枪,然后看着你将它们全部横着放,以同时钉 1,000 个钉子。

无论如何,重要的是要理解现代数据栈不仅仅允许你更大更快地存储和处理数据,它允许你从根本上以不同的方式处理数据,以实现新的目标并提取不同类型的价值

这部分是由于规模和速度的增加,但也是由于更丰富的元数据和生态系统中的更无缝集成。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Shane Murray 和作者提供。

在这篇文章中,我突出了我看到数据团队在云端行为变化的三种更常见方式,以及五种它们没有做到(但应该做到)的方式。让我们深入探讨一下。

数据团队在云端变化的三种方式

数据团队转向现代数据栈的原因有很多(不仅仅是因为首席财务官最终解放了预算)。这些用例通常是数据团队进入云端后的第一个也是最容易的行为转变。它们包括:

从 ETL 迁移到 ELT 以加快洞察时间

你不能随便将任何东西加载到本地数据库中——特别是如果你希望查询在周末之前返回的话。因此,这些数据团队需要仔细考虑要提取哪些数据以及如何通过在 Python 中硬编码的管道将其转换成最终状态。

这就像是为每个数据消费者定制特定的餐点,而不是提供自助餐,就像任何去过邮轮的人都知道的那样,当你需要满足整个组织对数据的无限需求时,自助餐才是最好的选择。

这正是 AutoTrader UK 技术负责人 Edward Kent 与我的团队去年讨论数据信任和自助分析需求的情况。

“我们希望赋能 AutoTrader 及其客户,使其能够做出数据驱动的决策,并通过自助平台实现数据的民主化……在我们将受信任的本地系统迁移到云端时,那些旧系统的用户需要相信新的云技术与他们过去使用的旧系统一样可靠,”他说。

当数据团队迁移到现代数据堆栈时,他们欣然采用如 Fivetran 这样的自动化数据摄取工具或如 dbt 和 Spark 这样的转化工具,以及更复杂的数据策展策略。分析自助服务开启了一个全新的领域,谁来负责数据建模并不总是很明确,但总体而言,这是一种更高效的方式来解决分析(和其他!)使用案例。

实时数据用于操作决策

在现代数据堆栈中,数据可以快速移动,不再仅仅用于每日指标的脉搏检查。数据团队可以利用Delta 实时表SnowparkKafkaKinesis、微批处理等更多工具。

并非每个团队都有实时数据的使用案例,但那些有的团队通常都非常清楚。这些通常是需要操作支持的物流密集型公司,或者是将强大报告集成到产品中的科技公司(尽管后一类公司中有相当一部分是在云中诞生的)。

挑战仍然存在,这些挑战有时涉及并行架构(分析批处理和实时流)并试图达到大多数人希望的质量控制水平。但大多数数据领导者很快理解了能够更直接支持实时操作决策的价值。

生成式 AI 和机器学习

数据团队对生成式 AI 浪潮有深刻的认识,许多行业观察者怀疑这项新兴技术正在推动基础设施现代化和利用的巨大浪潮。

但在 ChatGPT 生成其第一篇文章之前,机器学习应用已经从前沿技术逐渐成为数据密集型行业的标准最佳实践,包括媒体、电商和广告。

目前,许多数据团队一有可扩展的存储和计算资源就会立即开始检查这些使用案例(尽管有些团队可能会从构建更好的基础中受益)。

如果你最近迁移到云端,但还没有询问业务如何更好地支持这些使用案例,请将其安排到日程中。本周,或者今天。你会感谢我的。

数据团队仍像在本地部署一样工作的 5 种方式

现在,让我们看看一些以前在本地的数据团队可能较慢利用的未实现的机会。

附注:我想澄清的是,虽然我之前的类比有些幽默,但我并不是在嘲笑那些仍在本地操作或在云中使用以下流程的团队。变革是困难的,尤其是在面对持续的积压和不断增加的需求时,变革更为艰难。

数据测试

本地的数据团队没有规模或来自中央查询日志或现代表格格式的丰富元数据,因此无法轻松运行机器学习驱动的异常检测(换句话说,数据可观察性)。

相反,他们与领域团队合作,以理解数据质量要求,并将这些要求转化为 SQL 规则或数据测试。例如,customer_id 应该永远不能为 NULL,或者 currency_conversion 应该永远不能有负值。还有一些本地工具旨在帮助加速和管理这一过程。

当这些数据团队迁移到云端时,他们首先想到的不是以不同的方式处理数据质量,而是以云规模执行数据测试。这是他们所熟悉的做法。

我看到过一些案例研究,读起来像恐怖故事(不,我不会提名字),数据工程团队在数千个 DAG 上运行数百万个任务,以监控数百个管道中的数据质量。哎呀!

当你运行五十万条数据测试时会发生什么?我告诉你。即使绝大多数测试通过,仍然会有数万条测试失败。而且这些测试明天还会失败,因为没有上下文来加快根本原因分析,甚至不知道从哪里开始分类。

你不知何故使你的团队产生了警报疲劳,却仍未达到所需的覆盖范围。更不用说大规模的数据测试既耗时又昂贵。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。来源

相反,数据团队应该利用能够检测、分类和帮助根本原因分析的技术,同时将数据测试(或自定义监控)保留在最重要的表格中的最清晰的阈值上。

数据血缘的数据建模

支持中央数据模型有很多合理的理由,你可能在一篇精彩的 Chad Sanderson 博客中读到过这些理由。

但每隔一段时间,我会遇到在云端投入大量时间和资源来维护数据模型的团队,唯一的原因是维护和理解 数据血统。在本地部署时,这基本上是你最好的选择,除非你想阅读长篇的 SQL 代码,并创建一个满是记忆卡片和纱线的公告板,让你的另一半开始问你是否没事。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Jason Goodman 提供,来自 Unsplash

(“不,Lior!我没事,我只是试图理解这个 WHERE 子句是如何改变这个 JOIN 中的列的!”)

现代数据堆栈中的多个工具——包括数据目录、数据可观测性平台和数据仓库——可以利用元数据来创建自动化的数据血统。这只是一个 选择风味 的问题。

客户细分

在旧的世界里,对客户的视角是平面的,而我们知道它实际上应该是一个 360 度的全球视图。

这种有限的客户视角是由于预建数据(ETL)、实验约束和本地数据库计算更复杂查询(独特计数、不同值)所需的时间较长。

不幸的是,数据团队在云端移除了这些约束之后,并不总是会从他们的客户视角中移除盲点。这通常有几个原因,但最大的罪魁祸首仍然是老式的 数据孤岛

市场营销团队运营的客户数据平台仍然充满活力。该团队可以通过丰富他们对潜在客户和客户的视角来从数据仓库/湖泊中的其他领域数据中受益,但多年来建立的习惯和责任感很难打破。

因此,与其根据最高的估计终身价值来定位潜在客户,不如将目标放在每个线索的成本或每次点击的成本上。这是数据团队以直接且高度可见的方式为组织贡献价值的一个错失机会。

导出外部数据共享

复制和导出数据是最糟糕的。这不仅耗时、增加成本,还会导致版本控制问题,使得访问控制几乎不可能。

与其利用你的现代数据堆栈创建一个将数据以极快速度导出到典型合作伙伴的管道,更多的云端数据团队应当利用零拷贝数据共享。就像管理云文件的权限在很大程度上取代了电子邮件附件,零拷贝数据共享允许在不将数据从宿主环境中移走的情况下访问数据。

SnowflakeDatabricks 在过去两年的年度峰会上宣布并大力推广了他们的数据共享技术,更多的数据团队需要开始加以利用。

成本和性能优化

在许多本地系统中,数据库管理员负责监督所有可能影响整体性能的变量,并根据需要进行调节。

然而,在现代数据堆栈中,你通常会看到两种极端情况。

在一些情况下,DBA 的角色仍然存在,或者外包给一个中央数据平台团队,如果管理不善,可能会造成瓶颈。然而,更常见的情况是,成本或性能优化变成了“无人区”,直到一笔特别高额的账单送到 CFO 的桌上。

这通常发生在数据团队没有正确的成本监控工具时,且出现了特别激进的异常事件(可能是错误的代码或爆炸性 JOIN)。

此外,一些数据团队未能充分利用“按使用付费”的模式,而是选择承诺预定数量的积分(通常有折扣)……然后超出这个数额。虽然积分承诺合同本身没有问题,但如果不加以注意,这种预留时间可能会形成一些坏习惯,随着时间的推移逐渐累积。

云计算使得 DevOps/DataOps 可以采用更加连续、协作和集成的方法,FinOps 也是如此。我看到的最成功的团队是那些将成本优化融入日常工作流,并激励与成本最相关的人员的团队。

“消费型定价的兴起使得这一点更加关键,因为新功能的发布可能导致成本指数级上升,”Tenable 的 Tom Milner 说。“作为我的团队的负责人,我每天检查我们的 Snowflake 成本,并将任何费用激增作为我们待办事项的优先项。”

这会创建反馈循环、共享学习以及成千上万的小型快速修复,从而带来重大成果。

Stijn Zanders 在 Aiven 说:“我们设立了警报,当有人查询任何可能花费我们超过 $1 的内容时。这是一个相当低的阈值,但我们发现费用不需要超过这个数额。我们发现这是一个良好的反馈循环。[当这个警报出现时],通常是有人忘记了在分区或聚集列上设置过滤器,他们可以迅速学习。”

最后,在团队之间部署收费回收模型,在云计算之前的时代几乎是不可想象的,这是一项复杂但最终值得的工作,我希望看到更多的数据团队对此进行评估。

成为一个“学习型”人

微软首席执行官 Satya Nadella 曾谈到 他如何有意将公司的组织文化从“全知型”转变为“学习型”。这将是我对数据领导者的最佳建议,无论你是刚刚迁移还是多年来一直处于数据现代化的前沿。

我理解这可能是多么令人不知所措。新技术的出现迅猛而猛烈,供应商的推销也不容忽视。最终,这不仅仅是关于拥有行业内“最现代化”的数据堆栈,而是关于在现代工具、顶尖人才和最佳实践之间创造对齐。

为了做到这一点,始终准备学习你的同行如何应对你面临的许多挑战。参与社交媒体,阅读 Medium,关注分析师,并参加会议。我会在那里见到你!

在云环境中,其他哪些本地数据工程活动不再有意义?如有任何意见或问题,请联系 Barr ,通过 LinkedIn。

你的数据集有缺失值?什么都不做!

原文:towardsdatascience.com/your-dataset-has-missing-values-do-nothing-10d1633b3727

模型能够比填补方法更有效地处理缺失值。这是一个实证证明。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Samuele Mazzanti

·发表于 Towards Data Science ·10 分钟阅读·2023 年 10 月 9 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

[作者提供的图片]

缺失值在真实数据集中非常常见。随着时间的推移,许多方法被提出以解决这个问题。通常,这些方法要么是删除包含缺失值的数据,要么是使用一些技术进行填补。

在这篇文章中,我将测试第三种替代方案:

什么都不做。

实际上,最适合表格数据集的模型(即 XGBoost、LightGBM 和 CatBoost)可以原生处理缺失值。因此,我将尝试回答的问题是:

这些模型是否能有效处理缺失值,还是通过预处理填补能获得更好的结果?

谁说我们应该关心缺失值?

似乎存在一种普遍的信念,认为我们必须对缺失值做一些事情。例如,我问了 ChatGPT,如果我的数据集包含缺失值应该怎么办,它建议了 10 种不同的解决方法(你可以在这里阅读完整回答)。

那么,这种信念来自哪里呢?

通常,这些观点源于历史模型,特别是线性回归。这也是如此。让我们看看原因。

假设我们有这样一个数据集:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个缺失值的数据集。[作者提供的图片]

如果我们尝试在这些特征上训练线性回归,会出现错误。事实上,为了能够进行预测,线性回归需要将每个特征乘以一个数值系数。如果一个或多个特征缺失,就无法对该行进行预测。

这就是为什么提出了许多填补方法。例如,最简单的方法之一是用特征的均值替换缺失值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用特征均值填补。[图片由作者提供]

另一种更复杂的方法是利用变量之间的关系来预测填补特定条目的最可能值。这意味着为每个特征训练一个预测模型(使用其他特征作为预测变量)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

迭代填补:每个特征通过其他所有特征来估计。[图片由作者提供]

然而,并非所有模型都像线性回归一样。

实际上,碰巧的是,表格任务中表现最好的模型(即树模型,如 XGBoost、LightGBM 和 CatBoost)可以原生处理缺失值。

这怎么可能?

因为在树状结构中,缺失值可以像其他值一样处理,即模型可以将它们分配到树的某个分支上。例如,这张图片来自 XGBoost 文档:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

树模型如何处理缺失值。[图片来自XGBoost 文档]

正如你所见,对于每个分裂,XGBoost 会选择一个默认分支,将缺失值(如果有的话)路由到该分支上。

所以,我们是否说因为这些模型可以处理缺失值,我们就应该避免填补?我从未这么说过。

作为数据科学家,我们通常希望在实践中有深刻的了解。因此,接下来的段落将重点比较有无填补的模型性能,并观察哪种方法表现更好。

使用真实数据集进行实验

我们的目标是比较两种方法:

  1. 具有缺失值的数据集上训练和测试模型。

  2. 在没有缺失值的数据集上训练和测试模型(在这些缺失值通过某种填补方法填补后)。

我将使用 14 个在Pycaret中提供的真实数据集(这是一个受MIT 许可证保护的 Python 库)。

这些数据集不包含缺失值,所以我们需要人为制造这些缺失值。我将这个过程称为空值播撒,因为它涉及在原始数据集中散布空值(即“取消”一些原始值)。

我们要取消多少值?为了确保实验具有代表性,我们将尝试不同百分比的空值:5%、10%、20% 和 50%。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同百分比的空值播撒。[图片由作者提供]

我使用两种不同的策略来创建空值:

  • 随机:为数据集的每个条目分配一个介于 0 和 1 之间的随机值,如果小于阈值,则该条目被取消。

  • 非随机:对于每个特征,我将值按升序或降序排序,并根据值在排序序列中的位置分配一个相应的取消概率。

此外,对于每种组合,我将尝试 25 种不同的随机训练/测试划分。这样做的目的是确保我们观察到的结果在多次重复中是一致的,而不仅仅是偶然的。

总结一下,这些都是我将尝试的所有组合。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

形成我们实验的组合。[图像来源:作者]

对于这 2,800 种组合中的每一种,我将计算 LightGBM 在三种不同方法下的平均精度(在测试集上):

  1. 使用原始数据集(没有缺失值)。我将称这个平均精度为ap_original

  2. 使用包含我已经撒入缺失值的数据集。我将称这个平均精度为ap_noimpute

  3. 使用已用 Scikit-Learn 的 IterativeImputer 填充缺失值的数据集。我将称这个平均精度为ap_impute

让我们借助图示来阐明这三种指标之间的差异:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3 个版本的数据集。每个版本上训练了不同的 LightGBM。[图像来源:作者]

结果

我尝试了上述所有 2,800 种组合,并跟踪了每一种的ap_originalap_noimputeap_impute。我将结果存储在类似这样的表格中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

实验结果。[图像来源:作者]

这个表格有 112 行(即 14 个数据集 x 2 种缺失值撒入策略 x 4 种缺失频率)。前三列指示哪个数据集、缺失值撒入策略和缺失频率标识那一行。然后,有三列存储该方法在 25 次 bootstrap 迭代中报告的平均精度。

为了更加清楚,让我们看看第一行的第四列。这是一个包含 25 个元素的数组:每个元素都是 LightGBM 在原始数据集上(因此称之为ap_original)在特定训练/测试划分(即 bootstrap 迭代)中实现的平均精度,数据集“bank”包含 5%的随机撒入缺失值。

当然,我们对每一个单独的 bootstrap 迭代并不感兴趣,所以我们需要以某种方式聚合这些数组。由于我们希望比较三种方法,最简单的聚合方式是统计一种方法优于另一种方法的次数(即它具有更高的平均精度)。

因此,对于表格中的每一行,我计算了:

  • ap_original > ap_noimpute,这意味着在原始数据集上的模型优于在包含缺失值的数据集上的模型。

  • ap_noimpute > ap_impute,这意味着在包含缺失值的数据集上的模型优于在使用 IterativeImputer 填充的数据集上的模型。

这是结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

实验结果。[图像来源:作者]

例如,如果我们取第一行的最后一列,这意味着在 25 次迭代中,ap_noimpute在 60%的情况下大于ap_impute

让我们看看名为original_>_noimpute的列。显然,大多数情况下是 100%。这也是合理的:模型在原始数据集上训练的效果优于在取消了一些条目的数据集上训练的模型。

但对我们来说最重要的信息在最后一列。实际上:

  • noimpute_>_impute大于 50%时,这意味着在大多数情况下,基于空数据集训练的模型优于基于填补数据集训练的模型。

  • noimpute_>_impute小于 50%时,情况正好相反。

所以我们可能会倾向于简单地取这一列的均值,并根据全局均值是高于还是低于 50%来决定哪种方法效果更好。

但是如果我们这样做,我们可能会被偶然性所误导。实际上,如果一个值接近 50%,比如 48%或 52%,这可能很容易是由于随机性造成的。为了考虑这一点,我们需要将其框定为一个统计测试。

设定统计测试

为了避免被偶然性欺骗,我将采取几个预防措施。

首先,我只保留那些原始数据集上的平均精度大于缺失数据集上的平均精度超过 90%的情况(换句话说,我只保留original_>_noimpute > .90 的行)。

我这样做是因为我想保留只有那些缺失值对模型性能有明显负面影响的情况。在这样做之后,最初 112 行的表格中,只剩下 48 行。

其次,我需要计算一个“显著性阈值”,帮助我们理解一个值是否显著。

我们已经说过,当一个值非常接近 50%时,比如 48%或 52%,那么这很可能只是由于偶然性造成的。但是“非常接近”到底有多接近呢?要回答这个问题,我们需要一些统计数据。

我们想要测试的假设(即零假设)是填补或不填补没有区别。这就像说ap_impute大于ap_noimpute(或反之)的概率是 50%。由于我们有 25 次独立的迭代,我们可以通过二项分布计算获得特定结果的概率。

例如,假设在 25 次迭代中,ap_impute在 10 次迭代中优于ap_noimpute。这个结果与我们的假设有多兼容?

from scipy.stats import binom

binom.cdf(k=10, n=25, p=.50) * 2

# result: 0.42435622

因此,获得一个与 10 一样极端的结果的概率,在我们的假设下,是 42%。

请注意,我将二项分布的累积分布函数乘以 2,因为我们对双尾 p 值感兴趣。我们可以通过查看与任何可能结果相关的 p 值来再次确认这一点:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与任何可能结果相关的 p 值(来自 25 次迭代)。 [图片来源:作者]

与 12 和 13 相关的 p 值恰好是 100%。这也是合理的:获得至少与 12 或 13 一样极端的结果的概率必然是 100%,因为它们是最不极端的结果,鉴于我们的假设。

那么,我们的显著性水平是多少?按照惯例,我将取 1%的显著性水平。然而,我们必须考虑到我们有许多次运行,而不仅仅是一次,因此我们必须相应地调整显著性水平。由于我们有 48 次运行,我将使用Bonferroni 校正,简单地将 1%除以 48,得到最终显著性水平 0.0002。

将这个数字与上述 p 值进行比较,这意味着我们将只在运行中成功次数少于 3 次或多于 22 次(包括)时考虑显著。

现在我们已经采取措施以避免被偶然因素欺骗,我们准备查看结果。

为了简化起见,我们将度量标准设为ap_noimpute大于ap_impute的次数百分比。例如,如果基于包含缺失值的数据集的模型在 25 次迭代中的 10 次中表现出比插补数据集的模型更好的平均精度,则该度量为 40%。

由于我们有 48 次运行,我们将得到 48 个值。因此,让我们创建一个柱状图来查看所有值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

实验结果。[作者提供的图片]

红色虚线标识了显著性阈值:任何小于 12%(包括)或大于 88%(包括)的值都是显著的(12%和 88%分别对应 3/25 和 22/25)。

从柱状图中,我们可以看到当我们不插补缺失值时,在 48 次运行中的 30 次中获得了更好的结果(63%)。在这 30 个案例中的 7 个,结果极端到根据 1% p 值及 Bonferroni 调整也是统计显著的。相反,在 18 个插补胜出的案例中,结果从未显著地不同于纯粹的偶然。

基于这些结果,我们可以说插补与不插补之间的差异要么不显著,要么显著地偏向于不插补

简而言之,没有理由插补缺失值。

结论

在这篇文章中,我们通过实证证明了插补与不插补之间的差异要么不显著,要么显著地偏向于不插补。如果你还考虑到不插补缺失值会使你的管道更干净、更快速,那么当可以做到这一点时,留空数据集中的空值应该是标准。

当然,这并非总是可能的。一些不能处理缺失值的模型,比如线性回归或 K 均值。

不过,好消息是,当你使用最常见的表格任务模型(即基于树的模型)时,不插补缺失值而让模型处理它们是最有效和高效的方法。

你可以通过 这个笔记本复现本文中使用的所有代码。

感谢阅读!

如果你觉得我的工作有用,可以订阅 每次我发布新文章时收到邮件 (通常每月一次)

想要对我的工作表示支持吗?你可以 请我喝杯卡布奇诺

如果你愿意, 在 Linkedin 上添加我

你的特征重要吗?这并不意味着它们是好的

原文:towardsdatascience.com/your-features-are-important-it-doesnt-mean-they-are-good-ff468ae2e3d4

“特征重要性”是不够的。如果你想知道哪些特征对模型有益,你还需要关注“错误贡献”。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Samuele Mazzanti

·发表于 Towards Data Science ·10 分钟阅读·2023 年 8 月 21 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

[作者提供的图片]

“重要”和“好”并不是同义词

“特征重要性”这一概念在机器学习中被广泛使用,是最基本的模型可解释性类型。例如,它在递归特征消除(RFE)中被使用,用于迭代地删除模型中最不重要的特征。

然而,对于这一点存在一种误解。

一个特征重要并不意味着它对模型有益!

实际上,当我们说一个特征重要时,这仅仅意味着该特征对模型的预测贡献很高。但我们应考虑到这种贡献可能是错误的

举个简单的例子:一个数据科学家不小心忘记了模型特征中的客户 ID。模型将客户 ID 作为一个高度预测性特征。因此,即使这个特征实际上在降低模型性能,因为它无法在未见数据上良好运行,这个特征的重要性也会很高。

为了让事情更清楚,我们需要区分两个概念:

  • 预测贡献:预测中有多少部分是由于特征;这等同于特征重要性。

  • 错误贡献:预测错误中有多少部分是由于模型中存在该特征。

在本文中,我们将探讨如何计算这些量,并如何利用它们获得有关预测模型的有价值的见解(并加以改进)。

注意:本文专注于回归案例。如果你对分类案例更感兴趣,可以阅读“哪些特征对你的分类模型有害?”

从一个玩具示例开始

假设我们建立了一个模型,以根据人们的工作、年龄和国籍来预测收入。现在我们使用该模型对三个人进行预测。

因此,我们有实际值、模型预测和结果误差:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

实际值、模型预测和绝对误差(以千美元计)。[作者提供的图片]

计算“预测贡献”

当我们有一个预测模型时,我们可以始终将模型预测分解为各个特征带来的贡献。这可以通过 SHAP 值来完成(如果你不知道 SHAP 值的工作原理,可以阅读我的文章:SHAP 值解释:正如你希望有人向你解释的那样)。

所以,假设这些是我们模型对三个人的 SHAP 值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们模型预测的 SHAP 值(以千美元计)。[作者提供的图片]

SHAP 值的主要特性是它们是可加的。这意味着——通过对每一行求和——我们将得到该个体的模型预测。例如,如果我们取第二行:72k $ +3k $ -22k $ = 53k $,这正是模型对第二个体的预测。

现在,SHAP 值是特征对我们预测重要性的良好指标。确实,SHAP 值的绝对值越高,特征对该特定个体预测的影响越大。注意,我在这里讨论的是绝对 SHAP 值,因为符号并不重要:一个特征无论是使预测值上升还是下降,其重要性是相同的。

因此,特征的预测贡献等于该特征绝对 SHAP 值的均值。如果你在 Pandas 数据框中存储了 SHAP 值,这非常简单:

prediction_contribution = shap_values.abs().mean()

在我们的示例中,结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

预测贡献。[作者提供的图片]

如你所见,工作(job)显然是最重要的特征,因为它在平均情况下占据了最终预测的 71.67k $。国籍(nationality)和年龄(age)分别是第二和第三重要的特征。

然而,某个特征对最终预测的重要性并不能说明该特征的表现。为了考虑这一方面,我们需要计算“误差贡献”。

计算“误差贡献”

假设我们想回答以下问题:“如果模型没有特征工作,会做出什么预测?”SHAP 值允许我们回答这个问题。事实上,由于它们是可加的,我们只需从模型做出的预测中减去与特征工作相关的 SHAP 值即可。

当然,我们可以对每个特征重复这一过程。在 Pandas 中:

y_pred_wo_feature = shap_values.apply(lambda feature: y_pred - feature)

结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果去掉相应特征后得到的预测值。[图片由作者提供]

这意味着,如果我们没有工作这个特征,那么模型会预测第一位个体 20k 美元,第二位个体-19k 美元,第三位个体-8k 美元。相反,如果我们没有年龄这个特征,模型会预测第一位个体 73k 美元,第二位个体 50k 美元,以此类推。

如你所见,如果我们去掉不同的特征,每个个体的预测变化很大。因此,预测误差也会非常不同。我们可以轻松地计算它们:

abs_error_wo_feature = y_pred_wo_feature.apply(lambda feature: (y_true - feature).abs())

结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果去掉相应特征后得到的绝对误差。[图片由作者提供]

这些是我们如果去掉相应特征后会得到的错误。直观地说,如果误差很小,那么去掉该特征对模型没有问题——甚至是有益的。如果误差很高,那么去掉特征则不是一个好主意。

但我们还可以做更多。事实上,我们可以计算完整模型的错误与去掉特征后得到的错误之间的差异:

error_diff = abs_error_wo_feature.apply(lambda feature: abs_error - feature)

这就是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模型的错误与去掉特征后我们会得到的错误之间的差异。[图片由作者提供]

如果这个数字是:

  • 如果是负的,那么特征的存在会减少预测误差,因此该特征对该观察结果很有效!

  • 如果是正的,那么特征的存在会导致预测误差增加,因此该特征对该观察结果是不利的。

我们可以计算“误差贡献”,作为每个特征这些值的均值。在 Pandas 中:

error_contribution = error_diff.mean()

结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

误差贡献。[图片由作者提供]

如果这个值是正的,那么这意味着,平均而言,特征的存在导致模型的错误增加。因此,没有这个特征,预测会更好。换句话说,这个特征的负面影响大于正面影响!

相反,这个值越负,特征对预测的益处越大,因为其存在导致更小的误差。

让我们尝试在实际数据集上使用这些概念。

预测黄金回报

从现在起,我将使用来自Pycaret(一个MIT 许可的 Python 库)的数据集。该数据集名为“Gold”,包含财务数据的时间序列。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集样本。特征都以百分比表示,因此 -4.07 意味着 -4.07% 的回报。[图片由作者提供]

特征由观察时刻前 22、14、7 和 1 天的金融资产回报组成(“T-22”、“T-14”、“T-7”、“T-1”)。这是所有用作预测特征的金融资产的详尽列表:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可用资产列表。每个资产在时间 -22、-14、-7 和 -1 被观察到。[作者提供的图片]

总共有 120 个特征。

目标是预测 22 天后的黄金价格(回报) (“Gold_T+22”)。让我们来看看目标变量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

变量的直方图。[作者提供的图片]

一旦加载了数据集,这些是我进行的步骤:

  1. 随机拆分完整数据集:33%分配给训练数据集,另 33%分配给验证数据集,其余 33%分配给测试数据集。

  2. 在训练数据集上训练一个 LightGBM 回归模型。

  3. 使用前一步训练的模型对训练、验证和测试数据集进行预测。

  4. 计算训练、验证和测试数据集的 SHAP 值,使用 Python 库“shap”。

  5. 计算每个特征在每个数据集上的预测贡献和错误贡献(训练集、验证集和测试集),使用我们在前一段看到的代码。

比较预测贡献和错误贡献

比较训练数据集中错误贡献和预测贡献。我们将使用散点图,其中点表示模型的 120 个特征。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

预测贡献与错误贡献(在训练数据集上)。[作者提供的图片]

在训练集中,预测贡献与错误贡献之间存在高度负相关。

这很有意义:因为模型在训练数据集上学习,它倾向于将高重要性(即高预测贡献)分配给那些导致预测错误大幅减少的特征(即高度负的错误贡献)

但这并没有增加我们对知识的了解,对吧?

确实,我们真正关心的是验证数据集。验证数据集实际上是我们可以用来了解特征在新数据上表现的最佳代理。因此,让我们在验证集上做相同的比较。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

预测贡献与错误贡献(在验证数据集上)。[作者提供的图片]

从这个图中,我们可以提取出一些更有趣的信息。

图表右下角的特征是模型正确地赋予高重要性的特征,因为这些特征实际上降低了预测错误。

同时,请注意“Gold_T-22”(观察期前 22 天的黄金回报)与模型赋予的权重相比表现非常好。这意味着这个特征可能存在欠拟合。这一点特别有趣,因为黄金是我们试图预测的资产(“Gold_T+22”)。

另一方面,误差贡献高于 0 的特征使我们的预测变得更差。例如,“US Bond ETF_T-1”平均改变了模型预测 0.092%(预测贡献),但使模型预测比没有该特征时平均差 0.013%(误差贡献)。

我们可以假设所有具有高误差贡献(相对于其预测贡献)的特征可能存在过拟合,或者通常它们在训练集和验证集中的表现不同。

让我们看看哪些特征的误差贡献最大。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按误差贡献递减排序的特征。[图片来源]

现在是误差贡献最低的特征:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按误差贡献递增排序的特征。[图片来源]

有趣的是,我们可以观察到所有高误差贡献的特征与 T-1(观察时刻前 1 天)相关,而几乎所有低误差贡献的特征与 T-22(观察时刻前 22 天)相关。

这似乎表明最新的特征容易过拟合,而时间更久远的特征往往更容易泛化

请注意,如果没有误差贡献,我们永远不会知道这个洞察。

使用误差贡献的 RFE

传统递归特征消除(RFE)方法基于移除不重要的特征。这相当于首先移除预测贡献小的特征。

然而,根据我们在上一段所说的,首先移除误差贡献最大的特征会更有意义。

为了验证我们的直觉,让我们比较这两种方法:

  • 传统 RFE:首先移除无用特征(预测贡献最低)。

  • 我们的 RFE:首先移除有害特征(误差贡献最高)。

让我们看看验证集上的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

两种策略在验证集上的平均绝对误差。[图片来源]

每种方法的最佳迭代已被圈出:传统 RFE(蓝线)的模型有 19 个特征,而我们的 RFE(橙线)的模型有 17 个特征。

一般而言,我们的方法似乎效果良好:移除误差贡献最大的特征比移除预测贡献最大的特征会导致一致较小的 MAE。

然而,你可能会认为这只因为我们对验证集进行了过拟合。毕竟,我们关心的是在测试集上获得的结果。

所以让我们看看在测试集上的相同比较。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

测试集上两种策略的均绝对误差。[作者提供的图片]

结果与之前的相似。即使两条线之间的距离较小,但通过移除最高误差贡献者获得的 MAE 显然优于通过移除最低预测贡献者获得的 MAE。

由于我们选择了在验证集上导致最小 MAE 的模型,让我们看看它们在测试集上的结果:

  • RFE-预测贡献(19 个特征)。测试集上的 MAE:2.04。

  • RFE-误差贡献(17 个特征)。测试集上的 MAE:1.94。

因此,使用我们的方法的最佳 MAE 比传统 RFE 提高了 5%!

结论

特征重要性概念在机器学习中起着基础性的作用。然而,“重要性”这一概念常常被误解为“优越性”。

为了区分这两个方面,我们引入了两个概念:预测贡献和误差贡献。这两个概念都基于验证数据集的 SHAP 值,文章中我们展示了计算这些值的 Python 代码。

我们还在一个真实的金融数据集上进行了尝试(该任务是预测黄金价格),并证明基于误差贡献的递归特征消除方法相比于基于预测贡献的传统 RFE,能使均绝对误差提高 5%。

所有用于本文的代码可以在 这个笔记本找到。

感谢阅读!

如果你觉得我的工作有用,你可以 每次我发布新文章时收到邮件 (通常是每月一次)。

如果你想支持我的工作,你可以 请我喝咖啡

如果你愿意, 可以在 Linkedin 加我

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值