解释性神经基础模型
原文:
towardsdatascience.com/neural-basis-models-for-interpretability-fd04ac958ff2
解读 Meta AI 提出的新解释性模型
·发表于 Towards Data Science ·6 分钟阅读·2023 年 10 月 11 日
–
机器学习和人工智能在各个领域的广泛应用带来了更高的风险和伦理评估挑战。如在 ProPublica 报道的刑事再犯模型 中所见,机器学习算法可能存在严重的偏见,因此需要强有力的解释机制,以确保这些模型在高风险领域的信任和安全。
那么,我们如何在解释性、准确性和模型表现力之间取得平衡呢?Meta AI 的研究人员提出了一种新的方法,称为神经基础模型(NBMs),这是一种广义加性模型的子家族,在基准数据集上实现了最先进的性能,同时保持了透明的解释性。
在这篇文章中,我旨在解释 NBM 及其为何是一个有益的模型。像往常一样,我鼓励大家阅读原始论文。
如果你对解释性机器学习和其他伦理 AI 方面感兴趣,考虑查看我的其他文章并关注我!
解释性与伦理 AI
查看列表5 篇故事
背景:GAMs
NBM 被认为是广义加性模型(GAM)。GAM 本质上是可解释的模型,为每个特征学习一个形状函数,预测是通过“查询”形状函数来完成的。由于这些形状函数是独立的,通过可视化这些形状函数可以理解特征对预测的影响,使得模型高度可解释。变量之间的交互通过将多个变量传递到同一函数中并基于此构建形状函数来建模(通常将变量数量限制为 2 以提高互操作性),这种配置称为 GA2M。
GAMs 和 GA2Ms 的方程(图源自 Radenovic 等人 [1])
各种 GAM 和 GA2M 模型使用不同的机制来开发这些形状函数。可解释增强机(EBM)[2] 使用一组对每个特征进行训练的提升树,神经加性模型(NAMs)[3] 对每个特征使用深度神经网络,而 NODE-GAM [4] 使用 无意识神经树[6] 的集合。我推荐阅读以下文章以获得对这些模型的更详细解释:EBM 和 NODE-GAM/NAM。
NBM 方法
神经基础模型(NBM)是一类新的广义加性模型(GAMs)子家族,利用形状函数的基础分解。
NBM 架构(图源自 Radenovic 等人 [1])
与其他 GAM 模型(如 NAM[3])不同,后者有效地为每个特征训练独立模型以构建形状函数,而 NBM 架构则依赖于少量的基础函数,这些函数在所有特征中共享,并为特定任务共同学习。这些函数是什么?它们是函数逼近的瑞士军刀:深度神经网络。
实际上,一个通用的 MLP 主干网络接受 1 个输入并输出 B 个值,这些值被训练并应用于每个输入特征。这些输出然后被线性组合以形成给定特征的最终预测,而线性组合的权重对每个特征不同。另一种思考这种架构的方法是通过编码器-解码器网络的视角。所有特征共享相同的 编码器(通用 MLP 主干),但每个特征都有其自己的 解码器(编码的线性变换)。每个特征的解码值然后被加总以生成最终预测。
这可以很容易地扩展到包括特征交互。如果我们想要建模配对交互,我们可以包括一个接受两个输入的 MLP,而不是一个。
NBM 和 NB2M 方程(图源自 Radenovic 等人 [1])
使用共享 MLP 主干而不是为每个特征使用不同的 MLP 的一个好处是模型的显著较小的尺寸。这使得 NBM 非常适合处理极高维度数据的任务。
性能与优势
为了测试他们的架构,Radenovic 等人(2022 年)将 NBM 与各种其他模型进行了比较,如线性回归、EBM [2]、NAM [3]、XGBoost[5] 和 MLP。他们的首次评估是在混合的表格和图像数据集上进行的。
基准性能比较(图来自 Radenovic 等人 [1])
总的来说,NBM 牢牢站稳脚跟,超越了其他可解释模型,甚至在某些数据集上超过了 MLP。
Radenovic 等人(2022 年)还在纯表格数据集上进行了另一项评估,重点是对比 SOTA GAM 模型。
与其他 GAMs 的性能比较:(图来自 Radenovic 等人 [1])
这个比较清楚地展示了 NBM 的强大,几乎在每个数据集上都击败了竞争对手。如前所述,NBM 的可扩展性也非常出色。如下所示,在高维数据任务中,NBM 的参数数量几乎是 NAM 的 70 分之一。
NAM/NBM 参数比较。X 轴是数据的维度(图来自 Radenovic 等人 [1])
结论
总体而言,NBM 是极其强大且轻量级的模型,由于其为 GAM,本质上是可解释的。然而,这并不意味着它是解决高风险机器学习问题的灵丹妙药。在使用这些模型时,仍然需要考虑许多因素。例如,一个本质上可解释的模型几乎没有意义,如果输入到模型中的特征不可解释。
此外,虽然 NBM 的规模相较于 NAM 扩展得很好,但可解释性却没有。没有人能查看数千个特征归因图,特别是当这些归因图还包含成对交互时。这意味着在大参数空间下,仍然需要预处理方法,如特征选择,甚至作者也承认了这一点。然而,这并不贬低作者,因为这是一个仍然非常有用且相对容易实现和调整的模型。
该模型是 GAM 的事实对于机器学习应用(如移动设备和其他性能较低的设备)也非常有利,因为用户可以训练模型并部署生成的特征归因函数,而不是完整模型,从而实现极其快速且内存轻量的推理,而不会损失准确性。
资源与参考文献
-
NBM 开放评论:
openreview.net/forum?id=fpfDusqKZF
-
如果你对可解释的机器学习或时间序列预测感兴趣,可以考虑关注我:
medium.com/@upadhyan
。 -
查看我关于可解释机器学习的其他文章:
medium.com/@upadhyan/list/interpretable-and-ethical-ai-f6ee1f0b476d
参考文献
[1] Radenovic, F.、Dubey, A. 和 Mahajan, D.(2022)。用于可解释性的神经基础模型。神经信息处理系统进展,35,8414–8426。
[2] Yin L.、Rich C.、Johannes G. 和 Giles H.(2013)准确可理解的模型与配对交互。在第 19 届 ACM SIGKDD 国际会议上的知识发现与数据挖掘论文集,623–631. 2013。
[3] Agarwal, R.、Melnick, L.、Frosst, N.、Zhang, X.、Lengerich, B.、Caruana, R. 和 Hinton, G. E.(2021)。神经加性模型:使用神经网络的可解释机器学习。神经信息处理系统进展,34,4699–4711。
[4] Chang, C.H.、Caruana, R. 和 Goldenberg, A.(2022)。NODE-GAM:用于可解释深度学习的神经广义加性模型。在国际学习表征会议。
[5] Chen, T. 和 Guestrin, C.(2016 年 8 月)。Xgboost:一个可扩展的树提升系统。在第 22 届 ACM SIGKDD 国际会议上的知识发现与数据挖掘论文集(第 785–794 页)。
[6] Popov, S.、Morozov, S. 和 Babenko, A.(2019)。神经网络忽略决策集成用于表格数据的深度学习。在第八届国际学习表征会议。
神经图数据库
图神经网络数据库的最新进展
图数据管理的新里程碑
·
关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 3 月 28 日
–
我们引入了神经图数据库的概念,作为图数据库发展的下一步。神经图数据库专为大规模不完整图设计,并利用图表示学习进行缺失边的即时推理。神经推理保持了较高的表达能力,支持类似于标准图查询语言的复杂逻辑查询。
图片由作者提供,辅助工具为 Stable Diffusion。
本文由 Hongyu Ren, Michael Cochez* 和* Zhaocheng Zhu 共同撰写,基于我们最新的论文 Neural Graph Reasoning: Complex Logical Query Answering Meets Graph Databases。你也可以关注 我, Hongyu, Michael* 和* Zhaocheng 在 Twitter 上的动态。查看我们的 项目网站 获取更多资料。
概述:
-
神经图数据库:什么和为什么?
-
NGDBs 的蓝图
-
神经图存储
-
神经查询引擎
-
查询引擎的神经图推理
-
NGDBs 的开放挑战
-
了解更多
神经图数据库:什么和为什么?
🍨香草图数据库几乎随处可见,这要归功于不断增长的生产图、灵活的图数据模型和富有表现力的查询语言。经典的符号图数据库在一个重要假设下运行得又快又酷:
完整性。查询引擎假设经典图数据库中的图是完整的。
在完整性假设下,我们可以构建索引,以多种读写优化格式存储图,并期望数据库返回有什么。
但这一假设在实际中往往不成立(我们会说,几乎总是不成立)。例如在一些突出的知识图谱(KGs)中:在 Freebase 中,93.8%的人没有出生地,78.5%没有国籍,约 68%的人没有任何职业,而在 Wikidata 中,约 50%的艺术家没有出生日期,只有0.4%的已知建筑有高度信息。这仅仅是由数百名爱好者公开编辑的最大 KG,100M 节点和 1B 语句并不是行业中最大的图,所以你可以想象其不完整性的程度。
显然,为了考虑不完整性,除了**“有什么?”我们还必须问“缺少什么?”**(或“可以有什么?”)。让我们来看一个例子:
(a) - 输入查询;(b) — 带有预测边(虚线)的不完整图;© — 通过图遍历返回一个答案(UofT)的 SPARQL 查询;(d) — 神经执行恢复缺失的边,并返回两个新答案(UdeM, NYU)。图片来源:作者。
在这里,给定一个不完整的图(缺失边 (Turing Award, win, Bengio)
和 (Deep Learning, field, LeCun)
)以及一个查询 “在深度学习领域的图灵奖得主在哪些大学工作?”(以逻辑形式或类似 SPARQL 的语言表达),符号图数据库只会返回一个通过图遍历得到的答案 UofT。我们将这种答案称为 简单 答案,或现有答案。考虑到缺失的边,我们可以恢复两个更多的答案 UdeM 和 NYU(困难 答案,或推断答案)。
如何推断缺失的边?
-
在经典数据库中,我们选择不多。基于 RDF 的数据库具有一些形式语义,可以由庞大的 OWL 本体支持,但根据图的大小和推理的复杂性,在 SPARQL 推理规则 中完成推理可能需要无限的时间。标记属性图(LPG)数据库完全没有内置的推断缺失边的手段。
-
得益于图机器学习的进展,我们通常可以在潜在(嵌入)空间中以线性时间执行链接预测!然后,我们可以将这种机制扩展到在嵌入空间中执行复杂的、类似数据库的查询。
神经图数据库结合了传统图数据库和现代图机器学习的优势。
即,数据库原则如(1)图作为一等公民,(2)高效存储,以及(3)统一查询接口,现在由图 ML 技术支持,如(1)几何表示,(2)对噪声输入的鲁棒性,(3)大规模预训练和微调,以弥合不完整性差距并实现神经图推理和推断。
一般来说,NGDB 的设计原则包括:
-
数据不完整性假设 — 潜在数据可能在节点、链接和图级别上缺少信息,我们希望推断并在查询回答中加以利用;
-
归纳性和可更新性 — 类似于传统数据库,允许更新和即时查询,构建图潜变量的表示学习算法必须具有归纳性,并以零样本(或少样本)方式对未见数据(新实体和关系)进行泛化,以防止昂贵的再训练(例如,浅层节点嵌入);
-
表达能力 — 潜在表示在数据中编码逻辑和语义关系的能力,类似于 FOL(或其片段),并在查询回答中加以利用。实际上,神经推理支持的逻辑操作符集应接近或等同于标准图数据库语言,如 SPARQL 或 Cypher;
-
超越知识图谱的多模态性——任何可以作为节点或记录存储在经典数据库中的图结构数据(例如图像、文本、分子图或带时间戳的序列),并且可以赋予向量表示的,都是神经图存储和神经查询引擎的有效来源。
解决 NGDB 原则的关键方法是:
-
向量表示作为原子元素——虽然传统的图数据库在许多索引中对邻接矩阵(或边列表)进行哈希处理,但不完全性假设意味着给定的边和图潜在(向量表示)都成为真理的来源,在神经图存储中。
-
在潜在空间中的神经查询执行——由于不完全性假设,基本操作如边遍历不能仅通过符号操作来执行。相反,神经查询引擎在邻接和图潜在空间上操作,以将可能缺失的数据纳入查询回答中;
实际上,通过在潜在空间中回答查询(且不牺牲遍历性能),我们可以完全抛弃符号数据库索引。
符号图数据库和神经图数据库之间的主要区别:传统的数据库通过边遍历回答“有什么?”的问题,而神经图数据库还会回答“缺少什么?”的问题。图像来源:作者。
NGDBs 的蓝图
在深入了解 NGDBs 之前,我们先来看一下神经数据库的一般情况——事实证明它们已经存在了一段时间,你可能已经注意到了。许多机器学习系统在数据被编码为模型参数时,已经在这一范式下运行,而查询相当于前向传播,可以为下游任务输出新的表示或预测。
神经数据库概述
神经数据库的现状如何?它的不同种类之间有什么区别,NGDBs(神经图数据库)有什么特别之处?
向量数据库、自然语言数据库和神经图数据库之间的区别。图像来源:作者
-
向量数据库属于存储导向的系统,这些系统通常基于近似最近邻库(ANN),如Faiss或ScaNN(或定制解决方案)来回答基于距离的查询,使用最大内积搜索(MIPS)、L1、L2 或其他距离。由于向量数据库与编码器无关(即,任何生成向量表示的编码器,如 ResNet 或 BERT,都可以作为来源),它们速度很快,但缺乏复杂的查询回答能力。
-
最近,随着大规模预训练模型的崛起——或称为基础模型——我们见证了它们在自然语言处理和计算机视觉任务中的巨大成功。我们认为,这些基础模型也是神经数据库的一个重要例子。在这些模型中,存储模块可能直接以模型参数的形式呈现,或者外包给一个外部索引,这在检索增强模型中常常使用,因为将所有世界知识编码到即便是数十亿个模型参数中也是困难的。查询模块通过填充编码器模型(BERT 或 T5 风格)中的空白或通过解码器模型(GPT 风格)中的提示,进行上下文学习,这些提示可以跨越多种模式,例如视觉应用的可学习标记或甚至调用外部工具。
-
Thorne et al介绍的**自然语言数据库 (NLDB)**将原子元素建模为通过预训练语言模型(LM)编码为向量的文本事实。对 NLDB 的查询以自然语言表达的形式发送,这些查询被编码为向量,查询处理采用检索器-阅读器方法。
神经图数据库并不是一个新名词——许多图机器学习方法尝试将图嵌入与数据库索引结合起来,或许RDF2Vec和LPG2Vec是一些最显著的例子,展示了如何将嵌入插件到现有图数据库中,并在符号索引之上运行。
相比之下,我们认为 NGDB 可以在潜在空间中无需符号索引直接工作。如下面所示,存在能够模拟嵌入空间中精确边遍历行为的机器学习算法,以检索“那里有什么”,并进行神经推理以回答“缺少什么”。
神经图数据库:架构
神经图数据库的概念图。输入查询由神经查询引擎处理,其中规划器导出查询的计算图,执行器在潜在空间中执行查询。神经图存储使用图存储和特征存储在嵌入存储中获取潜在表示。执行器与嵌入存储通信,以检索和返回结果。图像来源于作者
在更高层次上,NGDB 包含两个主要组件:神经图存储和神经查询引擎。查询回答流程从某些应用程序或下游任务发送的已结构化格式的查询开始(例如,通过语义解析将初始自然语言查询转换为结构化格式)。
查询首先到达神经查询引擎,特别是到查询规划器模块。查询规划器的任务是根据查询复杂性、预测任务和底层数据存储(如可能的图划分)生成一个高效的原子操作计算图。
生成的计划随后被送往查询执行器,该执行器将查询编码到潜在空间中,执行对底层图及其潜在表示的原子操作,并将原子操作的结果聚合成最终答案集。执行是通过与神经图存储通信的检索模块完成的。
存储层包括
1️⃣ 图存储 用于以空间和时间高效的方式保存多关系邻接矩阵(例如,以各种稀疏格式如 COO 和 CSR)。
2️⃣ 特征存储 用于保存与底层图相关的节点和边级多模态特征。
3️⃣ 嵌入存储 利用编码器模块生成基于底层邻接和相关特征的潜在空间中的图表示。
检索模块查询编码后的图表示,以构建潜在答案的分布。
神经图存储
在传统的图数据库(右侧),查询被优化为一个计划(通常是一个连接操作符的树),并执行于数据库索引的存储中。在神经图数据库(左侧)中,我们将查询(或其步骤)编码到一个潜在空间中,并在底层图的潜在空间中执行。图像由作者提供。
在传统的图数据库中,存储设计通常取决于图建模范式。
两种最流行的范式是资源描述框架(RDF)图和标记属性图(LPG)。然而,我们认为新的 RDF-star(及其伴随的 SPARQL-star)将统一这两种范式,将 RDF 图的逻辑表达性与 LPG 的属性特性融合起来。许多现有的知识图谱已经遵循了类似 RDF-star 的范式,如 超关系知识图谱 和 Wikidata Statement Model。
如果我们展望未来几年的骨干图建模范式,我们会选择 RDF-star。
在神经图存储中,输入图及其向量表示都是事实来源。为了在潜在空间中回答查询,我们需要:
-
查询编码器
-
图编码器
-
检索机制用于将查询表示与图表示进行匹配
图编码(嵌入)过程可以视为一个压缩步骤,但保留了实体/关系的语义和结构相似性。嵌入空间中实体/关系之间的距离应该与语义/结构相似性正相关。编码器的架构有很多选择——我们建议坚持使用归纳型的,以遵循 NGDB 设计原则。在我们最近的NeurIPS 2022 工作中,我们展示了两个这样的归纳模型。
查询编码通常与自然图编码相匹配,使得它们处于同一空间。一旦我们有了潜在表示,检索模块就会启动以提取相关答案。
检索过程可以被视为在嵌入空间中对输入向量的最近邻搜索,并且具有 3 个直接好处:
-
每个检索项的置信度评分——多亏了嵌入空间中预定义的距离函数。
-
潜在空间和距离函数的不同定义——针对不同的图,例如,树状图在双曲空间中更易于处理。
-
效率和可扩展性——检索可以扩展到包含数十亿节点和边的极大图。
神经查询引擎
NGDBs(左)和传统图数据库(右)的查询规划。NGDB 规划(假设图不完整)可以逐步自回归执行(1)或完全生成一个步骤(2)。传统数据库规划是基于成本的,并且依赖于元数据(假设图完整并从中提取),例如中间答案的数量来构建连接操作符的树。图片由作者提供
在传统数据库中,典型的查询引擎执行三个主要操作。(1) 查询解析以验证语法正确性(通常会进行更深层次的语义分析);(2) 查询规划和优化以得出一个有效的查询计划(通常是关系操作符的树),以最小化计算成本;(3) 查询执行根据查询计划扫描存储并处理中间结果。
将这些操作扩展到 NGDBs 是相当简单的。
1️⃣ 查询解析可以通过语义解析转化为结构化查询格式。我们故意将 NGDBs 的查询语言讨论留待未来的工作和热烈的公众讨论😉
2️⃣ 查询规划器得出原子操作(投影和逻辑操作符)的有效查询计划,最大化完整性(必须返回所有现有边上的答案)和推断(即时预测缺失边)同时考虑查询复杂性和底层图。
3️⃣ 一旦查询计划完成,查询执行器将查询(或其部分)编码到潜在空间,与图存储及其检索模块进行通信,并将中间结果聚合到最终答案集中。查询执行存在两种常见机制:
-
原子,类似于传统数据库,当查询计划按顺序执行,通过编码原子模式、检索其答案和执行逻辑操作作为中间步骤;
-
全局,当整个查询图在一个步骤中被编码并在潜在空间中执行。
神经查询执行的主要挑战是将查询表达能力与 SPARQL 或 Cypher 等符号语言匹配——迄今为止,神经方法可以执行接近一阶逻辑表达能力的查询,但在符号语言方面还差一半。
神经图推理的分类学用于查询引擎
自 2018 年以来,关于复杂逻辑查询回答的神经方法(即查询嵌入)的文献不断增加,特别是 Hamilton 等人 在图查询嵌入(GQE)方面的开创性 NeurIPS 工作。GQE 能够回答带有交集的联接查询,并实时预测缺失的链接。
GQE 可以被视为对 NGDBs 的神经查询引擎的第一次尝试。
GQE 开创了图机器学习的整个子领域,随后出现了一些著名的例子,如 Query2Box (ICLR 2020) 和 Continuous Query Decomposition (ICLR 2021)。我们进行了一项重大工作,将所有这些(约 50 项)工作按 3 个主要方向进行了分类:
⚛️ 图——我们回答查询的基础结构是什么;
🛠️ 建模——我们如何回答查询以及采用了哪些归纳偏差;
🗣️ 查询——我们回答什么,查询结构是什么,以及预期的答案是什么。
复杂逻辑查询回答的神经方法分类。有关更多详细信息,请参见。图像由作者提供
⚛️ 说到图,我们进一步将其细分为模态(经典的三元组图、超关系图、超图等)、推理领域(离散实体或包括连续输出)和语义(神经编码器如何捕捉更高阶关系,如 OWL 本体)。
🛠️ 在建模中,我们遵循编码器-处理器-解码器范式,对现有模型的归纳偏差进行分类,例如,具有神经或神经符号处理器的传递性或归纳编码器。
🗣️ 在 查询 中,我们的目标是将神经方法能够回答的查询集与符号图查询语言的查询集进行映射。我们讨论查询操作符(超越标准的与/或/非),查询模式(从链状查询到 DAG 和循环模式),以及投影变量(你喜欢的关系代数)。
NGDB 的开放挑战
分析分类法时,我们发现目前没有银弹,例如,大多数处理器只能在离散模式下处理基于树的查询。但这也意味着未来有很大的工作空间——可能包括你的贡献!
更具体地说,以下是未来几年 NGDB 的主要挑战。
沿着 图 分支:
-
模态:支持更多图的模态:从经典的仅三元组图到超关系图、超图以及结合图、文本、图像等的多模态源。
-
推理领域:支持对时间和连续(文本和数值)数据进行逻辑推理和神经查询回答——字面量构成了图的大部分以及对字面量的相关查询。
-
背景语义:支持复杂公理和形式语义,这些语义编码了(潜在的)实体类及其层次结构之间的高阶关系,例如,支持对描述逻辑和 OWL 片段进行神经推理。
在 建模 分支:
-
编码器:支持在推理时处理未见过的关系——这是(1)可更新性的关键,无需重新训练即可更新神经数据库;(2)启用预训练-微调策略,将查询回答推广到具有自定义关系模式的自定义图。
-
处理器:表达性处理器网络能够有效且高效地执行类似于 SPARQL 和 Cypher 操作符的复杂查询操作符。提高神经处理器的样本效率对于训练时间与质量权衡至关重要——在保持高预测质量的同时减少训练时间。
-
解码器:迄今为止,所有神经查询回答解码器仅在离散节点上操作。扩展答案范围到连续输出对于回答现实世界的查询至关重要。
-
复杂性:由于处理器网络的主要计算瓶颈是嵌入空间的维度(对于纯神经模型)和/或节点数(对于神经-符号模型),新型高效的神经逻辑操作符和检索方法是将 NGDB 扩展到数十亿节点和万亿边的关键。
在 查询 中:
-
操作符:使更复杂的查询操作符具备与声明式图查询语言相匹配的表达能力,例如,支持克林星号和加号、属性路径、过滤器。
-
模式:回答比树状查询更复杂的模式,包括 DAG 和循环图。
-
投影变量:允许投影超出最终叶节点实体,即允许返回中间变量、关系以及组织在元组(绑定)中的多个变量。
-
表达能力:回答超出简单 EPFO 和 EFO 片段的查询,并追求数据库语言的表达能力。
最后,在数据集和评估方面:
-
需要更大且多样化的基准,涵盖更多图模式、更具表现力的查询语义、更多查询操作符和查询模式。
-
由于现有的评估协议似乎有限(仅关注推断硬答案),需要一个更有原则的评估框架和指标,涵盖查询回答工作流的各个方面。
关于神经图存储和 NGDB 的一般情况,我们识别出以下挑战:
-
需要一个可扩展检索机制来将神经推理扩展到数十亿节点的图。检索与查询处理器及其建模先验紧密相关。现有的可扩展 ANN 库只能处理基本的 L1、L2 和余弦距离,这限制了神经查询引擎中可能的处理器空间。
-
目前,所有复杂查询数据集提供了一个硬编码的查询执行计划,可能不是最优的。需要一个神经查询规划器,能够将输入查询转换为最优执行序列,考虑预测任务、查询复杂性、神经处理器类型和存储层配置。
由于编码器的归纳性和可更新性而无需重新训练,运行推理时在比训练图更大的图上存在需要缓解持续学习、灾难性遗忘和规模泛化的问题。
了解更多
NGDB 仍然是一个新兴概念,面临许多未来研究的挑战。如果你想了解更多关于 NGDB 的内容,可以查看
我们还将组织研讨会,请关注最新动态!
神经网络 — 初学者指南 (1.1)
建立关于神经网络的直觉
·
关注 发表在 Towards Data Science ·10 min read·2023 年 3 月 20 日
–
照片由 La-Rel Easter 拍摄,来源于 Unsplash
深度学习在过去十年中经历了巨大的增长。它在图像分类、语音识别、文本转语音、自驾车等方面都有应用,深度学习解决的问题列表非常重要。因此,理解神经网络的基本结构和工作原理对于欣赏这些进展是必要的。
让我们深入探讨学习。
1. 神经网络的构建模块
神经网络是一个计算学习系统,通过使用底层的非线性映射函数来将输入变量映射到输出变量。
它包含五个基本组件:
a. 节点和层
b. 激活函数
c. 损失函数
d. 优化器
我们将详细了解这些组件。
层:
简而言之,神经网络是一系列相互连接的层。神经网络中有三种层类型:输入层 — 接受输入数据,隐藏层 — 转换输入数据,输出层 — 在应用转换后为给定的输入生成预测。接近输入层的层称为下层,接近输出层的层称为上层。
每一层由多个神经元组成,也称为节点。给定层中的每个节点与下一层中的每个节点相连。节点接收来自上一层的加权输入总和,应用非线性激活函数,并生成一个输出,该输出随后成为下一层节点的输入。
考虑一个常见的分类问题,例如预测贷款申请者是否会违约。输入变量包括申请者年龄、就业类型、赡养人数、居住地、贷款价值比等。这些变量将组成输入层。
输入层中的节点数量对应于数据中的独立变量数量。隐藏层的数量以及这些层中的节点数是超参数,通常是问题复杂性和可用数据的函数。
在复杂问题中,层的数量和每层中的节点数量将更多,每个隐藏层将学习在上一层未学到的表示。这些神经网络被称为‘深度神经网络’。
对于回归问题,输出层中的节点数量为 1;对于多分类问题,输出层中的节点数量等于标签/类别数量;对于二分类问题,输出层中的节点数量为 1。
神经网络的工作原理可以归结为给定层中的单个节点。
神经网络中单个节点的工作原理(图片由作者提供)
如上所示,单个节点接受以下输入 — 偏置 b 和输入变量 x1 及 x2。它还接受另一个参数作为输入 — 每个独立变量的权重。权重表示输入变量的重要性。
节点将处理加权输入总和,如下所示:
z = w1x1 + w2x2 + bias(公式 1)
然后在给定层中的每个节点上应用激活函数以生成输出。应用激活函数后由节点生成的输出是 a。
f(z) = a(公式 2)
这是神经网络中单层单节点的工作原理。具有多个层和节点的网络也按照相同的原则运行。
2 层神经网络(作者提供的图片)
除了加权输入外,我们还可以看到在上述公式 1 中有一个叫做偏置 ‘b’ 的项。偏置在神经网络中有什么作用?
偏置 是一个帮助激活节点的变量。偏置是激活节点所需的阈值的负值。在给定层中的所有节点中使用一个单独的偏置值。
批次数据通过输入层传递,输入层将其发送到第一个隐藏层。第一个隐藏层中的神经元将基于激活函数的输出进行激活,激活函数接收输入的加权和与偏置并计算特定范围内的一个数字。
这引出了下一个问题 — 什么是激活函数,我们为什么需要它?
激活函数
用简单的术语来说:
激活函数用于将节点的输入转换为传递到下一个隐藏层节点的输出值。
用技术术语来说:
激活函数,也称为传递函数,定义了如何将输入的加权和与偏置转换为给定层中的节点的输出。它将输出值映射到特定范围,即 0 到 1 或 -1 到 +1,具体取决于所用的函数类型。
神经网络中使用的激活函数有两种类型 — 线性和非线性。
- 线性激活函数:
公式为 f(x) = b + Sigma( wi * xi),对所有输入变量 (i) 进行索引。
该函数的范围是:— 无穷大到 + 无穷大。
线性激活函数用于神经网络的外层,以解决回归问题。在输入层或隐藏层中使用它不是一个好主意,因为网络将无法捕捉底层数据中的复杂关系。
2. 非线性激活函数:
非线性激活函数默认是深度学习中最常用的激活函数。这些包括 Sigmoid 或 Logistic 函数、修正线性激活函数(ReLU)和双曲正切函数(Tanh)。
让我们更详细地了解每个。
- Sigmoid 激活函数:
也称为 Logistic 函数,它接受任何实值作为输入,并在 0 和 1 之间给出输出。
公式为 y = 1/(1+ e^-z),具有 S 形曲线。这里的 z = b + sigma(xi * wi),对 i 输入变量进行索引。
对于非常大的正数 z,e^-z 将为 0,函数的输出将为 1。对于非常大的负数 z,e^-z 将是一个大数,因此函数的输出将为 0。
2. 修正线性激活函数 (ReLU):
它是今天使用最广泛的激活函数。ReLU 具有对所有大于 0 的输入值是线性的,而对其他值则是非线性的属性。
它表示为 f(x) = max(0,x)
3. 双曲正切激活函数:
类似于逻辑函数,它接受任何实数作为输入,并输出范围在 -1 和 +1 之间的值。
它表示为:f(x) = (e^z — e^-z) / (ez+e-z)。其中 z = b + sigma(xi * wi),索引为 i 个输入变量。
Tanh 函数的形状也是 S 形的,但范围不同。
通常在所有层中使用一个激活函数,唯一的例外是输出层。输出层使用的激活函数取决于问题陈述是否要求我们预测一个连续值,即回归,或一个分类值,即二分类或多标签分类。
因此,神经元可以定义为一个包含两个部分的操作——线性组件和激活组件,即神经元 = 线性 + 激活。
上述所有函数及其变体都有一些限制,我将在下一篇文章中介绍。
那么神经网络是如何学习的?
所有参数的权重都以一些随机值进行初始化。加权和被传递到网络的第一个隐藏层。
第一个隐藏层将计算所有神经元的输出,并将其传递给下一个隐藏层中的神经元。请注意,每层的输入值都通过激活函数进行转换,然后发送到下一层。
这种流动会持续到达最后一层,然后计算最终的预测。这种从输入层到输出层的单向流动称为‘前向传播’或‘前向传递’。
我们的网络现在已经生成了最终输出。接下来发生什么?
损失函数
将预测值与实际值进行比较并计算误差。误差的大小由损失函数给出。
损失函数将估计预测值的分布与训练数据中实际目标变量的分布的接近程度。
最大似然估计(MLE)框架用于计算整个训练数据上的误差。它通过估计预测的分布与训练数据中目标变量的分布的匹配程度来完成这一点。
在 MLE 框架下,分类问题的损失函数是 交叉熵,回归问题的损失函数是 均方误差。
交叉熵 量度了两个概率分布之间的差异。在神经网络的背景下,它表示预测概率分布与训练数据集中目标变量分布之间的差异 对于给定的一组权重或参数。
对于二分类问题,使用的损失函数是二分类交叉熵;对于多分类问题,使用的损失函数是类别交叉熵。
例如,考虑一个与客户贷款违约相关的二分类问题。假设训练数据包含 5 个客户。
神经网络在第一次前向传播中将计算客户违约的概率。网络为所有 5 个客户生成的输出分别是[0.65, 0.25, 0.9, 0.33, 0.45]。
训练数据中观测值的实际值为[1, 1, 1, 1, 1]。
交叉熵损失定义如下:
图片由作者提供
使用这个方程,上述问题的交叉熵损失(CEL)计算如下:
图片由作者提供
在这里,二分类交叉熵计算了一个分数,这个分数总结了实际概率分布和预测概率分布之间的平均差异,以预测类别 1。给定目标变量的实际值和预测值的损失为0.404。我们如何解释这个值?它有一个相对的解释。最终模型的损失值将远低于 0.404。第五个也是最后一个构建块将帮助我们达到那个最优值。它通过寻找最优的权重和偏置值来最小化损失函数,从而实现这一点。
在多分类问题中,其中目标变量编码为 1 到 n-1 类,类别交叉熵将计算一个分数,这个分数总结了所有类别的实际概率分布和预测概率分布之间的平均差异。
类似地,对于回归问题,均方误差(MSE)是最常用的回归损失函数。MSE 计算为目标变量的预测值与实际值之间的平方差的平均值。由于它是误差的平方,输出总是为正。
MSE 有一些变体,如均方对数误差损失(MSLE)和均值绝对误差(MAE)。选择取决于多个因素,如异常值的存在、目标变量的分布等。
网络在第一次前向传播中生成的输出是由初始化为某些随机值的权重决定的。损失函数比较实际值和预测值并计算误差。下一步是通过改变权重来最小化误差。网络如何实现这一点?
这将我们引入神经网络的最后一个构建块,即优化器。
5. 优化器
如前面部分所讨论的,在神经网络中,学习发生在权重中。训练神经网络涉及到学习所有层中所有神经元的正确权重。这通过使用随机梯度下降算法结合反向传播算法来实现。
鉴于这是一个比上述内容更复杂的概念,我们将在下一篇文章中详细探讨这一点。这里涉及的所有构建块也值得在后续文章中做更详细的解释。
本文的关键要点是 最终的神经网络模型是整体架构的函数,即节点数量、层数等,以及参数(也称为权重)的最佳值。一旦我们解决了这两个组件,就可以自信地预测目标变量。
这里是我找到的一些在理解这个概念方面非常有帮助的链接。
-
youtu.be/PySo_6S4ZAg
— 这是斯坦福大学 CS230 神经网络课程,由 Andrew Ng 主讲。 -
amzn.eu/d/6U4c3GR
— 《用 Python 进行深度学习(第 2 版)》。一本很棒的书。概念用非常简单的语言解释。 -
machinelearningmastery.com/
— 这是一个涵盖深度学习和机器学习所有基础和中级问题的资源。
希望到现在你对神经网络有了一些理解,并且了解了各种构建块如何结合在一起解决深度学习问题。请告诉我你的想法。
神经网络作为决策树
图片由 Jens Lelie 提供,来源于 Unsplash
将神经网络的强大功能与决策树的可解释结构结合起来
·
关注 发布于 数据科学之路 · 9 分钟阅读 · 2023 年 4 月 3 日
–
人工智能的近期繁荣清楚地展示了深度神经网络在各种任务中的强大能力,尤其是在数据维度高且与目标变量之间存在复杂非线性关系的分类问题领域。然而,解释任何神经分类器的决策是一个非常困难的问题。虽然许多后置方法如 DeepLift [2] 和 Layer-Wise Relevance Propagation [3] 可以帮助解释单个决策,但解释全局决策机制(即模型通常寻找的内容)则更加困难。
因此,许多高风险领域的从业者更倾向于选择更具可解释性的模型,如基本的决策树,因为决策层级可以被利益相关者清晰地可视化和理解。然而,基本的决策树往往不能提供足够的准确性,通常会使用集成方法如 Bagging 或 Boosting 来提高模型的性能。不过,这又牺牲了一些可解释性,因为要理解一个单一的决策,从业者需要查看数百棵树。然而,这些方法仍然比深度网络更受欢迎,因为至少特征重要性(无论是局部还是全局)可以被轻松提取和展示。
因此,目前的问题是我们想要神经网络的区分能力,但又希望具备决策树的可解释性。那么,为什么不把网络结构化成一棵树呢?这就是 Fross 和 Hinton(2017)在他们的论文“将神经网络提炼成软决策树”[1]中采用的主要方法。在这篇文章中,我将深入探讨神经决策树背后的关键机制,并解释这种方法的一些优点以及在实际应用中可能需要考虑的一些因素。虽然我们主要讨论分类树,但详细的方法也可以应用于回归树,只需进行一些相对较小的调整。
方法论
软决策树与硬决策树
在深入了解如何将神经网络构建成软决策树之前,让我们首先定义什么是软决策树。
当人们想到决策树(例如 sklearn 中实现的决策树)时,他们想到的是每个决策都是确定性的硬决策树。
硬决策树的示例(图片由作者提供)
如果满足某个条件,我们将走向左分支,否则我们走向右分支。每个叶子节点都有一个类别,通过简单地遍历树并选择我们最终到达的类别来进行预测。我们允许树生长得越大,可以采取的路径就越多,从而实现最终决策。
软决策树有许多相似之处,但工作方式略有不同
软决策树的示例(图片由作者提供)
在硬决策树中,每个分支是确定性的,而软决策树定义了在满足条件的情况下进入某个分支的概率。因此,虽然硬决策树输出一个单一值,软决策树则输出所有可能类别的概率分布,其中类别的概率是我们通过到达叶子的概率的乘积。 例如,上面树的批准概率等于 P(b1|X)(1-P(b2|X)) + (1-P(b2|X))(P(b3|X))。 分类决策就是选择具有最高概率的类别。
这种结构有许多优点。首先,非确定性决策使用户了解给定分类中的不确定性。此外,从技术上讲,硬树只是软树的特殊变体,其中所有分支概率都等于 1。
这些树的一个缺点是解释性略有下降。从利益相关者的角度来看,“我们批准了一个贷款,因为个人年收入为 $100k,债务少于 $400k”比起:
如果收入是 $110k,我们有 0.7 的概率向右分支;如果债务低于 400k,我们有 0.8 的概率批准,这样结果就是 0.56 的概率加上左分支中发生的情况。
这并不意味着这些树不可解释(因为仍然可以确切看到模型关注的内容),只是需要模型开发者提供更多的帮助。
倾斜决策树
在了解神经决策树之前,第二个需要掌握的概念是“倾斜”决策树的概念。
传统的决策树被认为是“正交”树,因为它们的决策是相对于给定的轴进行的。简单来说,每次决策中只使用一个变量。另一方面,倾斜树在决策过程中使用多个变量,通常是线性组合的形式。
倾斜决策边界的示例(图源:Zhang et. al 2017 [4])
决策节点中的一些示例值可能是“收入 — 债务 > 0”。这可以导致更强的决策边界。一个缺点是,如果没有适当的正则化,这些边界可能会变得越来越复杂。
将它们结合起来
现在我们理解了软决策树和倾斜决策树,我们可以将它们结合起来理解神经公式。
第一个组成部分是决策节点。对于每个节点,我们需要基于输入值的一些概率。为了实现这一点,我们可以使用神经网络的基本工具:权重和激活。在每个决策节点中,我们首先对输入变量进行线性组合,然后对总和应用一个 sigmoid 函数,得到分支概率。
为了防止极软的决策(使决策树更像硬决策树),可以使用温和的 sigmoid(或在应用 sigmoid 之前对线性组合进行乘法运算)。
每个叶节点包含一个 N 维张量,其中 N 是类别的数量。这个张量表示样本属于某一类别的概率分布。
神经网络作为决策树(图像复制自 Frosst & Hinton 2017 [1])
与软决策树一样,这棵神经树的输出是类别的概率分布。输出分布等于分布的总和乘以到达该分布的路径概率。
训练树
神经树的一个好处是可以通过像梯度下降这样的连续优化算法进行训练,而不是像普通决策树需要构建的贪婪算法。我们需要做的就是定义损失函数:
神经树的损失函数(图像来自 Frosst & Hinton 2017 [1])
这棵树的损失函数类似于交叉熵损失。在这个方程中,P^l(x) 是在给定数据点 x 的情况下到达叶节点 l 的概率,T_k 是目标类别 k 的概率(1 或 0),而 Q_k^l 是叶节点 l 中与类别 k 对应的张量(概率分布)元素。
关于这一结构的一个重要说明是树形结构是固定的。与使用贪婪算法逐个拆分节点并生长树的普通决策树不同,使用这种软决策树时,我们首先设置树的大小,然后使用梯度下降同时更新所有参数。这种方法的一个好处是更容易在不损失太多判别能力的情况下约束树的大小。
在训练过程中可能遇到的一个潜在陷阱是模型可能过度偏向单个分支,而未能利用树的全部力量。为了避免陷入不良解决方案,建议在损失函数中引入惩罚,鼓励树同时利用左右子树。
惩罚是期望平均分布(左右树各 50/50)与实际平均分布(定义为 alpha)之间的交叉熵。
节点 i 的 alpha 定义(图像来自 Frosst & Hinton 2017 [1])
在这个方程中,P^i(x) 是从根节点到节点 i 的路径概率。我们然后对所有内部节点的惩罚进行求和。
惩罚的定义(图像来自 Frosst & Hinton 2017 [1])
在这个方程中,lambda 是一个超参数,决定了惩罚的强度。然而,这可能会导致一些问题,因为随着树的下降,数据分裂成 50/50 的机会减少,因此建议使用根据树的深度变化的自适应 lambda。这将修改惩罚为:
修改后的惩罚函数(图片由作者提供,摘自 Frosst & Hinton 2017 [1])
当我们深入树中时,建议根据 2^-d 比例衰减 lambda。
结果可视化
虽然将神经网络重新表述为树形结构很有趣,但追求这种方法的主要原因是提供更多的模型可解释性。
首先看看经典问题的解释——MNIST 中的数字分类:
MNIST 示例(图片来自 Frosst & Hinton 2017 [1])
在上图中,内部节点的图像是学习到的过滤器,叶节点的图像是学习到的类别概率分布的可视化。对于每个叶节点和节点,最可能的分类用蓝色标注。
从这棵树来看,我们可以看到一些有趣的特征。例如,如果我们查看最右边的内部节点,潜在的分类是 3 和 8。实际上,我们可以在决策节点可视化中看到 3 的轮廓。白色区域似乎表明模型寻找能够闭合 3 的内部循环的线条,从而将其转换为 8。我们还可以看到模型在左侧倒数第三个节点中寻找 0 的形状。
另一个有趣的例子是预测 Connect4 游戏中的胜利
可视化神经决策树前 2 层预测 Connect4 游戏胜者(图片来自 Frosst & Hinton 2017)
这个例子中的学习到的过滤器表明,游戏可以分为两种不同类型:一种是玩家主要集中在棋盘边缘的游戏,另一种是玩家在棋盘中心放置棋子的游戏。
结论
将神经网络构建为软决策树使我们能够利用神经网络的强大能力,同时仍保留一些可解释性。正如 MNIST 数据集上的结果所示,学习到的过滤器可以提供局部和全局的解释能力,这对于高风险任务也是一种受欢迎且有帮助的特性。此外,训练方法(一次优化和更新整个树)使我们能够在保持树的大小固定的情况下获得更多的区分能力,这是我们在正常决策树中无法实现的。
尽管如此,神经树仍然不完美。树的软性特征意味着使用这些树的数据科学家需要在向非技术利益相关者展示之前“预处理”树,而普通决策树可以直接展示(因为它们相对自解释)。此外,虽然树的斜向特性有助于准确性,但在给定节点中变量过多会使解释变得更加困难。这意味着正则化不仅是推荐的,而且在一定程度上是必要的。此外,无论模型多么可解释,仍然存在对利益相关者可理解的解释性特征的需求。
然而,这些缺点并未削弱这些模型在推动解释性与性能前沿方面的潜力。我强烈建议大家在下一个数据科学任务中尝试这些模型。我也推荐大家阅读原始论文。
资源与参考文献
-
想了解更多关于 XAI 和时间序列预测的信息,请关注
参考文献
[1] N. Frosst, G. Hinton. 将神经网络蒸馏为软决策树 (2017). 2017 人工智能行动会议
[2] A. Shrikumar, P. Greenside, A. Jundjae. 通过传播激活差异学习重要特征 (2017). 国际机器学习会议 PMLR 2017。
[3] S.Bach, A. Binder, G. Montavon, F. Klauschen, K-R. Muller, W. Samek. 基于层级相关传播的非线性分类器像素级解释 (2015). PloS one, 10(7), e0130140
[4] L. Zhang, J. Varadarajan, P. N. Suganthan, N. Ahuja, P. Moulin. 使用斜向随机森林的鲁棒视觉跟踪 (2017). 2017 计算机视觉与模式识别会议。
具有多个数据源的神经网络
如何使用 Tensorflow 设计一个具有多个数据源输入的神经网络
·
关注 发表在 数据科学前沿 ·5 分钟阅读·2023 年 1 月 6 日
–
具有多个数据源的卷积神经网络。图片来源:作者。
在许多使用场景中,神经网络需要并行训练多个数据源。这些包括医学应用场景,其中可能会有一张或多张图像与结构化的患者数据一起使用,或者多图像应用场景,其中不同对象的图像贡献到单一输出。例如,使用个人房屋和汽车的独立照片来预测他们的收入。
集体数据不能一体处理,因为每个数据源都有其独特的属性和形状。为了成功设计一个网络,每个输入流需要单独处理和训练。
使用具有多个独立输入的 CNN 已被证明比单一图像输入提高了准确性。在一项研究中[1],处理了三个不同的图像输入分支,并将它们合并,结果比单独处理图像提高了 8% 的准确性。
此外,还显示出在 CNN 设计中,晚期合并网络分支也能产生更好的准确性[2]。这种晚期合并意味着在实际操作中,输入分支应该在合并到最终模型并生成预测之前,几乎完全作为独立网络进行处理。
我们将详细讨论如何设计这种类型的卷积神经网络(CNN),通过一个理论上的患者数据示例,其中包含一个数据 CSV 文件和一张图像。我们将只考虑一个图像输入,但这种方法也可以用于每个患者的多个图像。
首先,必须加载源文件并将其处理为 Pandas 数据框。下方示例中,加载了一个简单的数据集,包括患者 ID、患者年龄和一个标志,表示是否已诊断出癌症。
需要注意数据框的形状,因为这将影响后续网络的设计。
接下来,我们必须为每个患者加载一张图像。这是通过对患者数据框进行迭代来完成的,以保持记录的顺序。
图像数据也被转换为 numpy 数组,以保持与从文件中加载的患者数据的一致性。
加载数据的形状。图片由作者提供
我们现在需要考虑已加载数据的形状。对于图像,如果每个图像的尺寸为 512x512 像素,并且我们有n张图像,那么数据的形状为(n, 512, 512)。对于具有多个通道的图像,可能会添加进一步的维度,但我们将保持这个示例简单。
对于结构化的患者数据,我们的文件中有三列和n条记录。这将导致数据形状为(n, 3)。患者 ID 列在训练中是不需要的,因此这列可能会被删除,从而得到最终训练数据的形状为(n, 2)。
数据的进一步预处理,如缩放,不在本讨论范围之内。对于本示例,我们将直接使用原始数据。
然而,在设计神经网络之前,还需要一步。这一步是将数据分割成训练集和测试集。这需要在一个步骤中完成,以保持数据集的顺序和分割。下面的示例演示了如何使用 scikit-learn 来完成这一步骤:
一旦分割完成,我们就可以从两个数据集中提取目标特征作为我们的‘y’数据集。检查两个结果训练数据集的形状应该产生类似于以下的输出:
(1200, 512, 512)
(1200, 3)
记录数为 1,200。两个数据集需要具有相同数量的记录,以便它们可以在神经网络的输出中合并。
现在我们可以使用 Keras 函数式 API 开始设计神经网络。首先,我们将从结构化的患者数据开始:
网络的设计可以有所不同,但最好包括一个归一化层。归一化层仅适应于训练数据。
重要的是,输入层的形状设置为数据中的列数(在此示例中为 3)。
输出层的形状也很关键,因为这是将与图像处理分支合并的形状。这由最后的 Dense 层决定。在此示例中,输出层的形状将是:
(无, 64)
其中‘None’是 Keras 对记录数量的解释,未指定。
数据分支现在已经完成,我们可以查看图像处理分支。虽然可以设计自己的网络,但在实践中使用预设计的模型更为方便。在此示例中,我们将使用 Keras Applications 中的 Resnet-50。
如上所示,输入形状是每个图像的大小,加上一个额外的维度用于图像通道(在此案例中为 1)。
在 Resnet 模型的末尾添加一个全连接的 Dense 层,以使输出与数据分支的形状相同:
(无, 64)
因为我们在整个过程中都注意了数据的形状,所以现在能够合并两个分支的输出:
两个分支被连接在一起,最后添加一个全连接的 Dense 层以将模型减少到最终的预测。这里使用的激活函数可以有所不同。在此示例中,使用线性激活函数来输出实际的类别概率。
CNN 的最终设计总结如下:
最终 CNN 设计。图片由作者提供
如上所示,如果仔细考虑数据的形状,可以成功地合并多个分支。然后,可以使用这个合并后的模型从多个数据源生成单一的预测。
感谢阅读。
参考文献:
[1] Yu Sun, Lin Zhu, Guan Wang, Fang Zhao, “Multi-Input Convolutional Neural Network for Flower Grading”, Journal of Electrical and Computer Engineering, vol. 2017, Article ID 9240407, 8 pages, 2017. doi.org/10.1155/2017/9240407
[2] Seeland M, Mäder P (2021) Multi-view classification with convolutional neural networks. PLoS ONE 16(1): e0245230. doi.org/10.1371/journal.pone.0245230
神经原型树
原文:
towardsdatascience.com/neural-prototype-trees-f7bac36437a9
通过模仿人类推理来实现可解释的图像分类。
·发布于 Towards Data Science ·6 分钟阅读·2023 年 6 月 2 日
–
机器学习和人工智能现在被应用于大量领域,但随着使用的增加,模型面临着更多的风险和伦理测试。
让我们通过最近的新闻来激发思考,关于一辆特斯拉在自动驾驶模式下撞上树木的事件。根据当局的说法,司机表示车辆在她启用自动驾驶模式后向右偏移,驶离了道路,并撞上了一棵树。目前,这些说法正在调查中,但想象一下,识别汽车突然做出怪异决策的原因是多么困难。它是否发生了误分类?它看到了什么让它困惑的东西?对于传统的黑箱模型来说,调查模型内部非常困难且昂贵。
那么,替代方案是什么?是否有一种可解释的图像分类方法?是的,通过原型学习[2]和神经原型树[1]!借助这些架构,模型采用了一种非常直观的预测方法:识别看起来熟悉的部分。那只鸟有长长的喙吗?它有红色的喉咙吗?那一定是一只蜂鸟!
在本文中,我旨在提供有关这些模型如何工作的资讯,并讨论使用这些模型的一些优缺点。我将频繁引用的两篇主要论文如下,我强烈建议有兴趣的读者阅读这些论文:
当我们面临具有挑战性的图像分类任务时,我们通常通过解构图像来解释我们的推理……
arxiv.org openaccess.thecvf.com [## CVPR 2021 开放获取库
神经原型树用于可解释的细粒度图像识别 Meike Nauta、Ron van Bree、Christin Seifert 等
这是我撰写的关于神经决策树变体的第二篇文章。如果你还没有读过第一篇文章,我强烈建议你浏览一下,因为这里阐述的许多概念都是基于标准神经树构建的。
towardsdatascience.com ## 神经网络作为决策树
利用神经网络的强大能力和决策树的可解释结构
[towardsdatascience.com
什么是原型?
图像识别中的原型思想首次由 Chen & Li 等(2019)在他们的论文“这看起来像那样:用于可解释图像识别的深度学习” [2] 中提出。这是一种潜在表示,表示与给定类别相关的某些训练图像补丁。正如名字所示,该模型通过解剖输入图像并找到提供证据的原型部分来工作,以表明图像属于某一类别。网络简单地计算欧几里得距离,并将其反转以创建相似度分数。这些分数随后通过一个全连接层生成最终的类别概率:
图 1: ProtoPNet 架构(图源自 Chen & Li 等,2019 [2])
一旦模型训练完成,用户可以简单地将学到的原型与训练集中的补丁进行匹配,从而创建对任何预测的非常可解释的解释:
图 2: 鸟类预测的推理过程(图源自 Chen & Li 等,2019 [2])
神经原型树
使用原型包模型(例如来自原始原型论文 [2] 的 ProtoPNet)的一个问题是,原型匹配是同时进行的,但人类图像识别依赖于一系列步骤。如果某物没有爪子或耳朵但有尾巴,那它可能不是猫,因此网络不应该给它打上猫的标签。这就是神经原型树 [2] 发挥作用的地方。与原型包不同,Nauta 等人 [1] 选择将他们的模型设计为神经决策树。这种决策树提供了这种顺序决策,并提供了全局可解释性,而不仅仅是局部可解释性。
软决策树的回顾
神经原型树是软决策树,而不是硬决策树。虽然硬决策树强制执行确定性分支(你要么向左走,要么向右走),但软决策树使用的是概率分支(你有 p 的概率向左走和 1-p 的概率向右走)。此外,虽然硬决策树输出一个单一的值,但软决策树输出的是所有可能类别的概率分布,其中类别的概率是到达叶子节点所经过概率的乘积,分类决策则是概率最高的类别。
原型树中的决策制定
像标准神经树一样,每个叶子节点包含一个关于类别的概率分布。分支决策是通过计算图像补丁到节点中给定原型的距离来做出的。每张图像的得分是图像中补丁与原型之间找到的最小距离,然后转换为概率。简单来说:如果在图像中找到原型,就向右走;如果找不到该原型,就向左走!
图 3:ProtoTree 中的预测机制(图源自 Nauta 等人 2021 [1])
这个机制显然允许我们使用与可视化普通原型模型相同的机制来可视化出极其可解释的模型。
图 4:原型树的全局解释(图源自 Nauta 等人 2021 [1])
学习叶子分布
在普通决策树中,叶子的标签是通过查看最终到达该叶子的样本来学习的,但在软树中,叶子中的分布是全局学习问题的一部分。然而,作者注意到将叶子的学习与原型的学习结合在一起会导致分类结果不佳。为了解决这个问题,他们利用了一种无导数策略来获取叶子概率的更新方案:
图 5:更新方案。c_l^t 是第 t 轮中叶子 l 的叶子概率。y 是真实值,yhat 是预测值。pi 是到该叶子的路径概率。(图自 Nauta 等,2021 [1])
该更新方案与小批量梯度相结合,以学习原型和卷积参数,从而创建一个高效的学习过程。
剪枝
为了提高可解释性,作者还引入了剪枝机制。如果一个叶节点包含有效的均匀分布,它的区分能力不强,因此最好对其进行剪枝,因为较小的树更易于阅读和解释。从数学上讲,作者定义了一个阈值 t 并移除所有最高类别概率小于 t(max(c_l) ≤ t*)的叶子。如果一个子树中的所有叶子都被移除,则可以移除该子树及其相关原型,从而使树变得更加紧凑。通常,t = 1/K + epsilon 其中 K 是类别数,epsilon 是一个非常小的数值,表示容差。
图 5:剪枝可视化(图自 Nauta 等,2021 [1])
性能
图 6:平均准确率和标准差。ProtoTree ens. 是 3 棵或 5 棵原型树的集合。 (图自 Nauta 等,2021 [1])
作者使用 CARS 和 CUBS 数据集对他们的方法进行了基准测试,与其他可解释的图像识别方法(如基于注意力的可解释性方法)进行比较。他们发现,通过使用相对较小的树木集合(9 棵和 11 棵),他们能够接近 SOTA 准确率。
结论
可解释的深度学习图像分类器相对于黑箱模型提供了许多优势。它们可以帮助建立信任,改善调试,并解释预测。此外,它们还可以用于探索数据,了解不同特征之间的关系。
总的来说,神经原型树是一种有前途的新方法,用于以可信赖的方式进行图像识别。如果医生能够检查模型所观察到的图像的特征,他更可能相信癌症检测模型。这些原型树甚至可以通过添加注意力等措施进一步提高准确性!
资源和参考文献
-
神经原型树的 Github:
github.com/M-Nauta/ProtoTree
-
如果你对可解释的机器学习和人工智能感兴趣,可以考虑关注我:
medium.com/@upadhyan
。
参考文献
[1] M. Nauta, R.v. Bree, C. Seifert. 神经原型树用于可解释的细粒度图像识别 (2021). IEEE/CVF 计算机视觉与模式识别会议(CVPR),2021
[2] C. Chen, O. Li, C. Tao. A.J. Barnett, J. Su, C. Rudin. 这看起来像那样:可解释图像识别的深度学习 (2019). 第 33 届神经信息处理系统会议。
新的 ChatGPT 提示工程技术:程序模拟
·
关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 9 月 3 日
–
来源:作者提供的图像,使用 MidJourney 生成
提示工程的世界在各个层面上都非常迷人,并且有很多巧妙的方法可以引导像 ChatGPT 这样的代理生成特定类型的响应。诸如链式思维(CoT)、基于指令、N-shot、Few-shot 甚至像奉承/角色分配这样的技巧都是灵感的来源,激发了许多满足各种需求的提示库。
在这篇文章中,我将深入探讨一种技术,根据我的研究,这种技术可能尚未被充分探索。虽然我会暂时将其标记为“新”,但我会避免称其为“创新”。鉴于提示工程中的创新速度之快以及新方法的易于开发,这种技术可能已经以某种形式存在。
该技术的本质在于使 ChatGPT 以模拟程序的方式运行。我们知道,程序由一系列指令组成,这些指令通常打包成函数以执行特定任务。从某种程度上说,这种技术是基于指令和基于角色的提示技术的结合。但与这些方法不同的是,它寻求利用一个可重复且静态的指令框架,使得一个函数的输出可以影响另一个函数,并且整个交互保持在程序的范围内。这种模式应该与像 ChatGPT 这样的代理中的提示-完成机制很好地契合。
来源:作者提供的图像
为了说明这项技术,让我们在 ChatGPT4 中指定一个迷你应用程序的参数,该应用程序旨在作为互动创新工作坊。我们的迷你应用程序将包含以下功能和特点:
-
处理新想法
-
扩展想法
-
总结想法
-
检索想法
-
继续处理先前的想法
-
Token/“记忆”使用统计
需要明确的是,我们不会要求 ChatGPT 用任何特定编程语言编写迷你应用程序,我们将在我们的程序参数中反映这一点。
根据这个程序大纲,让我们开始编写启动提示,以在 ChatGPT 中实例化我们的互动创新工作坊迷你应用程序。
程序模拟启动提示
Innovator’s Interactive Workshop Program
I want you to simulate an Innovator’s Interactive Workshop application whose core features are defined as follows:
1\. Work on New Idea: Prompt user to work on new idea. At any point when a user is ready to work through a new idea the program will suggest that a date or some time reference be provided. Here is additional detail on the options:
a. Start from Scratch: Asks the user for the idea they would like to work on.
b. Get Inspired: The program assists user interactively to come up with an idea to work on. The program will ask if the user has a general sense of an area to focus on or whether the program should present options. At all times the user is given the option to go directly to working on an idea.
2\. Expand on Idea: Program interactively helps user expand on an idea.
3\. Summarize Idea: Program proposes a summary of the idea regardless of whether or not it has been expanded upon and proposes a title. The user may choose to rewrite or edit the summary. Once the user is satisfied with the summary, the program will "save" the idea summary.
4\. Retrieve Ideas: Program retrieves the titles of the idea summaries that were generated during the session. User is given the option to show a summary of one of the ideas or Continue Working on a Previous Idea.
5\. Continue Working on Previous Idea: Program retrieves the titles of the idea summaries that were generated during the session. User is asked to choose an idea to continue working on.
6\. Token/Memory Usage: Program displays the current token count and its percentage relative to the token limit of 32,000 tokens.
Other program parameters and considerations:
1\. All output should be presented in the form of text and embedded windows with code or markdown should not be used.
2\. The user flow and user experience should emulate that of a real program but nevertheless be conversational just like ChatGPT is.
3\. The Program should use emojis in helping convey context around the output. But this should be employed sparingly and without getting too carried away. The menu should however always have emojis and they should remain consistent throughout the conversation.
Once this prompt is received, the program will start with Main Menu and a short inspirational welcome message the program devises. Functions are selected by typing the number corresponding to the function or text that approximates to the function in question. "Help" or "Menu" can be typed at any time to return to this menu.
如果你想以更互动的方式跟随并自己测试,可以随意将提示加载到 ChatGPT4 中。
这是 ChatGPT 对提示的完成结果。
到目前为止,一切顺利。我们已经启动了我们的“迷你应用程序”,收到了振奋人心的欢迎消息,并且展示了一个与我们的程序参数一致的功能菜单。让我们通过提交“1”来测试我们的迷你应用程序的功能,以启动“处理新想法”功能。
对话继续很好地遵循我们设定的“程序”结构,适当地提供了符合参数的完成。让我们继续从零开始构建一个想法,并让程序与我们合作,开发一种用于生长建筑物而非建造它们的技术。
有趣的是,我们注意到“程序”在没有明确指示的情况下自动调用“Expand on Idea”功能。鉴于程序的目标,这种行为并不不当,可能受到我们最初上下文设置的影响,这些设置引导聊天代理像程序一样运行。让我们继续深入探讨一下增长建筑所需的技术。
现在让我们检查一下用于增长建筑的材料。
我继续沿着这些思路前进,现在,让我们看看是否可以返回到菜单。
菜单仍然完整。让我们尝试让程序执行 Summarize Idea 功能。
我对这个标题和摘要暂时满意,所以让我们“保存”它。
很快,我们将测试检索我们“保存”的想法,以检查我们在实现数据持久性方面的努力是否成功。另一方面,调整我们的“迷你应用”以省略保存后的重复摘要可能会有所帮助。
角色启动作为程序的结果是在输出中包含主菜单——这种行为在程序的背景下是合理的,即使它在我们的程序定义中没有被明确配置。
接下来,让我们测试我们的令牌计数功能。
为了核实准确性,我转向 OpenAI 的分词器工具。
令牌计数不准确,证据在于显著的差异——我们的程序报告大约有 1,200 个令牌,而分词器工具显示为 2,730。鉴于这种不匹配,明智的做法是从程序中移除此功能。我不会深入讨论为何这种任务通常对语言模型来说是个问题,以及功能损失相对较小。最终,我预计这样的功能将会原生集成到 ChatGPT 中,特别是考虑到令牌计数信息在后台不断传递。
接下来,让我们深入研究“Get Inspired”功能以生成新想法。为了简洁起见,我将进一步展示对话。正如你所见,我选择深入探讨一个我们的程序建议的废物转化为能源的无人机概念,概括了这个想法,并让我们的程序“保存”了它。
一切看起来都很好,系统甚至擅自给我们的想法命名为“SolarSky”。为了更有效地实现这一点,我们可能会在程序定义中为此任务加入一个独立的函数,或者在“工作在新想法”或“扩展新想法”函数中提供更具体的指示。同样,我们在完成中看到菜单,这从程序流的角度看是合乎逻辑的。
现在让我们看看是否可以“检索想法”。
这似乎符合我们的原始指示,仅提供了所请求的标题。它还提示我们继续工作一个想法,即使这并没有明确地编程到迷你应用程序中。接下来,让我们评估它是否保持了根菜单索引。为此,我将输入“5”,对应于“继续工作在之前的想法”功能,看看是否有效。
显然,索引在对话上下文中被维护,并且相应地调用了函数。这一观察值得注意,特别是在考虑到多个索引可能处于活动状态的情况下。这引发了有关“程序”在这种条件下如何表现的有趣问题。你可能错过了,但在我们互动的早期,程序在征求用户对想法扩展选择的输入时实际上采用了索引技术。
让我们继续工作在我们的建筑构想上。
再次看起来不错。“程序”的行为如预期,并且也跟踪了我们在想法扩展过程中暂停的确切点。
让我们在这里停止测试我们的提示,看看通过这种技术我们学到了什么。
结论与观察
坦白说,这次练习虽然在范围和功能上都有限,但超出了我的预期。我们本可以让 ChatGPT 用 Python 等语言编写这个迷你应用程序,然后利用代码解释器(现在称为高级数据分析)在持续的 Python 会话中运行它。然而,这种方法会引入一种刚性,使得启用我们迷你应用程序中固有的对话功能变得困难。更不用说,特别是在具有多个重叠功能的程序中,我们立即面临着代码无法正常工作的风险。
ChatGPT 的表现尤其令人印象深刻,因为它以高度逼真的方式模拟了程序行为。提示完成保持在程序定义的边界内,即使在函数行为没有明确规定的情况下,完成也在迷你应用程序目的的上下文中有逻辑性。
这种程序模拟技术可能与 ChatGPT 的“自定义指令”功能配合良好,尽管值得一提的是,这样做会将程序的行为应用于所有后续的互动中。
我的下一步包括对这种技术进行更深入的研究,以评估是否可以通过一个全面的测试框架来了解这种方法相对于其他提示工程技术的表现。这种练习也可能帮助确定这种技术最适合哪些特定任务(或任务类别)。敬请关注更多信息。
与此同时,希望你在互动中发现这种技术和提示有帮助。如果你想进一步讨论这种技术,请随时通过LinkedIn与我联系。
除非另有说明,本文中的所有图片均由作者提供。
新数据表明 2023 年是有史以来最热的夏天
原文:
towardsdatascience.com/new-data-demonstrates-that-2023-was-the-hottest-summer-ever-d92d500a8f01
气候变化|数据可视化
我们在 Python 和 Plotly 中开发可视化,以分析 2023 年 6 月至 8 月期间记录的最高气温
·发布于面向数据科学 ·阅读时间 11 分钟·2023 年 9 月 28 日
–
今年夏天比 1880 年以来的任何时候都要热!
数据科学家如何帮助展示我们的气候正在迅速变化,并帮助传达情况的严重性?我们将探讨如何通过分析和可视化有效地呈现数据,并对数据的表示进行优化。
但首先,让我们简要探讨一下全球变暖的一些后果。然后,我们将考虑如何使用 Plotly 和 Python 有效地可视化数据,并展示这些数据如何与 CO₂排放相关。
这张地图展示了 2023 年气象夏季(6 月、7 月和 8 月)的全球温度异常情况。它显示了地球不同区域相对于 1951 年至 1980 年基准平均值的温暖或凉爽程度。来源:NASA 地球观测台/劳伦·多芬,经授权使用
这张来自 NASA 的地图显示了与 1951 年至 1980 年平均值相比,今年夏天的全球温度异常情况,我并不感到惊讶,因为我所在的欧洲地区是全球温度变化最高的区域之一——西班牙的气温高达 40 摄氏度(约 104 华氏度)并不罕见。
影响
高温加剧了加拿大、夏威夷、欧洲部分地区以及其他地方的野火,并可能促成了全球范围内的强降雨事件。
“Drevenochoria” 火灾从阿提卡的伊利翁看到的图像,拍摄时间大约为 7 月 18 日凌晨 2 点。图片由Sthivaios提供 CC BY-SA 4.0
欧洲的野火在夏季并不罕见,但今年特别猛烈,尤其是在希腊,岛屿如罗德岛和科孚岛等地区进行了疏散,许多人死伤。而且,在夏威夷,一个小镇被完全摧毁了。
除了野火,欧洲还遭遇了大规模的暴雨(特别是来自丹尼尔风暴的暴雨),希腊再次受到严重影响——洪水造成了数百万欧元的损失。
2023 年 9 月 9 日,丹尼尔气旋(也称为丹尼尔风暴)在利比亚北部。来自 NOAA-20 卫星的 VIIRS 影像——worldview.earthdata.nasa.gov/
,公共领域
由于丹尼尔风暴,利比亚的洪水造成了巨大的破坏,并造成了数千人遇难,当时两个大坝崩溃,摧毁了地中海沿岸的德尔纳大部分地区。电视新闻播报了幸存者从曾经的家园废墟中被救出的悲惨场景。
虽然不可能明确证明这些灾难与气候变化之间的联系,一项由 NASA 主导的研究 确认了随着气温升高,严重干旱和过度降水的发生频率增加。
而Carbon Brief(一个总部位于英国的网站,涵盖气候科学、气候政策和能源政策的最新发展)建议,93% 的极端高温事件经科学家评估都因为气候变化而变得更可能发生或更严重。
这并不令人惊讶:较高的温度意味着森林变得更干燥,更易燃,而当下雨时,这些较高的温度意味着大气能容纳更多的水蒸气,从而有更多的 H₂O 可降雨。
统计数据展示
根据他们的新闻稿,今年 6 月、7 月和 8 月的综合温度比任何其他记录中的夏季高出 0.23 摄氏度,比 1951 年至 1980 年的平均夏季温度高出 1.2 摄氏度。
数据由 NASA 戈达德太空研究所(GISS)[1]的科学家在纽约记录,记录了 6 月、7 月和 8 月的温度异常——这些月份被认为是北半球的气象夏季。
数据覆盖了从 1880 年至当前年份,并记录了与 1951 年至 1980 年计算的平均值相比的夏季温度变化。
我们可以绘制一个简单的折线图,从中可以清晰地看到温度的逐渐上升。1980 年后的温度逐渐上升,而 1951 年之前的温度大多低于平均水平,而 1951 年至 1980 年之间的温度则趋于接近平均温度。
2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片
然而,由于数据取自每年三个月的时间段,其连续性不如代表一系列相邻月份的数据。因此,也许对于这种数据,柱状图可能是更好的选择。
下图更清楚地展示了自 1880 年以来全球温度在 6 月、7 月和 8 月的变化情况,柱状图更好地表现了数据。你可以更容易地看到,今年的温度明显高于任何近期年份,并且比 1951 年至 1980 年的平均夏季温度高出很多。
2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片
因此,柱状图可能比折线图更好地传达数据。但我们仍然可以通过使用颜色来使其更加清晰。
2023 年 6 月、7 月和 8 月的全球温度。数据来源于 NASA 的 GISS [1] — 作者图片
上面的图表比之前的图表更戏剧性地展示了温度变化。
图表中的颜色编码(尽管严格来说是多余的)强调了上升趋势,较高的温度由更暖的颜色表示,深色主题与浅色形成鲜明对比,并突出了近期更热的年份。
我使用 Plotly 创建了柱状图,采用了深色方案和 Plotly 的‘inferno’颜色范围——我们将在下面看到 Python 代码。
数据
数据[1]覆盖了 1880 年至 2023 年,包括每年每个月的温度异常以及例如我们上面看到的 6 月至 8 月的月均温度。有一个全球数据文件,还有两个分别针对北半球和南半球的文件。
我已将数据复制到一个 GitHub 仓库,并编写了一个 Jupyter Notebook,该 Notebook 读取这些数据并生成我在这里使用的所有图表(文章顶部的地图除外——那张图片是由 NASA/GISS 的好人们制作的)。你可以在那里下载任何或所有的材料,我将在本文末尾提供一个链接。
这是一个将全球数据读取到 Pandas 数据框中的代码副本。
title, df = readdata.read_GLB()
我创建了一些辅助函数来读取文件,这些函数将文件的标题和数据分开,你将在仓库中找到它们。
这是数据的截图:
2023 年 6 月、7 月和 8 月的全球气温作为数据框。数据来自 NASA 的 GISS [1] — 图片由作者提供
一些数据缺失——主要是尚未发生的月份——但我们不必担心这些,因为我们不会使用任何包含缺失数据的列。
创建条形图只需三行代码。
period = 'JJA'
scale = 'inferno'
px.bar(df, x='Year', y = period, color="JJA", title = f"{title} - {period}",
color_continuous_scale=scale, template='plotly_dark')
图表绘制了北半球夏季的数据,这些数据在 JJA
列中找到。由于我在尝试不同的颜色比例,所以我还使用了一个变量 scale
来定义这个比例。该时期和标题(在之前的代码中设置)用于为图表创建标题,除此之外,它只是一个简单的 Plotly Express 条形图。
线图的代码类似于:
px.line(df, x='Year', y = period, title = f"{title} - {period}",
template='plotly_white')
我对默认的 Plotly 颜色方案不太感冒,所以在这里我使用了我更喜欢的 plotly_white
。
CO₂ 排放
我们能展示全球气温持续上升的原因吗?恐怕不能简单地做到这一点。但我们可以展示 CO₂ 排放量的上升与气温上升之间的相关性,并指出我们知道 CO₂ 排放量由于人为活动而增加,并且大气中的 CO₂ 增加会导致变暖。
我使用了一个 Our World in Data 的 GitHub 仓库 [3] 来创建我们接下来将使用的数据。同样,它被复制到我自己的一个仓库中,我将其处理成较小的文件。链接将在文章末尾出现。
我们将创建一个这样的全球 CO₂ 排放线图:
全球 CO2 排放。数据来自 OWID [3],图片由作者提供
首先,让我们获取数据:
f = "https://raw.githubusercontent.com/alanjones2/CO2/master/data/world_df.csv"
co2 = pd.read_csv(f)
它看起来像这样:
我们只对“年份”和“年度 CO₂ 排放量”列感兴趣,从中我们可以绘制图表。
该图表类似于温度变化图。它有相同的平坦开始,并且在图表的后半部分上升更陡。它们在这里一起展示。
全球 CO2 排放。数据来自 OWID [3],图片由作者提供
2023 年 6 月、7 月和 8 月的全球气温。数据来自 NASA 的 GISS [1] — 图片由作者提供
我们可以接受科学共识,即 CO₂排放增加了全球变暖,但同时我们也需要接受其他因素的影响。
温度变化不仅仅是由于人类将温室气体排放到大气中。正如美国环境保护署(EPA)所明确指出的,还有其他因素,例如太阳活动和由于例如森林砍伐造成的地球反射率变化。还有除了二氧化碳之外的其他温室气体,如甲烷和一氧化二氮。
美国环境保护署(EPA)还明确表示,除了人为排放的温室气体之外,没有其他原因能够解释当前气候变化的水平。
然而,这些其他因素使图表不完全相同。温度线的波动上下,这些不太剧烈的影响可能是原因。
数据相关性
数据科学家或统计学家确定相关性的方法可能是绘制 CO₂排放与温度变化的散点图,并在点之间绘制趋势线。
我们可以稍后再查看,但我不确定散点图是否为一般读者所理解,仅仅将两个图放在一起可能对非专业读者来说是更好的方法。
温度和排放图表的相似性很容易看出,但如果我们能在同一图表上绘制这些数据会更有用。这并不是完全简单,因为虽然两个图的 x 轴都是类似的年份区间,但 y 轴差异很大。温度异常覆盖了几度摄氏度,而 CO₂排放量在大约 20 到 40 亿吨之间。
双轴图
解决方案是绘制一个双轴图,其中有两个 y 轴和一个共同的 x 轴。不幸的是,Plotly Express 不支持双轴图,因此我们利用了 Plotly Express 构建的 Graph Objects 包。
如下代码所示,我们首先创建一个包含次级 y 轴的空图,然后将两个数据轨迹添加到该图中,第一个轨迹是温度数据,第二个轨迹是 CO₂数据。
你可能会发现 emissions 轨迹是一个散点图,但 Graph Objects 中的Scatter
轨迹默认用线连接这些点,所以它实际上是一个线图。
代码的其余部分仅仅设置了标题和标签。
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])
# Add traces
fig.add_trace(
go.Bar(x=df['Year'], y=df['JJA'], name="Temp anomaly"),
secondary_y=False,
)
fig.add_trace(
go.Scatter(x=co2['Year'], y=co2['Annual CO₂ emissions'], name="CO2 Emissions"),
secondary_y=True,
)
# Add figure title
fig.update_layout(
title_text="Temperature / CO2 Emissions"
)
# Set x-axis title
fig.update_xaxes(title_text="Year")
# Set y-axes titles
fig.update_yaxes(title_text="Temperature ºC", secondary_y=False)
fig.update_yaxes(title_text="CO2 Emissions tonnes", secondary_y=True)
fig.update_layout(template='plotly_dark')
fig.show()
结果如下:
2023 年 6 月、7 月和 8 月的全球温度。数据来自 NASA 的 GISS [1],全球 CO2 排放数据来自 OWID [3] —— 作者提供的图像
两组数据之间的关系相当明确。
数据科学家的散点图相关性
为了完整性,我们还需要一个散点图,不是吗?这就是它:
2023 年 6 月、7 月和 8 月的全球温度。来自 NASA 的 GISS[1]与全球 CO2 排放数据,数据来自 OWID[3] — 图像由作者提供
要绘制此图,我们需要匹配数据集的长度,因此我们需要将两个数据集都截断,从 1880 年(温度数据的第一年)到 2021 年(CO₂数据的最后一年)
# To draw a scatter plot we need to make the data the same length
# So we need to truncate both from 1880 (the first temp yr) to 2021 (the last co2 yr)
# Check that years are correct
co2yrs=list(co2['Year'][30:])
tyrs = list(df['Year'][:-2])
print(f"CO2 Years {min(co2yrs)} to {max(co2yrs)}")
print(f"Temperature Years {min(tyrs)} to {max(tyrs)}")
CO2 Years 1880 to 2021
Temperature Years 1880 to 2021
这完成了任务并检查了范围是否相同。
散点图还显示了两个数据集之间的相关性,但至少对于一般观众,我认为双线和条形图更具说服力。
结论
关于气候变化的争论常常充满情感、政治和经济动机以及个人偏见。因此,我们有责任以尽可能清晰的方式呈现事实,以便科学论点能够占上风。
我将最后的话留给气候科学家和 GISS 主任加文·施密特。在NASA 新闻稿中,他被引用说,“不幸的是,气候变化正在发生。我们说会发生的事情正在发生,”他补充道,“如果我们继续向大气中排放二氧化碳和其他温室气体,情况会变得更糟。”
下载
感谢阅读,希望这篇文章对你有帮助,并且你会查看我 GitHub 存储库中的数据和代码。
你可以在我网站的链接中找到数据和包含本文所有代码(及更多)的 Jupyter Notebook。作为额外福利,还有一些图表的 Matplotlib 版本,包括双轴图。
编码、数据科学和数据可视化 - 文章和教程
你也可以订阅我的数据可视化、数据科学和 Python通讯以获取更多内容。
参考文献
-
GISTEMP 团队,2023: GISS 地表温度分析(GISTEMP),第 4 版。NASA 戈达德太空研究所。数据集于 2023-09–19 访问,地址为 data.giss.nasa.gov/gistemp/。请注意,NASA 的数据集没有特定的使用许可。NASA 将其免费提供用于非商业目的,但应给出归属(如上所示)。
-
Lenssen, N., G. Schmidt, J. Hansen, M. Menne, A. Persin, R. Ruedy, 和 D. Zyss, 2019: GISTEMP 不确定性模型的改进。J. Geophys. Res. Atmos., 124, 第 12 期, 6307–6326, doi:10.1029/2018JD029522.
-
全球 CO₂排放数据源自我们世界的数据(OWID)co2-data GitHub 存储库,创作共用 BY 许可
最新的 DeepMind 工作揭示了语言模型的极致提示种子
如何通过计算优化的提示使语言模型表现出色,以及这如何影响提示工程
LucianoSphere (Luciano Abriata, PhD)
·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 11 月 8 日
–
图片由 Ali Shah Lakhani 提供,来源于 Unsplash
随着我们见证人工智能(AI)的稳步进步,每个月完成越来越困难的任务,人们普遍关注未来的劳动市场。如果 AI 继续自动化许多当前由人类执行的任务,未来的职业会是什么样的呢?有一种观点认为“编程这些系统将是人类的工作多年”,或者“我们将始终需要人类来维护和重新训练 AI 模型”,或者“设计有效的提示以正确引导 AI 模型是人类的技能”。本文的重点就是后者,这促使了“提示工程”作为一种“职业”的出现。确实,编写高效提示以使 AI 模型准确地执行期望的操作,或使其“思考”得足够好以改善其答案,尤其是对问题,是一种技巧。看看这个作为一个例子:
提炼出关键点,基于超过 2 年的经验以及 AI 开发者自己的教程、实践和示例。
towardsdatascience.com
然而,这些人工干预中的任何一种很可能都不会永远保持相关性。特别是在提示工程方面,这些技能看起来很快就不会再那么重要了。继续阅读以了解原因,并在过程中了解 DeepMind 在最近的预印本中报告的非常有趣的发现,当使用大型语言模型(LLMs)时,你可以立即将这些发现应用于自身利益,我会通过 ChatGPT 免费版的实践示例来展示给你。
提示优化
在深入之前,我们需要了解“提示”的概念。提示是传递给 AI 模型的指令,用来告诉它们我们希望它们做什么。AI 模型响应用户生成的文本输入或提示,以生成其输出——文本、图形、音频等。输入提示的质量和具体性显著影响模型生成的内容和质量。此外,不同的用户可能会有不同的请求或提问方式,并不是所有的方式都能高效地产生预期的答案和正确的信息。
过去,制定有效提示的艺术并不被很好地理解。早在 2019 年,OpenAI 就揭示了在文本输入末尾添加“tl;dr”(通常用于请求总结)可以使模型总结前面的文本。尽管当时这只是一个学术好奇,但随着时间的推移,研究人员和爱好者开始发现特定的措辞可以解锁这些 AI 模型的增强潜力。正是这一点催生了“提示工程(学)”这一“领域”和“职业”。
提示工程师成为了制定能够引发 AI 模型期望回应的提示的专家。他们在互联网上分享了他们的“魔法短语”和技巧,从而有效地创造了一个新的专业领域。这些专业人员的需求迅速增长,提示工程师的职位开始出现,突显了这些人在 AI 领域的重要影响。
提示工程师基本上是为了最佳输出而优化提示。但是,优化是计算机做得非常好的任务;因此,它们很可能会取代人工提示工程师。这可能即将发生:在他们最近的预印本中,DeepMind 展示了 AI 模型可以用来优化自身的输入,从而使得“提示工程”在传统意义上变得有些过时和低效。
大型语言模型作为自身提示的优化器
最近 DeepMind 的论文《大型语言模型作为优化器》探讨了如何有效地优化……它们自己的输入提示!
让我们看看这是什么,它如何与训练和优化相关,以及 DeepMind 的工作究竟发现了什么。哦,我们还可以尝试一些计算机优化的提示。你会感到惊讶。
传统上,在机器学习中,优化过程涉及调整模型的内部参数(“权重”,即描述不同人工神经元如何连接的大量数字)以最小化误差。通常,在训练阶段,模型会接收到大量已知的输入-输出对,并且所有权重都按照传统的数学方式进行优化,以使模型“学习”。然而,一旦模型训练完成,我们仍然可以通过以不同方式提供输入来使用它,这些输入将根据训练过程中计算的权重集以不同的方式传播,因此它们的输出也会不同……有些更好,有些更差。然后,可以优化如何提供输入,在语言模型的特殊情况下,就是优化提示。像任何优化协议一样,我们人类可以通过试错来完成这一过程,但显然计算机比我们更擅长——这只是一个让他们去做的事情!
DeepMind 工作的核心思想是使用各种 AI 模型生成特定任务的提示,然后测试不同提示在实现期望结果方面的有效性。例如,如果任务是解决一个数学问题,用户可以输入一个非常简单的提示,直接提出问题,或者他/她可能会在问题前面加上一句种子句,如“逐步解决这个数学问题。”,或者“我们一步步解决这个问题。”等。在某些情况下,种子句可能有助于改善 AI 的输出,即使问题本身以完全相同的方式提出。
基于这一思想,DeepMind 的工程师尝试了不同的种子提示,然后用完全相同的问题进行测试,并评估答案的质量(正确性)。他们重复了这个过程,将一系列提示应用于多个不同的问题,并最终统计了每个种子句产生的正确答案数量,以比较对每个问题的影响。
从经验来看,我们知道这种策略应该有效于找到如何优化提示。然而,对于我们人类来说,大规模地进行样本测试是不可能的。通过使用计算机架构,DeepMind 能够大规模地运行这个过程,并且在不同类型的问题上进行测试。
结果令人惊讶:如果没有使用或使用了不好的种子句,问题只能在约 50%的情况下正确解决,而当使用了好的种子句时,正确率可以达到 80%。
DeepMind 的方法包括使用这些测试中的指标来指导更好提示的创建。他们得出的结论是,通过持续迭代提示并考虑模型的反馈,AI 系统可以改进其提示生成过程,对输出产生非常重要的影响。
请注意,这种方法并不涉及更新模型的内部参数,如同训练模型时所做的那样,而是专注于优化输入本身。
一些有趣的动手示例
这是一些我自己测试中灵感来自 DeepMind 工作的有趣示例,你也可以立即使用 ChatGPT 免费版尝试(该版本由 DeepMind 预印本中显示为 GPT-3.5-Turbo 的模型提供支持)。
示例 1,复制 DeepMind 预印本中展示的一个问题的思路
DeepMind 论文中引起我注意的第一个例子是要求进行线性回归,因为既然这些模型原则上无法进行数学运算,那么我期望它们永远无法工作,无论提供什么提示。
如果你要求 ChatGPT 对非常简单的数据进行线性回归,比如 x = [1, 2, 3, 4] 和 y = [2, 4, 6, 8],你会发现它会立即正确地解决。但如果我们用更难的线性回归来挑战它呢?让我们看看。
在这里,我生成了不同 x 值的合成数据,这些值在 0 和 12 之间随机分布,并使用方程 y = 3.5 x — 11.5 在常规电子表格程序中计算了没有噪声的 y 值。然后,我要求程序“找到描述这些数据的线性方程:”并跟上 x 和 y 对。就像这样:
从与 ChatGPT 免费版交互时的截图,进行测试时使用。
这是我得到的答案:
从与 ChatGPT 免费版交互时的截图,进行测试时使用。
你清楚地看到答案是错误的,并且它经过了一些我没有要求的编码。更令人困惑的是,这段代码看起来本身是正确的,并且可能会导致正确的解决方案,但文本生成所呈现的“结果”是错误的。
我尝试重新生成问题,得到了另一个不正确的答案,这次没有调用任何代码生成,而是尝试直接处理:
从与 ChatGPT 免费版交互时的截图,进行测试时使用。
现在来看看“魔法”。
如果我们用 DeepMind 报告能显著改善 GPT-3.5-Turbo 答案的以下句子来引导提示会怎么样?(摘自预印本的表 1):
“一点算术和逻辑方法将帮助我们快速得到这个问题的解决方案。”?
我们来试试:
从与 ChatGPT 免费版交互时的截图,进行测试时使用。
严格来说,这不算回归,但逻辑是完美的,结果是正确的!
示例 2,使用 DeepMind 发现的种子句子在 GSM8K 上表现最佳的化学问题
GSM8K 是一个由人工问题编写者创建的高质量且语言多样的小学数学单词问题的大型数据集。DeepMind 使用这个数据集来评估几个 LLM 的能力,发现对于 GPT-3.5-Turbo,最佳的提示开始方式是:
分析给定的信息,将问题分解为可管理的步骤,应用合适的数学运算,并提供一个清晰、准确、简洁的解决方案,如有必要,确保精确四舍五入。考虑所有变量,仔细考虑问题的背景,以便有效地解决问题。
所以,我选择了一个关于化学计量学的问题(在化学中,就是根据给定的试剂量或反之,计算得到的产物量),并要求 ChatGPT 解决,首先是不带种子句子,然后是带上种子句子。
这是没有任何种子提示的情况:
这是与 ChatGPT 免费版本互动的截图,在进行测试时(在这种情况下,侧边并排显示)。
答案完全错误,因为在(a)中我们要求的是分子数量,而不是摩尔数,而在(b)中我们要求的是质量,但数字是错的。
正确答案是 5.337E22 分子和 10.41 克 Zn(CN)2。
现在让我们看看当我在问题前加上种子句子时发生了什么:
这是与 ChatGPT 免费版本互动的截图,在进行测试时(在这种情况下,侧边并排显示)。
我。感到。惊讶。
两个答案都是完全正确的!(并且计算过程非常详细。)
我想我得重新修订并重写我很久以前写的这篇博客,当时 GPT-3 刚刚发布:
学生们能从 OpenAI 最新的语言模型中学习,并将其用作全天候顾问吗?学生们能用它来……
[towardsdatascience.com
手动提示工程的终结?
这些以及其他近期的发展表明,我们可能正在见证手动提示工程时代的结束,以及一个新的时代的开始,在这个时代中,你无需掌握每个 AI 的提示语言就能获得所需的结果。如果你尝试了 DALL-E 3,你一定会发现它在理解你的意图方面做得更好,即使使用的是你一直使用的相同提示。用户可以越来越自然地指示模型,甚至让 AI 系统自动生成能够产生所需结果的提示。
AI 优化 AI:似乎非常强大。那么呢?工作呢?
从历史上看,以前的工业和技术革命产生了新的工作形式,通常伴随着重大的经济和社会变革,带来了不可预测的后果。在那些时代,随着工作的自动化,新的角色应运而生,以满足每个时代的需求——因此工作虽然改变了,但始终存在。
但随着人工智能革命的到来,这可能会有所不同。我们可能首次面临一种能够适应新挑战并学习未来将出现的任务和工作的自动化力量,包括如何控制自身。
我们在短期内会得到什么?
仅仅触及 AI 模型的一种可改进的特征,DeepMind 的预印本显示,单凭优化提示而不进行实际的再训练或微调,我们就可以让 LLMs 表现得更好,其他生成性 AI 模型也必定是如此。
这种提示的优化很难被人类匹配,特别是当优化的提示有意义但其确切形式并不明显时——见我上面提供的第二个示例,其中最优的种子提示相当长且非常具体。
这些发现和进展表明,我们正朝着一个未来发展,在这个未来中,人工智能系统将能够自行生成非常有效的提示,从而减少对手动提示工程的需求。因此,提示工程师的角色可能会发生变化,那些无法适应新自动化程序并学习在哪里仍然有人工干预空间的提示工程师将被淘汰。
区分可能具有不同未来的两种角色至关重要。一方面,如果你的工作是手动创建提示,随着人工智能模型在生成自身提示方面变得更加熟练,你的角色可能变得不那么相关且需求减少。另一方面,如果你的角色涉及在更广泛的系统中优化 AI 模型的使用,其中你将其与非 AI 软件或代码的输入和输出相结合,并可能在其他场景中使用,那么你在劳动力市场上的价值可能对自动化更具韧性,并在长期内仍然有用。
总结我对这一问题的关键观点,就像一些人曾经快速地跟上了提示工程的潮流一样,他们现在必须保持警觉,因为人工智能及其与人类的互动正迅速发展。因此,跟踪如 DeepMind 的预印本或这篇博客文章等文献,对于了解如何在人工智能工具和其他人工智能工具发展时最佳地调整自己的掌握是至关重要的。
参考文献
DeepMind 在 arXiv 上的预印本:
优化无处不在。虽然基于导数的算法已成为解决各种问题的强大工具,…
另一位作家在 Medium 上的一篇有趣的相关博客文章:
为什么这篇论文很重要
medium.com](https://medium.com/@minh.hoque/large-language-models-as-optimizers-explained-a20dc5e5c5af?source=post_page-----e95fb7f4903c--------------------------------)
www.lucianoabriata.com 我写的内容涵盖了我广泛兴趣领域的一切:自然、科学、技术、编程等。 订阅以获取我的新故事 通过电子邮件*。要* 咨询小项目 请查看我的 服务页面。你可以 在这里联系我。 你也可以 在这里给我小费*。
音频机器学习的新领域
·
关注 发表在 Towards Data Science · 发送 新闻简报 · 阅读时间 3 分钟 · Apr 20, 2023
–
不久以前,任何涉及处理音频文件的工作流程——甚至是像转录播客剧集这样相对简单的任务——都伴随着一系列艰难的选择。你可以选择手动操作(在过程中浪费数小时甚至数天的时间),依赖于几款笨拙且最终令人失望的应用程序,或者拼凑出类似于弗兰肯斯坦怪物的工具和代码组合。
那些日子已经过去。强大的模型和易于访问的 AI 界面的兴起使得处理音频和音乐变得更加高效,新的视野每天都在不断开启。为了帮助您跟上音频聚焦机器学习的最新进展,我们从过去几周收集了一些突出的文章,涵盖了各种方法和用例。过滤掉噪音,深入了解吧!
-
揭示音乐标签 AI 的黑箱。随着每天在 Spotify 和 Apple Music 等平台上添加数千首歌曲,您是否曾想过这些服务如何知道为每首歌分配哪种音乐流派?Max Hilsdorf的项目利用 Shapley 值确定特定乐器的存在如何影响 AI 系统标记新曲目的方式。
-
探索基于深度学习的鸟鸣识别方法。Leonie Monigatti最近的贡献涵盖了去年的 BirdCLEF2022 Kaggle 竞赛,参赛者需创建鸟鸣录音的分类器。Leonie 向我们展示了一种巧妙的方法,将音频波形转换为梅尔频谱图,使深度学习模型可以像处理图像一样处理它们。
图片由Oskars Sylwan提供,来源于Unsplash
-
从长 YouTube 视频中自动生成摘要。如果您是一个完美主义者,您会欣赏Bildea Ana使用 OpenAI 的 Whisper 模型和 Hugging Face 进行音频转录的简化流程,然后使用开源的 BART 编码器进行总结。您可以将此方法应用于自己的录音和语音备忘录,或者任何其他音频文件(前提是其所有者允许,当然——始终仔细检查您希望使用的数据的版权和许可状态)。
-
将转录提升到一个新水平。Luís Roque的最新项目与 Ana 的项目有相似之处,但有所不同。它也依赖于 Whisper 来转录音频文件,但随后通过部署 PyAnnotate 进行说话者分离,“即识别和区分不同说话者的语音的过程”。
你说“请不要停止音乐”?我们很乐意满足——这里是我们最近一些最喜欢的关于非音频相关主题的文章。请享用!
-
“学习神经网络不应该是解读误导性图示的练习,” Aaron Master和 Doron Bergman 表示,他们提出了一种建设性的新方法来创建更好、更准确的神经网络。
-
从推广设计到库存分析,Idil Ismiguzel 展示了关联规则挖掘的力量:一种赋能数据专业人士发现数据集中的频繁模式的技术。
-
对于无监督学习和 K-means 聚类的动手实践,不要错过Nabanita Roy的最新教程,该教程专注于按颜色分组图像像素的使用案例。
-
如果你发现人工智能、政府监管和加拿大官僚主义的交集很吸引人(谁会不感兴趣呢?),Mathieu Lemay的深度剖析是你本周绝对不容错过的一篇文章。
-
随着合成数据在多个领域的作用不断发展(和增长),Miriam Santos的实用 CTGAN 生成合成数据指南依然时效性和实用性十足。
-
我们绝对不能在一整周内没有一个以 GPT 为主题的推荐;如果你还没读过,我们强烈推荐Henry Lai对这些备受欢迎的模型背后的数据驱动 AI 概念的概述。
感谢您本周收听《Variable》!如果您喜欢在 TDS 上阅读的文章,请考虑成为 Medium 会员——如果您是符合条件国家的学生,不要错过享受会员大幅折扣的机会。
下期《Variable》见,
TDS 编辑团队
新版 Scikit-Learn 更适合数据分析
原文:
towardsdatascience.com/new-scikit-learn-is-more-suitable-for-data-analysis-8ca418e7bf1c
Scikit-Learn 版本 ≥1.2.0 的 Pandas 兼容性及更多
·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 3 月 8 日
–
新版 Sklearn 的一些非常酷的更新!(来源:作者笔记本)
大约在去年 12 月,Scikit-Learn 发布了一个重要的 稳定更新 (v. 1.2.0–1),我终于可以尝试一些突出的新功能。现在它与 Pandas 兼容性更好,还有一些新功能将帮助我们进行回归和分类任务。接下来,我将介绍一些新更新及其使用示例。我们开始吧!
与 Pandas 的兼容性:
在使用数据进行训练 ML 模型(如回归或神经网络)之前应用数据标准化是一种常见技术,以确保具有不同范围的特征在预测中获得相等的重要性(如果或当需要时)。Scikit-Learn 提供了各种预处理 API,如 StandardScaler
、MaxAbsScaler
等。随着新版本的发布,可以在预处理后保持 Dataframe 格式,让我们看看下面:
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
########################
X, y = load_wine(as_frame=True, return_X_y=True)
# available from version >=0.23; as_frame
########################
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y,
random_state=0)
X_train.head(3)
数据集 Wine 的 Dataframe 格式
新版本包括一个选项,即使在标准化之后也能保持 Dataframe 格式:
############
# v1.2.0
############
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler().set_output(transform="pandas")
## change here
scaler.fit(X_train)
X_test_scaled = scaler.transform(X_test)
X_test_scaled.head(3)
即使在标准化之后,Dataframe 格式仍保持不变。
之前,它会将格式转换为 Numpy 数组:
###########
# v 0.24
###########
scaler.fit(X_train)
X_test_scaled = scaler.transform(X_test)
print (type(X_test_scaled))
>>> <class 'numpy.ndarray'>
由于 Dataframe 格式保持不变,我们不需要像处理 Numpy 数组格式时那样关注列。分析和绘图变得更容易:
fig = plt.figure(figsize=(8, 5))
fig.add_subplot(121)
plt.scatter(X_test['proline'], X_test['hue'],
c=X_test['alcohol'], alpha=0.8, cmap='bwr')
clb = plt.colorbar()
plt.xlabel('Proline', fontsize=11)
plt.ylabel('Hue', fontsize=11)
fig.add_subplot(122)
plt.scatter(X_test_scaled['proline'], X_test_scaled['hue'],
c=X_test_scaled['alcohol'], alpha=0.8, cmap='bwr')
# pretty easy now in the newer version to see the effect
plt.xlabel('Proline (Standardized)', fontsize=11)
plt.ylabel('Hue (Standardized)', fontsize=11)
clb = plt.colorbar()
clb.ax.set_title('Alcohol', fontsize=8)
plt.tight_layout()
plt.show()
图 1:标准化前后特征的依赖关系!(来源:作者笔记本)
即使我们建立了一个管道,管道中的每个转换器也可以配置为返回数据框,如下所示:
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
clf = make_pipeline(StandardScaler(), SVC())
clf.set_output(transform="pandas") # change here
svm_fit = clf.fit(X_train, y_train)
print (clf[:-1]) # StandardScaler
print ('check that set_output format indeed remains even after we build a pipleline: ', '\n')
X_test_transformed = clf[:-1].transform(X_test)
X_test_transformed.head(3)
数据框格式即使在管道中也可以保持不变!
数据集获取更快、更高效:
OpenML是一个开放的数据集分享平台,而 Sklearn 中的数据集 API 提供了fetch_openml
函数来获取数据;随着 Sklearn 的更新,这一步在内存和时间上更高效。
from sklearn.datasets import fetch_openml
start_t = time.time()
X, y = fetch_openml("titanic", version=1, as_frame=True,
return_X_y=True, parser="pandas")
# # parser pandas is the addition in the version 1.2.0
X = X.select_dtypes(["number", "category"]).drop(columns=["body"])
print ('check types: ', type(X), '\n', X.head(3))
print ('check shapes: ', X.shape)
end_t = time.time()
print ('time taken: ', end_t-start_t)
使用parser='pandas'
可以显著提高运行时间和内存消耗的效率。可以通过psutil
库轻松检查内存消耗,如下所示:
print(psutil.cpu_percent())
部分依赖图:分类特征
部分依赖图之前也存在,但仅限于数值特征,现在已经扩展到分类特征。
如 Sklearn 文档中所述:
部分依赖图显示目标与感兴趣的一组输入特征之间的依赖关系,边际化所有其他输入特征(‘补充’特征)的值。直观上,我们可以将部分依赖性解释为目标响应与感兴趣的输入特征的函数。
使用上述的‘titanic’数据集,我们可以轻松绘制分类特征的部分依赖性:
使用上述代码块,我们可以得到如下的部分依赖图:
图 2:分类变量的部分依赖图。(来源:作者笔记)
在 0.24 版本中,我们会遇到分类变量的值错误:
>>> ValueError: could not convert string to float: ‘female’
直接绘制残差(回归模型):
在分析分类模型的性能时,Sklearn 的度量 API 中,像PrecisionRecallDisplay
、RocCurveDisplay
这样的绘图例程在旧版本(0.24)中存在;在新版中,回归模型也可以进行类似的操作。下面是一个示例:
可以直接使用 Sklearn 绘制线性模型拟合及其对应的残差。(来源:作者笔记)
尽管总是可以使用 matplotlib 或 seaborn 绘制拟合线和残差,但在我们确定了最佳模型后,能够快速直接在 Sklearn 环境中检查结果是很棒的。
新版 Sklearn 中还有一些其他的改进/新增功能,但我发现这 4 个主要改进对标准数据分析特别有用。
参考文献:
[2] Sklearn 版本亮点: 视频
[3]所有图表和代码: 我的 GitHub
如果你对进一步的基础机器学习概念及更多内容感兴趣,可以考虑使用 我的链接加入 Medium。你不会支付额外费用,但我将获得一小笔佣金。感谢大家!!
[## 使用我的推荐链接加入 Medium - Saptashwa Bhattacharyya
更多来自 Saptashwa(以及 Medium 上所有其他作者)的内容。你的会员费用直接支持 Saptashwa 和其他作者…
medium.com](https://medium.com/@saptashwa/membership?source=post_page-----8ca418e7bf1c--------------------------------)
新的 SHAP 图:小提琴图和热图
原文:
towardsdatascience.com/new-shap-plots-violin-and-heatmap-20f647313b64
SHAP 版本 0.42.1 中的图表可以告诉你关于模型的哪些信息
·发表于 Towards Data Science ·6 分钟阅读·2023 年 8 月 14 日
–
(来源:作者)
对于 SHAP 最大的担忧之一与软件包本身有关。它已经有一段时间没有更新了,GitHub 上的问题也不断增加。让许多用户感到欣慰的是,贡献者们变得更加活跃。事实上,他们给我们带来了新的图表——小提琴图和热图。我们将:
-
提供这些图的代码
-
讨论我们可以从中获得哪些新见解
你还可以观看关于这个主题的简介:
现有的 SHAP 图
我们从之前的 SHAP 教程继续。你可以在下面的文章中找到这篇教程。你还可以在 GitHub 上找到完整的项目。要使用新的图表,你需要更新 SHAP 软件包。我使用的是版本 0.42.1。
如何创建和解释 SHAP 图:瀑布图、力图、平均 SHAP 图、蜜蜂散点图和依赖图
towardsdatascience.com
总结来说,我们使用 SHAP 来解释一个基于 abalone 数据集 构建的模型。该数据集包含 4,177 个实例,你可以在下方看到特征的示例。我们使用这 8 个特征来预测 y——螺旋纹数。
X 特征矩阵(来源:UCI 机器学习库)(许可证:CC0:公共领域)
本教程继续计算 SHAP 值并显示各种 SHAP 图。理解其中的一些图对于理解新的 SHAP 图是有帮助的。我们将看到它们提供了类似的信息。
第一个是均值 SHAP图,见图 1。对于每个特征,这给出了所有实例的绝对均值 SHAP 值。对预测贡献显著的特征,其均值 SHAP 值会很高。换句话说,这张图告诉我们哪些特征在一般情况下最为重要。
图 1:绝对均值图(来源:作者)
另一种图是蜜蜂散点图,见图 2。这是所有 SHAP 值的可视化。在 y 轴上,值按特征分组。对于每个组,点的颜色由特征值决定(即特征值较高的点颜色较红)。现在,让我们看看新的 SHAP 图与这些图的比较情况。
图 2:蜜蜂散点图(来源:作者)
SHAP 小提琴图
小提琴图的代码类似于我们在其他 SHAP 图中看到的内容。我们只需输入我们的shap_values对象(第 2 行)。为了明确,这些值是我们在之前的教程中计算的。你可以在图 3中查看输出。与图 2相比,我们可以看到小提琴图是蜜蜂散点图的一种不同风格。
# violin plot
shap.plots.violin(shap_values)
图 3:小提琴图(来源:作者)
另一种风格是分层小提琴图,见图 4。在这种图中,每个 SHAP 值下的特征值变化更为清晰。也就是说,与原始的小提琴图和蜜蜂散点图相比。
# layered violin plot
shap.plots.violin(shap_values, plot_type="layered_violin")
图 4:分层小提琴图(来源:作者)
由于相似性,我们从这些图中获得的见解类似于蜜蜂散点图。这些图可以突出显示重要的关系,因为我们可以看到哪些特征往往具有较大的 SHAP 值。通过按特征值着色,我们还可以开始理解特征与模型预测之间的关系。现在,让我们看看热图是否能提供更多见解。
SHAP 热图
你可以在图 5中看到热图函数的输出。这里有很多内容:
-
在 x 轴上,我们对所有 4,177 个实例进行了标记
-
y 轴表示特征
-
每个实例上方的线条按该特征的SHAP 值进行着色
-
f(x) 线表示该实例的预测环数
-
右侧的条形图显示了我们在图 1中看到的平均 SHAP 值
与蜜蜂散点图类似,这是一种每个 shap 值的图。但现在我们关注的是 SHAP 值与实例组之间的模式。
# heatmap
shap.plots.heatmap(shap_values)
图 5:SHAP 热图(来源:作者)
默认情况下,实例是使用层次聚类算法进行排序的。开发者表示,“这会将因相同原因得到相同模型输出的样本分组在一起”。我发现选择自己的实例排序对于发现模式更为有用。
热图排序
为此,我们传递一个instance_order参数。这必须是与数据集长度相同的整数数组(即 4,177)。这些值给出实例的顺序。在下面的代码中,我们将实例从预测值最低到最高排序。
# order by predictions
order = np.argsort(y_pred)
shap.plots.heatmap(shap_values, instance_order=order)
在图 6的输出中,我们看到了一些模式的出现。注意去壳重量的 SHAP 值有 3 个组。存在两个正值组——一个是当壳体重量的 SHAP 值既小又大时。一个潜在的交互作用?我们可以通过 SHAP 交互值进一步探索。
图 6:按预测值排序的 SHAP 热图(来源:作者)
另一种选择是按特征的值对实例进行排序。下面,我们使用壳体重量对它们进行排序。我们可以看到,预测的环数随着该特征的增加而增加。我们还可以看到该特征的 SHAP 值也有增加的趋势。换句话说,壳体重量值越大,预测的环数越高。
# order by feature's values
order = np.argsort(data['shell weight'])
shap.plots.heatmap(shap_values, instance_order=order)
图 7:按特征值排序的 SHAP 热图(来源:作者)
我们可以以任何我们想要的方式排序热图。这种灵活性可以帮助我们以其他图表无法提供的方式理解我们的模型。就个人而言,我很兴奋看到这些发展的出现。更多的特征和可视化选项将受到包的众多用户的赞赏。你希望在未来的更新中看到什么?
如果你想了解更多关于 SHAP 的信息,请查看下面的文章:
使用 SHAP Python 包识别和可视化数据中的交互作用
[towardsdatascience.com ## 从 Shapley 到 SHAP — 理解数学
关于 SHAP 特征贡献计算的概述
[towardsdatascience.com ## SHAP 的局限性
SHAP 如何受到特征依赖、因果推断和人为偏差的影响
[towardsdatascience.com
希望你喜欢这篇文章!你可以通过成为我的推荐会员来支持我**😃**
[## 通过我的推荐链接加入 Medium — Conor O’Sullivan
作为 Medium 会员,你的部分会员费用将用于你阅读的作家,同时你可以全面访问所有故事……
conorosullyds.medium.com](https://conorosullyds.medium.com/membership?source=post_page-----20f647313b64--------------------------------)
| Twitter | YouTube | Newsletter — 免费注册获取 Python SHAP 课程
参考资料
S. Lundberg SHAP****Python 包 github.com/slundberg/shap
S. Lundberg & S. Lee, 统一解释模型预测的方法 arxiv.org/pdf/1705.07874.pdf
SHAP 热图 shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/heatmap.html
SHAP 小提琴图总结 shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/violin.html
牛顿运动定律:最初的梯度下降
探索梯度下降和牛顿运动方程的共享语言
·
关注 发表在 Towards Data Science ·7 min read·2023 年 12 月 27 日
–
照片由 Luddmyla . 提供,来自 Unsplash
我记得在工程学院读本科时,作为物理学学生,我上了第一门机器学习课程。换句话说,我是个局外人。当教授通过梯度下降解释反向传播算法时,我脑海里有个模糊的问题:“梯度下降是一个随机算法吗?”在举手向教授提问之前,陌生的环境让我犹豫了一下;我稍微缩了回来。突然,答案闪现在我脑海里。
这是我想到的。
梯度下降
要说清楚什么是梯度下降,我们首先需要定义训练神经网络的问题,我们可以通过概述机器如何学习来做到这一点。
神经网络训练概述
在所有监督学习的神经网络任务中,我们有一个预测值和真实值。预测值与真实值之间的差异越大,说明我们的神经网络在预测值方面的表现越差。因此,我们创建了一个称为损失函数的函数,通常表示为L,它量化了实际值与预测值之间的差异。训练神经网络的任务就是更新权重和偏差(简称参数)以最小化损失函数。这就是训练神经网络的大致情况,而“学习”只是更新参数以最佳适应实际数据,即最小化损失函数。
通过梯度下降优化
梯度下降是一种用于计算这些新参数的优化技术。由于我们的任务是选择参数以最小化损失函数,我们需要一个选择标准。我们试图最小化的损失函数是神经网络输出的函数,因此在数学上我们将其表示为L = L(y_nn, y)。但神经网络输出y_nn也依赖于其参数,所以y_nn = y_nn(θ),其中θ是一个包含我们神经网络所有参数的向量。换句话说,损失函数本身是神经网络参数的一个函数。
借鉴了一些向量微积分的概念,我们知道要最小化一个函数,你需要逆着它的梯度方向前进,因为梯度指向函数最快增长的方向。为了获得一些直觉,我们来看一下图 1 中 L(θ)可能的样子。
图 1:显示L(w1,w2)作为 w1 和 w2 函数的表面。图像由作者提供。
在这里,我们对训练神经网络时什么是期望的、什么是不期望的有了清晰的直觉:我们希望损失函数的值更小,所以如果我们从使损失函数落在黄色/橙色区域的参数 w1 和 w2 开始,我们希望沿着紫色区域的方向滑动下降到表面。
这种“滑下”运动是通过梯度下降方法实现的。如果我们处在表面上最亮的区域,梯度将继续指向上方,因为这是增加最快的方向。然后,沿相反方向(因此是梯度下降)会产生一个向着最大减少区域的运动。
为了看到这一点,我们可以绘制梯度下降向量,如图 2 所示。在这个图中,我们有一个等高线图,显示了与图 1 相同的区域和函数,但损失函数的值现在编码为颜色:越亮,值越大。
图 2:显示指向梯度下降方向的向量的等高线图。图片由作者提供。
我们可以看到,如果我们选择一个位于黄色/橙色区域的点,梯度下降向量会指向最快到达紫色区域的方向。
一个很好的免责声明是,通常一个神经网络可能包含任意多的参数(GPT-3 有超过 1000 亿个参数!),这意味着这些漂亮的可视化在实际应用中完全不实用,神经网络中的参数优化通常是一个非常高维的问题。
从数学上讲,梯度下降算法可以表示为
在这里,θ(n+1) 是更新后的参数(即图 1 的表面滑下来的结果);θ(n) 是我们开始时的参数;ρ 被称为学习率(梯度下降指向的方向上的步长);∇L 是在初始点 θ_(n) 计算的损失函数的梯度。使这里的名字为下降的是前面的负号。
数学在这里非常关键,因为我们会看到牛顿的运动第二定律与梯度下降方程具有相同的数学公式。
牛顿第二定律
牛顿的运动第二定律可能是经典力学中最重要的概念之一,因为它说明了力、质量和加速度是如何联系在一起的。每个人都知道牛顿第二定律的高中公式:
其中,F 是力,m 是质量,a 是加速度。然而,牛顿原始的公式是基于一个更深层的量:动量。动量是物体的质量与速度的乘积:
并且可以解释为物体的运动量。牛顿第二定律背后的思想是,要改变一个物体的动量,你需要以某种方式干扰它,这种干扰被称为力。因此,牛顿第二定律的简洁公式是
这种公式适用于你能想到的每一种力,但我们希望在讨论中有更多的结构,为了获得结构,我们需要限制我们的可能性范围。让我们讨论保守力和势能。
保守力和势能
保守力是一种不耗散能量的力。这意味着,当我们处于仅涉及保守力的系统中时,总能量是常量。这听起来很严格,但实际上,自然界中最基本的力量都是保守的,如重力和电力。
对于每个保守力,我们关联一个称为势能的量。这个势能通过公式与力相关联。
在一维中。如果我们仔细查看最后两个公式,就会得到保守场的第二运动定律:
由于导数处理起来有点复杂,并且在计算机科学中我们反正将导数近似为有限差分,因此让我们用Δ替换 d:
我们知道Δ意味着“取更新值并减去当前值”。因此,我们可以将上述公式重新写成
这已经看起来很像上面几行中的梯度下降公式。为了使其更类似,我们只需在三维中查看它,梯度自然会出现:
我们可以清楚地看到梯度下降与上述公式之间的对应关系,这完全源于牛顿物理学。一个物体的动量(如果你愿意,也可以理解为速度)总是指向势能减少最快的方向,步长由Δt 给出。
结束语和要点总结
因此,我们可以将牛顿公式中的势能与机器学习中的损失函数相关联。动量向量类似于我们试图优化的参数向量,时间步长常数即学习率,即我们朝着损失函数最小值移动的速度。因此,类似的数学公式表明这些概念是联系在一起的,并且提供了一种很好的统一视角。
如果你想知道,我一开始的问题的答案是“不是”。梯度下降算法中没有随机性,因为它复制了自然每天所做的事情:粒子的物理轨迹总是试图在周围找到最低可能的势能。如果你让一个球从某个高度掉落,它总会有相同的轨迹,没有随机性。当你看到有人在滑板上滑下陡峭的坡道时,请记住:那实际上是自然在应用梯度下降算法。
我们看待问题的方式可能会影响其解决方案。在这篇文章中,我没有展示任何关于计算机科学或物理的新内容(实际上,这里的物理知识已有约 400 年历史),但改变视角和将(表面上)不相关的概念结合在一起,可能会创造出新的联系和对某一主题的直觉。
参考文献
[1] Robert Kwiatkowski, 梯度下降算法——深度探讨,2021 年。
[2] Nivaldo A. Lemos, 《解析力学》,剑桥大学出版社,2018 年。
创建快速、安全且兼容的数据结构的九条规则(第一部分)
来自 RangeSetBlaze 的经验教训
·
关注 发表在 Towards Data Science · 13 分钟阅读 · 2023 年 4 月 5 日
–
将数字存储在树中 — 来源:Stable Diffusion
今年,我开发了一个新的 Rust crate,名为 [range-set-blaze](https://crates.io/crates/range-set-blaze)
,它实现了范围集合数据结构。范围集合是一种有用(虽然较少见)的数据结构,它将整数集合存储为已排序且不相交的范围。例如,它存储以下三个范围:
100..=2_393, 20_303..=30_239_000, 501_000_013..=501_000_016
而不是 30220996 个单独的整数。除了潜在的内存节省,range-set-blaze
还提供了高效的集合操作,如并集、交集、补集、差集和对称差集。
在创建range-set-blaze
时,我学到了九条规则,这些规则可以帮助你在 Rust 中创建数据结构。除了数据结构,这些规则中的许多还可以帮助你提高任何 Rust 代码的性能和兼容性。
规则如下:
-
抄袭你的 API、文档,甚至代码——从标准库中抄袭。
-
设计构造函数以便于使用、兼容性和速度。
-
创建比预期更多的 Rust 迭代器。
-
使用 traits 使非法值不可表示。
-
定义具有保证属性和有用方法的通用迭代器。
在 第二部分**中讨论:
6. 定义运算符和快速操作。
7. 遵循“良好 API 设计的九条规则”,特别是“编写良好的文档”。
8. 使用代表性数据、Criterion Benchmarking 和性能分析来优化性能。
9. 测试覆盖率、文档、traits、编译器错误和正确性。
在查看前五条规则之前,让我们先看看range-set-blaze
可能的使用场景,它的集合操作是如何工作的,以及它与其他范围集 crate 的比较。
有用性: 想象一下在一个不可靠的集群上运行 100 亿个统计实验。集群上的每个任务运行几个实验。每个实验产生一行带有实验编号的输出。所以,一个任务可能会把这些放入一个文件中:
你会使用什么数据结构来查找哪些实验缺失并需要重新提交?一个选项是:将输出的实验编号存储在一个[BTreeSet](https://doc.rust-lang.org/std/collections/struct.BTreeSet.html)
中,然后进行线性扫描以查找间隙。
更快且内存效率更高的选项:使用范围集。八年前,我创建了[IntRangeSet](https://fastlmm.github.io/PySnpTools/#util-intrangeset)
,一个用 Python 编写的范围集来解决这个问题。现在,我会在 Rust 中使用range-set-blaze
(示例代码)。
集合操作:这是一个简单的并集运算符(|
)示例:
use range_set_blaze::RangeSetBlaze;
// a is the set of integers from 100 to 499 (inclusive) and 501 to 1000 (inclusive)
let a = RangeSetBlaze::from_iter([100..=499, 501..=999]);
// b is the set of integers -20 and the range 400 to 599 (inclusive)
let b = RangeSetBlaze::from_iter([-20..=-20, 400..=599]);
// c is the union of a and b, namely -20 and 100 to 999 (inclusive)
let c = a | b;
assert_eq!(c, RangeSetBlaze::from_iter([-20..=-20, 100..=999]));
附注:请参阅项目的
[README.md](https://github.com/CarlKCarlK/range-set-blaze)
以获取来自生物学的另一个集合运算符示例。该示例使用RangeSetBlaze
结构体从转录区域和外显子区域中查找基因的内含子区域。
与其他范围相关的 crate 的比较
好处: 尽管 Rust 的 crates.io 已经包含了几个范围集合的 crate,我希望我的版本能提供完整的集合操作,同时保持性能。通过各种优化措施,我相信它达到了这些目标(请参见 基准报告)。例如,它可以比最流行的范围集合 crate 快 75 倍来处理单个整数(因为其他 crate 没有对单个处理做特殊优化——但它可以轻松添加这种优化)。在另一个基准测试中,range-set-blaze
——使用混合算法——在合并两个集合时比其他 crate 快 30% 到 600%。
不足: 与其他范围相关的 crate 相比,range-set-blaze
有两个重要不足。首先,它仅适用于 Rust 整数类型。大多数其他 crate 处理任何可以排序的元素(日期、浮点数、IP 地址、字符串等)。其次,它仅提供集合功能。许多其他 crate 还处理映射。随着兴趣(以及可能的帮助),这些不足可能会在未来得到解决。
创建数据结构需要做出许多决策。根据我在 range-set-blaze
上的经验,以下是我推荐的决策。为了避免优柔寡断,我将这些建议表述为规则。当然,每个数据结构都不同,因此并非每条规则都适用于每个数据结构。
本文涵盖规则 1 到 5。第二部分 涵盖规则 6 到 9。
规则 1:抄袭 API、文档甚至代码——来自标准库
查找标准库中的类似数据结构,并逐行研究其文档。我选择了 BTreeSet
作为我的模型。它可以在缓存高效的平衡树中存储整数集合。
附带说明:稍后,在基准测试(规则 8)中,我们将看到
range_set_blaze::*RangeSetBlaze*
在某些“块状”整数集合上的速度可能比*BTreeSet*
快 1000 倍。
BTreeSet
提供了 28 个方法,例如,clear
和 is_subset
。它还实现了 18 个特性,例如,FromIterator<T>
。这是 BTreeSet
的 clear
文档和 RangeSetBlaze
的 clear
文档:
你可以看到我主要是直接复制的。我将“元素”改为“整数元素”,以提醒用户 RangeSetBlaze
支持什么。我删除了 where A: Clone
,因为所有整数必然是可克隆的。注意,Rust 文档包括一个“源”链接,这使得复制变得容易。
复制提供了这些优点:
-
它告诉你需要提供哪些方法。换句话说,它为你的 API(应用程序编程接口)提供了一个起点。这节省了设计时间。此外,用户会理解并期望这些方法。你甚至可以使你的数据结构成为标准数据结构的直接替代品。
-
几乎可以免费获得文档文本和文档测试。
-
你甚至可以复制代码。例如,这里是
BTreeSet
和RangeSetBlaze
的is_superset
代码:
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
pub fn is_superset(&self, other: &BTreeSet<T, A>) -> bool
where
T: Ord,
{
other.is_subset(self)
}
#[must_use]
pub fn is_superset(&self, other: &RangeSetBlaze<T>) -> bool {
other.is_subset(self)
}
BTreeSet
代码让我想起了超集可以通过子集来定义,以及 #[must_use]
是一个存在且在这里适用的特性。
你可能 决定 不支持标准数据结构中的所有功能。例如,我跳过了 new_in
,这是一个实验性特性。同样,标准库支持映射(不仅仅是集合)、任何可排序的元素(不仅仅是整数)和 Serde 序列化。对我而言,这些是可能的未来特性。
你也可以 决定 以不同的方式支持某些内容。例如,BTreeSet::first
返回 Option<&T>
但 RangeSetBlaze::first
返回 Option<T>
。我知道 T
是一个便于克隆的整数,所以不需要是一个引用。
顺便提一下:Rust 没有一个通用的
*Set*
特性来告诉所有集合应该实现哪些方法,甚至提供一些默认实现(例如,*is_superset*
以*is_subset*
作为基础)吗?没有,但这个问题正在被 讨论。
你也可能 决定 支持比标准数据结构更多的方法。例如,RangeSetBlaze::len
和 BTreeSet::len
一样,返回集合中的元素数量。然而,RangeSetBlaze
还提供 ranges_len
,它返回集合中排序的、不相交的范围的数量。
规则 2:设计构造函数以提高易用性、兼容性和速度
如果有一个空版本的数据结构是有意义的,你会想定义一个 new
方法和一个 [Default::default](https://doc.rust-lang.org/std/default/trait.Default.html)
方法。
类似地,如果从迭代器填充数据结构是有意义的,你会想定义 [FromIterator::from_iter](https://doc.rust-lang.org/std/iter/trait.FromIterator.html)
方法。这些方法也会自动定义 collect
方法。像 BTreeSet
一样,RangeSetBlaze
接受整数的迭代器和对整数的引用。(接受引用很重要,因为许多 Rust 迭代器提供引用。)以下是 from_iter
和 collect
使用的示例:
let a0 = RangeSetBlaze::from_iter([3, 2, 1, 100, 1]);
let a1: RangeSetBlaze<i32> = [3, 2, 1, 100, 1].into_iter().collect();
assert!(a0 == a1 && a0.to_string() == "1..=3, 100..=100");
RangeSetBlaze
也接受包含范围的迭代器以及对这些范围的引用。它对输入范围没有限制。这些范围可以是无序的、重叠的、空的或重复的。
#[allow(clippy::reversed_empty_ranges)]
let a0 = RangeSetBlaze::from_iter([1..=2, 2..=2, -10..=-5, 1..=0]);
#[allow(clippy::reversed_empty_ranges)]
let a1: RangeSetBlaze<i32> = [1..=2, 2..=2, -10..=-5, 1..=0].into_iter().collect();
assert!(a0 == a1 && a0.to_string() == "-10..=-5, 1..=2");
最后,考虑定义额外的 From::from
方法。这些方法会自动定义 into
方法。例如,为了兼容 BTreeSet
,RangeSetBlaze
在数组上定义了一个 From::from
方法。
let a0 = RangeSetBlaze::from([3, 2, 1, 100, 1]);
let a1: RangeSetBlaze<i32> = [3, 2, 1, 100, 1].into();
assert!(a0 == a1 && a0.to_string() == "1..=3, 100..=100")
RangeSetBlaze
还定义了from_sorted_disjoint/into_range_set_blaze
,用于保证已排序且不相交的区间的迭代器。(我们将在规则 5 中看到,如何通过特殊特性和 Rust 编译器来强制执行这一保证。)
let a0 = RangeSetBlaze::from_sorted_disjoint(CheckSortedDisjoint::from([-10..=-5, 1..=2]));
let a1: RangeSetBlaze<i32> = CheckSortedDisjoint::from([-10..=-5, 1..=2]).into_range_set_blaze();
assert!(a0 == a1 && a0.to_string() == "-10..=-5, 1..=2");
附言:为什么使用
*from_sorted_disjoint*
/*into_range_set_blaze*
而不是*from_iter /collect*
或*from/into*
?请参见这个讨论和这个讨论。
对于你的每一个构造函数,考虑可能的加速和优化。RangeSetBlaze
在from_iter
中实现了这种优化:
-
将相邻(可能无序)整数/区间合并成不相交的区间,O(n₁)
-
按起始位置对不相交的区间进行排序,O(n₂ log n₂)
-
合并相邻的区间,O(n₂)
-
从现在排序且不相交的区间创建一个
BTreeMap
,O(n₃ log n₃)
其中 n₁ 是输入整数/区间的数量,n₂ 是不相交且无序的区间数量,n₃ 是最终排序且不相交的区间数量。
“块状”整数的影响是什么?如果 n₂ ≈ sqrt(n₁),则构建时间为 O(n₁)。(实际上,只要 n₂ ≤ n₁/ln(n₁),构建时间为 O(n₁)。)在基准测试中,这在块状整数迭代器上变成了比HashSet
和BTreeSet
快 700 倍。
规则 3:创建比你预期的更多的 Rust 迭代器
你猜测标准BTreeSet
定义了多少种不同的迭代器类型?
答案是八种:Iter
,IntoIter
,DrainFilter
,Range
,Difference
,SymmetricDifference
,Intersection
,和Union
。许多非 Rust 编程语言可以将任何方法变成迭代器/生成器,只需几个“yield”语句。然而,Rust 并不提供这种功能(但正在讨论中)。因此,几乎每个与迭代相关的方法都需要你定义一个新的迭代器结构类型。这些结构至少会实现一个next
方法,该方法返回Some(
值)
或None
。
RangeSetBlaze
及其相关类型定义了 13 个迭代器结构。让我们看两个。
首先,用户可以调用ranges
并将整数作为一系列排序的不相交区间进行迭代。(请记住,RangeSetBlaze
接受无序、重叠的区间,但存储排序的不相交区间。)
use range_set_blaze::RangeSetBlaze;
let set = RangeSetBlaze::from_iter([30..=40, 15..=25, 10..=20]);
let mut ranges = set.ranges();
assert_eq!(ranges.next(), Some(10..=25));
assert_eq!(ranges.next(), Some(30..=40));
assert_eq!(ranges.next(), None);
在内部,RangeSetBlaze
使用标准BTreeMap
来存储区间信息。因此,RangeSetBlaze::ranges
方法构造一个包含BTreeMap::Iter
的RangesIter
结构。然后我们让RangesIter::next
方法调用BTreeMap::Iter
的next
方法,并将结果转换成所需类型。这里是代码:
impl<T: Integer> RangeSetBlaze<T> {
pub fn ranges(&self) -> RangesIter<'_, T> {
RangesIter {
iter: self.btree_map.iter(),
}
}
}
#[derive(Clone, Debug)]
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub struct RangesIter<'a, T: Integer> {
pub(crate) iter: btree_map::Iter<'a, T, T>,
}
impl<'a, T: Integer> Iterator for RangesIter<'a, T> {
type Item = RangeInclusive<T>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(start, end)| *start..=*end)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
其次,用户可能希望调用iter
并逐个以排序顺序遍历整数。在这种情况下,我们将返回一个名为Iter
的结构体,它包含一个RangeIter
,然后逐个遍历范围内的整数。以下是Iter::next
的原始代码,之后是关注点的讨论。
impl<T: Integer, I> Iterator for Iter<T, I>
where
I: Iterator<Item = RangeInclusive<T>> + SortedDisjoint,
{
type Item = T;
fn next(&mut self) -> Option<T> {
loop {
if let Some(range) = self.option_range.clone() {
let (start, end) = range.into_inner();
debug_assert!(start <= end && end <= T::safe_max_value());
if start < end {
self.option_range = Some(start + T::one()..=end);
} else {
self.option_range = None;
}
return Some(start);
} else if let Some(range) = self.iter.next() {
self.option_range = Some(range);
continue;
} else {
return None;
}
}
}
SortedDisjoint
特征涉及到保证内部迭代器提供排序的、不相交的范围。我们将在规则 5 中讨论它。
option_range
字段保存我们当前返回整数的范围(如果有的话)。我们使用loop
和continue
来填充空的option_range
。这个循环最多只循环两次,因此我本可以使用递归。然而,其他一些迭代器的递归次数足以导致栈溢出。因此,…
尾递归优化在 Rust 中没有保证。我的政策是:在next
函数中从不使用递归。
附注:感谢 Michael Roth,当前版本的
Iter::next
现在更简短了。他的拉取请求在这里。
BTreeSet
和RangeSetBlaze
除了iter
方法外,还定义了一个into_iter
迭代器方法。同样,RangeSetBlaze
除了其ranges
方法外,还定义了一个into_ranges
迭代器方法。这些into
_whatever方法获取RangeSetBlaze
的所有权,这在某些情况下很有用。
规则 4:通过特征使非法值不可表示
我说过RangeSetBlaze
只适用于整数,但有什么阻止你将它应用于字符呢?
use range_set_blaze::RangeSetBlaze;
fn _some_fn() {
let _char_set = RangeSetBlaze::from_iter(['a', 'b', 'c', 'd']);
}
答案?编译器会阻止你。它返回这个错误消息:
let _char_set = RangeSetBlaze::from_iter(['a', 'b', 'c', 'd']);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| the trait `Integer` is not implemented for `char`
|
= help: the following other types implement trait `Integer`:
i128
i16
i32
i64
i8
isize
u128
u16
and $N others
为了实现这一点,RangeSetBlaze
定义了一个它称之为Integer
的特征。以下是该定义(以及我找到的所有超级特征):
pub trait Integer:
num_integer::Integer
+ FromStr
+ fmt::Display
+ fmt::Debug
+ std::iter::Sum
+ num_traits::NumAssignOps
+ FromStr
+ Copy
+ num_traits::Bounded
+ num_traits::NumCast
+ Send
+ Sync
+ OverflowingSub
+ SampleUniform
{
// associated type SafeLen definition not shown ...
fn safe_len(range: &RangeInclusive<Self>) -> <Self as Integer>::SafeLen;
fn safe_max_value() -> Self { Self::max_value() }
fn f64_to_safe_len(f: f64) -> Self::SafeLen;
fn safe_len_to_f64(len: Self::SafeLen) -> f64;
fn add_len_less_one(a: Self, b: Self::SafeLen) -> Self;
fn sub_len_less_one(a: Self, b: Self::SafeLen) -> Self;
}
接下来,我在所有感兴趣的整数类型(u8
到u128
包括usize
,i8
到i128
包括isize
)上实现了Integer
特征。例如,
impl Integer for i32 {
#[cfg(target_pointer_width = "64")]
type SafeLen = usize;
fn safe_len(r: &RangeInclusive<Self>) -> <Self as Integer>::SafeLen {
r.end().overflowing_sub(*r.start()).0 as u32 as <Self as Integer>::SafeLen + 1
}
fn safe_len_to_f64(len: Self::SafeLen) -> f64 {len as f64}
fn f64_to_safe_len(f: f64) -> Self::SafeLen {f as Self::SafeLen}
fn add_len_less_one(a: Self, b: Self::SafeLen) -> Self {a + (b - 1) as Self}
fn sub_len_less_one(a: Self, b: Self::SafeLen) -> Self {a - (b - 1) as Self}
}
有了这个,我可以使代码泛型化为<T: Integer>
,如规则 3 中的代码示例所示。
附注:为什么 Rust 没有提供一个标准的“整数”特征来做所有事情?这里是讨论。
规则 5:定义具有保证属性和有用方法的泛型迭代器
RangeSetBlaze
的from_sorted_disjoint
构造函数假设输入是排序好的不相交范围。这让RangeSetBlaze
避免了工作。但是如果这个假设错误了呢?例如,如果我们给它未排序的范围,会发生什么?
use range_set_blaze::RangeSetBlaze;
fn _some_fn() {
let not_guaranteed = [5..=6, 1..=3, 3..=4].into_iter();
let _range_set_int = RangeSetBlaze::from_sorted_disjoint(not_guaranteed);
}
与规则 4 一样,编译器会捕捉错误并返回有用的消息:
7 | let _range_set_int = RangeSetBlaze::from_sorted_disjoint(not_guaranteed);
| ----------------------------------- ^^^^^^^^^^^^^^
the trait `SortedDisjoint<_>` is not implemented for `std::array::IntoIter<RangeInclusive<{integer}>, 3>`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `SortedDisjoint<T>`:
CheckSortedDisjoint<T, I> ...
为了实现这一点,RangeSetBlaze
定义了特征SortedDisjoint
。以下是相关定义:
pub trait SortedStarts<T: Integer>: Iterator<Item = RangeInclusive<T>> {}
pub trait SortedDisjoint<T: Integer>: SortedStarts<T> {
// methods not shown, yet
}
这说明 SortedDisjoint
是对整数的泛型,并且每个 SortedDisjoint
必须是 SortedStarts
。此外,所有 SortedStarts
都是整数范围的迭代器。
附注:我的项目需要两个新的特征,因为我需要保证两个不同的属性。需要保证一个属性的项目只需要一个新的特征。
那么,重点是什么呢?为什么要引入新的特征,而不是直接使用 Iterator<Item = RangeInclusive<T>
?正如我从 Rüdiger Klaehn 的精彩 sorted-iter crate 中学到的,我们可以使用这些新特征来强制执行保证。例如,这个构造函数使用 where
子句只接受保证的(排序且不重叠的)整数迭代器:
impl<T: Integer> RangeSetBlaze<T> {
pub fn from_sorted_disjoint<I>(iter: I) -> Self
where
I: SortedDisjoint<T>,
{
// ... code omitted ...
}
}
那么,保证的迭代器如何获得所需的 SortedDisjoint
特征?它们实现了这个特征!例如,我们知道 RangeSetBlaze::ranges
方法返回一个 RangesIter
迭代器,它由排序且不重叠的范围组成,因此我们让 RangesIter
实现 SortedDisjoint
特征,如下所示:
impl<T: Integer> SortedStarts for RangesIter<'_, T> {}
impl<T: Integer> SortedDisjoint for RangesIter<'_, T> {}
就这样。我们已经将 RangesIter
标记为 SortedDisjoint
。编译器会完成剩下的工作。
不保证到保证:我还标记了一个名为 CheckSortedDisjoint
的迭代器为 SortedDisjoint
。有趣的是,它遍历一个 不保证 的内部迭代器。这怎么可能呢?实际上,当它迭代时,它也会检查——如果发现任何未排序或重叠的范围则会引发恐慌。结果是一个保证的迭代器。
有时保证外部迭代器:那么有时是SortedDisjoint
而有时不是的迭代器呢?例如,流行的 [Itertools::tee](https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.tee)
方法将任何迭代器转换为两个具有相同内容的迭代器。如果其输入迭代器是SortedDisjoint
,那么其输出迭代器也将是:
impl<T: Integer, I: SortedDisjoint<T>> SortedDisjoint<T> for Tee<I> {}
定义方法:几乎可以说是额外的好处,我们可以在泛型 SortedDisjoint
迭代器上定义方法。例如,在这里我们定义了 complement
方法,该方法生成当前迭代器中 不 包含的每个排序且不重叠的整数范围。
pub trait SortedDisjoint<T: Integer>: SortedStarts<T> {
fn complement(self) -> NotIter<T, Self>
where
Self: Sized,
{
NotIter::new(self)
}
}
这是来自 complement
文档的一个使用示例:
use range_set_blaze::prelude::*;
let a = CheckSortedDisjoint::from([-10i16..=0, 1000..=2000]);
let complement = a.complement();
assert_eq!(complement.to_string(), "-32768..=-11, 1..=999, 2001..=32767");
complement
方法使用 NotIter
,另一个迭代器(见规则 3)。NotIter
也实现了 SortedDisjoint
。这个示例还使用了 to_string
,这是另一个 SortedDisjoint
方法。
要使用 complement
和 to_string
,用户必须
use range_set_blaze::SortedDisjoint;
或 use range_set_blaze::prelude::*;
。前导模块起作用是因为项目的 lib.rs
指定了
pub mod prelude;
和 prelude.rs
文件包含这个 pub use
语句,它包括 SortedDisjoint
:
pub use crate::{
intersection_dyn, union_dyn, CheckSortedDisjoint, DynSortedDisjoint, MultiwayRangeSetBlaze,
MultiwayRangeSetBlazeRef, MultiwaySortedDisjoint, RangeSetBlaze, SortedDisjoint,
};
这些是创建 Rust 数据结构的前五条规则。请参见 第二部分 了解规则 6 到 9。
附言:如果你对未来的文章感兴趣,请关注我的 Medium。我写关于 Rust 和 Python 的科学编程、机器学习和统计学的文章。我通常每个月写一篇文章。