2019 年火花与人工智能峰会
Picture from Spark and AI Summit 2019
我对 2019 年 4 月 24 日和 25 日在旧金山举办的最新一届 Spark 和 AI 峰会的回顾。
上周举办了最新一期的星火会议。这是我第一次参加会议。以下是会议不同方面的分析。
大新闻
会议的组织者和 Spark 的主要贡献者 Databricks 宣布了几个项目:
树袋熊
他们宣布了一个名为考拉的新项目,考拉是 Spark 的本地“熊猫”翻译。你现在可以自动将你的熊猫代码移植到 Spark 的分布式世界。这将是人们适应熊猫环境的一座奇妙的桥梁。许多在线课程/大学使用熊猫教授数据科学。现在,新的数据科学家将填补一点损失。
我不认为这只会对新的数据科学家有用。正如您可能知道的,数据科学是一个充斥着您公司周围的脚本的世界。人们使用各种框架在各种环境下创建脚本来完成各种任务。如果你的主环境是 Spark,你将会调整你的熊猫的执行环境,少一个要关心的。
考拉是一个免费的开源项目在这里。该项目仍处于预发布版本(0.1)
三角洲湖
data bricks(Spark 的付费版本)的主要组件之一 Delta 刚刚获得开源。对于使用标准版 Spark 的人来说,这是一个非常好的消息。
产品的所有细节都可以在 https://delta.io/的找到
MLFlow
from Databricks 的端到端生命周期模型管理将于 5 月升级到 1.0 版本。
以下组件将添加到现有产品中:
- MLFlow 工作流,允许在一个管道中打包多个步骤项目
- MLFlow Model Registery,Registery 发布模型、版本,查看谁在使用它
对于任何商业生产模型的人来说,这似乎是一个有趣的过程。
有趣的是,两年前,的一个同事也在做一个类似的内部项目>。所以我可以说,它确实符合行业的实际需求。
最佳谈话
以下是我个人参加过的最喜欢的演讲列表:
智能连接算法,用于应对大规模倾斜
作者:安德鲁·克莱格
这个关于如何处理大型数据集中的不对称的演讲是我最期待的演讲,实际上也是我想参加会议的原因之一…我没有失望。
Andrew 提出了一个非常简单但效率惊人的方法来处理偏斜。我已经可以看到在我的工作中应用这些知识的地方。TLDR:他建议通过在 ID 末尾添加一个随机整数,将你真正频繁的数据细分成更小的块,并在更小的表中创建所有可能的 newID。
更多详情,你可以查看他们的 his 幻灯片这里。
Apache Spark 数据验证
作者:Patrick Pisciuneri 和 Doug Balog(目标)
他们共享 Target 数据验证框架,该框架应该很快就会开源。该框架允许在生成后进行数据验证。
如果代码有单元测试,数据需要类似这样的东西。我们都知道,当您处理数据集时,您有一组假设,当您创建管道时,这些假设可能是正确的,但几个月后,数据“真相”可能会略有不同,然后您的管道可能会在数据上失败。更糟糕的是,它可能在你没有意识到的情况下处理它而没有失败。这样的框架将有助于保持数据的完整性。
框架在 Github 上可用。
美妙的触感
我真的很喜欢他们给 ML/AI 伦理学的聚光灯。他们在周四的主题演讲中安排了一个黄金时段来谈论道德问题。我觉得这个话题讨论的不够多,或者至少不够优先。
这一点值得称赞。
无形的
众所周知,会议有两件事,谈话和社交。本次大会似乎已经理解了这一点,并为实现联网做出了很多努力。他们基本上从早上 8 点到晚上 11 点全天都有内容/活动,让人们呆在现场。
我与来自不同行业的其他数据科学家进行了许多有趣的讨论。对我来说,这是会议的关键点。
结论
我真的很喜欢这次会议,推销很平衡。大多数技术会谈都是纯粹的行业火花,没有销售意图。人际关系网太棒了。技术含量高。祝贺组织者。
据我所知,他们会在自己的网站上发布一些演讲的视频:https://databricks.com/sparkaisummit/north-america
最初发布于 2019 年 4 月 29 日http://coffeeanddata . ca。
使用正则表达式的火花连接
这是一篇关于我如何在 SPARK 中使用自定义 UDF 将 2 个数据集与正则表达式有效连接的技术性更强的文章
语境
在过去的几个月里,我一直在努力解决这个小问题。我有一个正则表达式模式的列表,我想知道哪篇维基百科文章包含它们。
我想以一个包含以下各列的表格作为结束:
- 维基百科文章 ID
- 维基百科文章文本
- 匹配模式(如果没有触发模式,则为空)
我正则表达式列表大约有 500 个模式长。有些是简单的单词搜索,但有些是更复杂的正则表达式。我需要一种好的方法来搜索这些模式,并找到一种方法将它们转换成上述格式。某种类型的左外连接。
设置
因为维基上有很多文章。很显然,580 万其中。我决定使用一个能够并行研究的工具,我已经决定使用来自阿帕奇基金会的 Spark 。现在,你可以选择你最喜欢的云提供商,很有可能你只需点击一下就能得到一个集群,我就是这样做的。
现在我有了自己的工作设置,我开始在网上寻找如何做到这一点。
我在网上发现的第一件事是用 rlike 做一个左外连接。
这看起来很有希望,在一个小场景中,它做的正是我想要的。
在我天真地启动了所有数据集的连接之后…我等了很长时间。实际上,只要我意识到我的集群内存不足。从那时起,我在网上找了好几个小时,尝试了很多解决方案,但我找不到任何关于如何解决这个问题的相关信息。
由于这不是什么真正重要的事情,我有几个月没有碰它,有一天在工作中,我和我的一个同事谈论它,他发现了这个问题。
问题
他是这样说的:
“首先,也是最重要的,非等价连接在 Spark 中性能很差,因为它们只能使用广播嵌套循环连接或交叉连接进行评估。
让我们假设文章包含 1,000,000 行,而模式包含 500 行。较小的数据集(本例中为粒子)将被广播。为了评估这个连接,粒子将被有效地扫描 1,000,000 次,连接谓词将被评估 500,000,000 次。更糟糕的是,该模式将被编译 5 亿次。
这基本上意味着灾难。”—迈克尔
解决方案
然后,在发现问题后,几天后他带着一件定制的 UDF 回来找我
要使用它,只需像这样查询它:
结果
我试了一下,几个小时后我就有了我想要的东西。
我决定做一个小的基准。在 Spark 中加载和缓存了这两个数据集之后,我只选择了 20 000 篇文章来尝试这两种方法。通过使用完全相同的样本,结果如下:
- 第一种技术, rlike :每秒 3 篇文章
- 第二种技术, UDF :每秒大约 5 000 篇文章
我知道这种延迟主要是由于缓存、读取和其他内存管理造成的,但这仍然是一个很大的区别。
结论
我通常不写这种类型的博客帖子,但因为我寻找解决方案,但找不到任何东西,我认为它值得分享回来。
我要感谢迈克尔·斯泰尔斯对这个项目的帮助。
这篇博客是我自己的博客咖啡和数据的转贴
spark Joy——用数据操作语法对您的事件日志说 Konmari
当你有一大堆事件日志要解析时,首选的武器应该是什么?在本文中,我将分享我尝试 spark/spark ryr 解决这个问题的经验。
在 Honestbee 捕获用户数据的事件日志🐝作为数据湖的一部分存储在 S3 自动气象站,每隔 40 分钟从段发送给我们。学习如何检索这些数据非常重要,数据(科学)团队使用这些日志来评估我们的机器学习模型(又名 canonical A B testing)的性能。
此外,我们还使用相同的日志来跟踪业务 KPI,如ClickTthroughRate、Con versionRate 和
在本文中,我将分享我们如何利用运行 Spark 的高内存集群来解析食物推荐系统生成的日志。
案例研究:食物推荐系统
每当一位诚实的顾客开始结账时,我们的 ML 模型都会尽力对你最有可能加入购物车的商品进行个性化预测,尤其是你错过的商品。
一个事后分析*,将要求我们查看日志,根据加权分布,查看用户被分配到哪个治疗组。*
现在,让我们开始吧。
首先,我们导入必要的库:
连接高记忆火花簇
接下来,我们需要连接主节点 Spark。我建议在转移到本地集群和远程集群之前,先从本地机器开始。
安装 Spark
如果它还没有安装,幸运的是sparklyr
有一个内置的功能来帮助安装
spark_install(
version = "2.4.0",
hadoop_version = "2.7"
)
注意,我们在这里也将 Hadoop 和 spark 一起安装,因为从 S3 文件系统读取文件所需的 jar 是附带的。
本地集群/单节点机箱
接下来,您将连接 spark cluster ie。建立一个SparkCconnection,通常缩写为 SC,会是你想要做的事情。
如果您连接到本地安装的 spark 集群/单节点箱,您将设置主参数为local
。
你可能需要调整内存消耗,我把它设置为 150 Gb,你的里程可能会根据你的机器而有所不同。
在笔记本电脑上运行 spark,或者作为一个更大的云实例上的单个节点安装都非常好。数据科学团队不拥有或管理集群,而是在单节点超大型 EC2 实例上运行 spark,主要用于原型开发和 EDA。然而,当任务变得太大时,您可能会考虑更繁重的任务,比如适当的集群。
远程集群
如果您正在运行 jupyterhub / Rstudio server,特别是如果您希望为每个数据科学家提供一个集群,那么与远程 spot 集群连接的选项可能会很有吸引力。
在这种情况下,python / R 进程不会在集群的同一个主节点上运行。像 Qubole 和 Databricks 这样的第三方 Spark 即服务提供商可以缓解这种情况。在 Honestbee,我们也选择了这个选项,集群由我们的 AWS 帐户下的 Qubole 提供。
PS。Qubole 是一个很好的抢断!
上面的要点建立了一个 spark 连接sc
,你将需要在大多数函数中使用这个对象。
另外,因为我们从 S3 读取,所以我们必须设置 S3 访问密钥和密码。这必须在执行spark_read_json
等功能之前进行设置
所以你会问各有什么利弊。本地集群通常适合 EDA,因为您将通过 REST API (LIVY)进行通信。
读取 JSON 日志
基本上有两种方法可以读取日志。第一种是把它们作为一个整体或者一个流来读——就像它们被倒进你的桶里一样。
有两个功能,spark_read_json
和stream_read_json
前者是批处理的,后者创建结构化的数据流。这也相当于读取您的拼花文件
成批的
应该用s3a
协议设置路径。s3a://segment_bucket/segment-logs/<source_id>/1550361600000.
json_input = spark_read_json(
sc = sc,
name= "logs",
path= s3,
overwrite=TRUE
)
下面是魔法开始的地方:
如你所见,这是一个简单的查询…
- 从
Food
垂直方向过滤所有Added to Cart
事件 - 选择以下列:
CartID
experiment_id
variant
(治疗组)和timestamp
3.删除未将用户分配给模型的事件
4.添加新列:1。fulltime
可读时间,2。一天中的某个小时
5.按照服务recommender
对日志进行分组,并计算行数
6.添加一个新列event
,其值为Added to Cart
7.按时间排序
Here’s a plot of the output so we see Model 2 is really getting more clicks than Model 1.
火花流
或者,您也可以将上述操作的结果写入结构化的火花流。
您可以使用耦合到glimpse
的tbl
函数预览这些来自流的结果。
sc %>%
tbl("data_stream") %>%
glimpse
就这样,伙计们!
此外,我想对存储模型元数据的现有选项做一个评论,特别是当您使用多个模型进行 A|B 测试时,每个模型都有多个版本。
老实说,很难知道发生了什么。
对于我的博士学位,我个人致力于使用图形数据库来存储具有复杂关系的数据,我们目前正在努力开发这样一个系统来存储与我们的模型相关的元数据。
例如:
- 它们与哪些 API 相关联
- 哪些气流/ Argo 作业与这些模型相关联
- 部署配置(舵图和地形)和其他部署元数据在哪里
- 当然还有元数据,比如成绩和分数。
原载于 2019 年 2 月 20 日ethe Leon . github . io。
窗户上的火花?入门指南。
Photo by Swaraj Tiwari on Unsplash
我们都读过这些关于大数据如何接管世界的文章。广泛用于这种大规模数据处理的工具之一是 Spark。Spark 是一个大数据分析代理,是整个行业使用的许多机器学习和数据科学的基础框架。
用你的 Jupyter 笔记本和熊猫来做数据分析项目是很好的,但是如果你想让它扩展,你需要做一点不同的设计。不幸的是,很难知道如何在您自己的工作站或笔记本电脑上实际安装螺母和螺栓,以便当您想要扩展时,它是完全相同的代码。我一直在设置我的本地 Windows 10 工作站来进行真正的数据科学工作,所以我想分享一下我的配方。有一堆脚本和演练来为 Linux 设置这些东西,所以我将在您的家庭 Windows 10 机器上设置这些令人敬畏的工具。不需要虚拟机。
先决条件
饭桶
下载并安装 Git for windows 。这将在你的开始菜单中给你 Git Bash。这将有助于拉下我为测试您的设置而创建的笔记本。除了“按原样检出,按原样提交”之外,使用安装的默认选项。这可能只是我,但我不喜欢 git 弄乱我的文件内容。
Java 和 Scala
Spark 需要运行 Java 和 Scala SBT(命令行版本),所以你需要下载并安装 Java 8+。Java 经历了一些许可变更,但是因为这是为了开发目的,所以你可以下载和使用。Scala 是一种运行在 Java 机器上的脚本语言,Spark 使用它来编写脚本。
7-Zip
如果你还没有安装 7-Zip ,它是处理各种压缩文件格式的优秀工具。
蟒蛇
Anaconda 是一个科学计算资源的包管理器,允许您轻松安装 Python、R 和 Jupyter 笔记本。在这里下载并选择 Python 3.7 64 位图形安装程序。下载并运行后,您应该会看到如下内容。如果尚未安装,请单击 Jupyter 笔记本的安装按钮。
火花
Spark 是计算集群框架。您可以下载一个. tgz 文件,并使用 7-zip 解压到一个临时位置。在 7-zip 中可能需要两个回合才能解开它,一个回合才能解开它。它应该给你留下一个 spark-2.4.3-bin-hadoop2.7,里面有一堆东西。将 spark-2.4.3-bin-hadoop2.7 文件夹移动到一个容易找到的位置,如 C:\spark-2.4.3-bin-hadoop2.7。
让我们做一些测试
检查它是否一切正常。打开一个新的 Windows 命令提示符(Win,搜索 cmd)并检查 java 是否安装正确。否则,您可能需要注销或重新启动才能使路径更新生效。
Java 语言(一种计算机语言,尤用于创建网站)
运行 java 命令,它应该会返回用法文本。
C:\Users\simon>java
Java should be located by the windows command prompt
火花
在命令提示符下导航到“C:\spark-2.4.3-bin-hadoop2.7”并运行 bin\spark-shell。这将验证 Spark、Java 和 Scala 都可以正常工作。一些警告和错误是正常的。使用“:quit”退出并返回到命令提示符。
现在,您可以运行一个圆周率的示例计算来检查它是否正常工作。
bin \ run-示例 SparkPi 10
饭桶
运行 git bash 应用程序来显示 bash 提示符。(赢,搜索 bash)
$ cd
$ mkdir 文档/开发
$ cd 文档/开发
$ git 克隆https://github.com/simonh10/SparkML.git
朱皮特
运行 Jupyter 笔记本应用程序(Win,搜索 Jupyter),这将启动 Jupyter 笔记本服务器并打开 web 浏览器。如果浏览器没有打开,请转到 http://localhost:8888 并导航到 Documents/Development/SparkML。你应该看看下面。
选择火花测试,它将打开笔记本。要运行测试,请单击“重启内核并运行全部> >”按钮(确认对话框)。这将安装 pyspark 和 findspark 模块(可能需要几分钟),并为运行集群作业创建 spark 上下文。Spark UI 链接将带您进入 Spark 管理 UI。
Click restart kernel and run all, after a few minutes the Spark UI will be available.
Spark Management UI
您现在可以在本地机器上对 Spark 集群运行 Python Jupyter 笔记本了!
接下来去哪里?
PySpark 是一种很棒的语言,可以进行大规模的探索性数据分析,构建机器学习管道,以及…
towardsdatascience.com](/a-brief-introduction-to-pyspark-ff4284701873) [## 基于 PySpark 的多类文本分类
Apache Spark 在头条新闻和现实世界中的采用率都在迅速上升,这主要是因为它能够…
towardsdatascience.com](/multi-class-text-classification-with-pyspark-7d78d022ed35) [## PySpark 备忘单:Python 中的 Spark
Apache Spark 通常被认为是一个快速、通用和开源的大数据处理引擎,内置了…
www.datacamp.com](https://www.datacamp.com/community/blog/pyspark-cheat-sheet-python)
Sparkify 流失预测,或时间序列数据的力量
在与客户打交道时,能够预测客户流失既是改善客户服务的机会,也是业务表现好坏的指标。
作为 Udacity 数据科学纳米学位顶点项目的焦点,我选择为一家名为 Sparkify 的音乐流媒体服务从事流失预测工作。我在这里提出了对我使用的数据和我得出的结论以及我选择实施的方法的见解。
项目背景
这项研究的目的是建立一个模型,能够预测音乐流媒体服务 Sparkify 的用户是否有可能流失。
一旦我们能够定义流失,我们就可以标记我们的数据,我们将要实现的机器学习模型是用于分类的监督模型。
我们有两个数据集可供使用:主数据集,大小约为 12GB,小型数据集,大小约为 128MB。
由于主数据集的大小,我们需要能够将我们的数据划分到几个集群中,我们将使用 Spark framework 而不是传统的 Pandas/Scikit-Learn Python 库来实现一切。
我们处理的数据是时间序列数据集。基本上,数据集的每一行都是用户的一个带有时间戳的操作。这使得我们试图解决的问题变得非常丰富:我们可以了解用户的行为如何随着时间的推移而变化,并且我们不仅仅将我们的预测建立在一个综合的观察上,比如平均值,或者最近发生的事件。
因此,我们模型的输入数据将包含提供用户行为变化相关信息的特征。
为了能够评估我们的预测有多准确,我们将看看 F1 的分数。事实上,这个分数提供了足够的关于预测有多好的信息,而不会对类别不平衡太敏感。
在本文的其余部分,我们将使用来自微型数据集的数据来说明我们的方法。完整数据集的研究将在另一篇独立的博客文章中进行探讨和展示。
数据清理和流失定义
我们分析的第一步当然是清理数据集。让我们快速看一下数据是什么样的。
Head of the mini-dataset
Mini-dataset schema
我们注意到一些重要的事情:
- 有两种类型的列:数值型和分类型
- 性别有空值->我们应该去掉相应的行,因为这些行很可能与任何用户都没有关联(我们希望基于用户的行为建立预测!)
- userId 有一个空字符串作为值->我们应该去掉相应的行
- 用户代理和方法似乎不是需要考虑的相关特性
- auth 有一个取消状态,这在以后查看流失率时会很有用
- 唯一用户的数量(226)和唯一位置的【115】倾向于表明人口广泛分布在多个城市,并且该特征对于预测流失可能没有很大的价值(我们将在更深入地探索数据集时确认这一点)
- 看起来数据集的绝大部分是由动作“NextSong”(大约 80%) 组成的。并且仅当页面列中的值是“NextSong”时,才填充艺术家和歌曲列。然而在剩下的 20%的行中(升级,首页,…)我们不期望找到艺术家或歌曲—这代表大约 57000 行。
现在让我们关注缺失或无效的数据。
Count of missing values per column of the mini-dataset
大约 8346 行不包含 userId 或 sessionID,因此对我们的预测目标没有用处。这个数量的行表示不到数据集的 3% ,所以我们可以简单地删除它们。
请注意,艺术家和歌曲这两列缺少更多的值。这是因为记录的性质不仅仅是播放一首歌曲,还包括登录服务、转到主页……正如我们看到的,大约 20%的行没有艺术家或歌曲值。
这个清洁步骤相当简单:
- 我们希望预测用户是否会流失,所以我们需要只保留与用户相关联的行>,我们删除 UserId 列中具有空值的行
- 我们希望能够观察用户行为的演变,因此时间信息是强制性的>我们将时间戳列转换为几个时间列(年、月、日、星期几、小时、周数)
Head of the cleaned mini-dataset
最后,我们必须定义流失意味着什么。这里有两种方法:
- 预测用户何时离开服务——使用给定用户的“取消确认”操作来定义
- 预测付费用户何时降级到免费服务——由“提交降级”定义
在本文的剩余部分,我们将同时研究这两种方法。因此,我们构建了一个包含两个新标签列的中间数据集,每个标签列对应一种流失类型。
数据探索
一旦我们有了一个干净的数据集,我们就可以开始熟悉数据所讲述的故事。
为了能够绘制这些图表,我们需要对数据进行一些转换。对于几乎每个绘图,我们都定义了几个熊猫数据帧(我们使用 Plotly 来呈现图形,并且我们需要一个熊猫数据帧作为输入)。
另请注意,当我们谈论平均值时,除非明确指定,否则它是在整个迷你数据集上计算的,因此对于该数据集中可用的两个月的数据。
以下是一些最相关的情节和我们从中得出的结论。
性别的影响
Impact of the gender on the churn probability of a user
乍一看,似乎男性取消服务的比例高于女性,并且性别对预测用户降级没有相关影响。总的来说,这个特性似乎没有带来太多的预测流失,我们将在以后决定我们是否要保留它或不训练我们的模型。
用户级别(付费/免费)的影响
Impact of the subscription level on the churn probability of a user
乍一看,付费用户取消服务的可能性略低。
用户位置的影响
Impact of the location (state) on the churn probability of a user
我们可以在这两个图中清楚地看到一个模式,其中一些州只有活跃用户,没有不稳定的用户。
每周日均收听歌曲数量的影响
Impact of the daily average number of songs per week on the churn probability of a user
看看每天平均的歌曲数量,看起来是这样的:
- 取消搅动用户的平均值低于活跃用户
- 降级搅动用户的平均值略高于活跃(付费)用户
这种趋势也可以通过每周拇指上下的次数来观察。
不同艺术家平均收听人数的影响
Impact of the number of distinct artists on the churn probability of a user
在这两种情况下,活跃用户似乎比不活跃的用户听更多种类的艺术家的音乐。
平均重复收听次数的影响
Impact of the repeat ratio on the churn probability of a user
在服务取消方面,活跃用户的歌曲重复率似乎高于不活跃用户。但是,看看服务降级,趋势就不那么明显了!
平均收听广告数量的影响
Impact of the number of ads on the churn probability of a user
我们可以观察到两种趋势:
- 与活跃用户相比,经常翻唱歌曲的用户平均会听更多的广告
- 免费用户听的广告比付费用户多 10 倍以上
登录次数和两次登录之间时间的影响
Impact of the number of logins per week on the churn probability of a user
看起来取消搅动的用户比活跃用户连接得少,尽管这种区别并不明显(趋势是大多数取消搅动的用户连接 1 到 13 次,而大多数活跃用户登录 1 到 22 次)。关于降级搅动用户,趋势相当不同——一些搅动用户实际上比活跃付费用户登录次数更多!
Impact of the time between two sessions on the churn probability of a user
从这张图表中可以看出,搅动的用户实际上比活跃的用户更经常地联系。
每次会话收听时间的影响
Impact of the listening time per session on the churn probability of a user
请注意,我们从这些图表中删除了异常值(时间戳记录中可能存在异常,因为我们让人们听一首歌超过 52 天!).
我们在这里观察到的是,大多数人每次听音乐的时间不到 20 分钟,而活跃用户往往听得更多。
活动时间的影响(动作计数)
Impact of the activity day of the month on the churn probability of a user
Impact of the activity day of the week on the churn probability of a user
看起来用户在工作日期间明显更活跃,并且平均而言,易激动用户的平均动作数量略高于活跃用户。
注册、升级和降级事件之间的时间影响
Impact of the time between registration and upgrade on the churn probability of a user
Impact of the time between registration and downgrade on the churn probability of a user
Impact of the time between upgrade and downgrade on the churn probability of a user
我们可以观察比平均水平:
- 服务流失的用户倾向于比活跃或服务降级流失的用户更早升级,这往往表明他们可能对平台有更多的期望
- 服务搅动用户倾向于在注册后比服务降级用户用更少的时间降级,但是在升级他们的服务后用更多的时间
这可以用这样一个事实来解释,即被搅动的用户可能对平台要求更高,尝试付费订阅,但不够满意并离开。而降级用户可能会对该平台如此热情,以至于他们在升级前会等待更长时间,甚至愿意在降级到免费订阅后继续使用它。
特征工程
数据探索让我们意识到了一些事情:
- 活跃用户比不活跃用户更倾向于出现在音乐服务上(更多的歌曲、重复、歌手……)
- 对于一些指标,降级和取消搅动的用户的行为并不严格相同。碰巧的是降级的用户似乎比活跃的付费用户更关注这项服务——他们如此喜欢这项服务,以至于他们实际上会继续使用它,不管作为免费用户的额外限制,比如听更多的广告(他们降级,属于免费活跃用户类别)。
- 注册的时间可以是用户类型的指标,也可以是人们活跃的时间(主要是在工作日)——因此用户的概况是不同的,可能是在工作时听音乐的用户,而喝醉的用户是更多休闲时间的用户,可能也更苛刻,因此更容易流失?
- 地点或性别似乎不是相关的标准
现在是时候利用这个探索阶段给出的灵感来设计特性了,来训练一个预测模型。在研究数据时,我们试图记住,每个用户的活动都应该单独考虑,并考虑时间间隔。在设计特征时,我们实际上要更进一步,查看指标随时间变化的方差,与较近期的事件(较高)相比,给予较旧事件(较低)不同的权重。
正如我们之前所讨论的,时间序列数据极其丰富,因为它允许我们评估随时间的变化,并给予较旧事件与最近事件不同的权重。
流程中这一步的重点是构建一个数据框架,我们可以将它作为输入传递给我们的分类模型。该输入数据帧将包含一组功能,这些功能是根据一段时间内的用户行为摘要构建的,具有一周聚合逻辑。基本上,每行将描述一个用户一周的行为,包括计算一周与下一周之间的偏差。因此,如果用户交互了 4 周,我们将在工程数据集中有 4 行,每个较新周的偏差特征都考虑了前几周的值。
这些偏差计算旨在观察用户行为的变化,与最近的事件相比,给予较旧的事件不同的权重。然而,由于我们最初不能说哪个权重更好,我们将考虑多种情况,将逻辑定义为能够调整这些权重(例如,我们可以尝试为最近/较旧的权重设置 0.8/0.2、0.5/0.5、0.6/0.4,这意味着最近的事件占值的 80%,较旧的事件占 20%)。
我们将关注的不同特性包括自注册以来的时间、歌曲数量、不同艺术家、重复次数、广告、登录次数、登录间隔时间、每次会话的收听时间,以及与前一周相比所有这些计数的变化。
Head of a summary using the 80/20 deviation ratio
为了处理这个逻辑,我们定义了:
- 一个类负责为特定的一周建立每个用户的摘要
- 一个类负责把新一周的总结与之前已经建立的总结整合在一起——每周一次添加一个,使用不同的偏差比率
- 函数编码分类特征和缩放数字特征,因为我们有非常不同比例的特征
请注意,每次更新摘要时,我们都会保存它。为了用新的一周更新它,我们加载已保存的摘要,计算偏差,并在再次保存它之前为每个用户向已加载的摘要添加一个新行。
我们不保存缩放后的模型,因为每次追加新行时,缩放步骤都需要对整个摘要进行重新计算。该步骤仅在需要时执行,在给模型供料之前。
在此过程的最后,我们保存了每个偏差比率的摘要(因此在微型数据集数据上总共保存了 3 个摘要)。
Dataframe that can be used to train a machine learning model
建模
我们正在处理一个分类问题。为此,我们将比较三种不同的模型:逻辑回归、随机森林和梯度提升树。
此分析阶段的目的是:
- 用偏离比来定义更有意义
- 超参数调整后选择性能最佳的模型
如果您还记得上面的内容,我们构建了两个数据框架:一个包含服务流失标签,另一个包含降级流失标签。让我们对这两种情况进行分析和预测。
服务流失预测
如果我们首先看一下偏差率的影响,这是我们在没有任何超参数调整的情况下得到的 F1 分数。
F1 score of three models trained with their default hyperparameters on three datasets with different deviation ratios
我们可以注意到:
- 一般来说,60/40 的比例没有其他两个比例的效果好
- 总的来说,80/20 的比率对于表现更好的两个模型(逻辑回归和梯度提升树)来说获得了最好的结果
在用 ParamGrid 运行 CrossValidator 之后,我们发现最佳组合是梯度提升树,带有{‘maxDepth’: 3,’ maxBins’: 50,’ maxIter’: 250,’ stepSize’: 0.1},在 70/30 偏差比率上,F1 得分为 87%。在验证集上运行后,我们得到了 83%的最终分数。
让我们在这里注意,在如此小的数据集上进行训练会对我们观察到的整体准确性产生影响!
降级流失预测
我们执行了完全相同的步骤,以下是我们得出的结论:
- 逻辑回归和梯度增强树的表现优于随机森林
- 两个偏差比率 80/20 和 70/30 也比 60/40 表现得更好
- 使用 CrossValidator 和 ParamGrid,我们得出结论,最佳组合是梯度提升树,具有{‘maxDepth’: 3,’ maxBins’: 50,’ maxIter’: 250,’ stepSize’: 0.1},在 70/30 偏差比率上,F1 得分为 87%。在验证集上运行后,我们得到了 83%的最终分数。
结论
该分析的目的是预测流媒体音乐服务的用户流失,我们有一个时间序列数据集,其中包括该服务用户的时间戳操作。
我们选择了两种不同的客户流失定义:服务客户流失,也就是用户离开服务;降级客户流失,也就是用户降级到免费订阅。
在数据清理的初始阶段之后,接着是探索性分析,我们设计了能够训练分类器的特征。利用时间序列数据的全部潜力,我们选择包括携带用户行为随时间变化的信息的特征。我们每周为每个用户构建一个摘要,一些特性是基于用户与服务交互的前一周的值来计算的。
使用一种简单的方法来衡量旧事件和新事件,我们尝试了几种偏差率(80/20、70/30、60/40)。
建模阶段的重点是测试模型的几种组合(逻辑回归、梯度推进树和随机森林)、偏差率和每个模型的超参数。为了评估我们的模型的性能,我们决定查看 F1 分数,因为这一指标对我们面临的阶级不平衡不太敏感(在我们处理的数据集中,有 20%以上的用户实际上发生了变化)。
我们的结论如下:
-对于服务流失,最佳组合似乎是梯度提升树,具有{‘maxDepth’: 3,’ maxBins’: 50,’ maxIter’: 250,’ stepSize’: 0.1},偏差比率为 70/30,F1 得分为 83%,最高为 87%。
-对于降级流失,最佳组合似乎是梯度提升树,具有{‘maxDepth’: 3,’ maxBins’: 50,’ maxIter’: 250,’ stepSize’: 0.1},偏差比率为 80/20,F1 得分为 81–82%。
从那时起,我们已经可以预见进一步的工作和改进:
- 改进我们衡量旧事件和新事件的方式
- 在更大的数据集(完整数据集)上测试模型,并观察性能和结果的变化
- 扩展工作以预测用户何时会流失(例如使用卡尔曼滤波器)
- 将代码转换成可以部署到任何 spark 环境的脚本
- 实施一个管道,利用新的每周数据自动重新计算预测,每周任务是重新计算摘要,重新训练模型,并根据最新日志报告潜在的变动
- 在 AWS 集群上部署这段代码
所有这些都是潜在的方向,将在其他帖子中进行描述!敬请期待!
下面的 Github 资源库 中有我所有的代码,跟着 Jupyter 笔记本中的方法走吧!
Sparkify 用户流失预测
音乐是我们生活中重要的一部分。众所周知的事实是,如果你在免费层,你将无法逃脱广告打断你的会话。这难道不令人沮丧吗?:)是…
客户参与、保持和流失一直是企业的重要课题。预测分析有助于企业采取主动行动,如提供优惠和折扣,以留住客户,并在某些情况下提高忠诚度。如今,我们产生了大量可用于此类分析的数据,数据科学对公司来说变得非常重要。
Sparkify 是由 Udacity 创建的虚拟音乐流媒体应用程序。对于这个项目,我们给出了小型、中型和大型的应用数据。我在 AWS EMR 上使用 Spark 处理过中等规模的数据。
github:https://github . com/elifinspace/spark ify/blob/master/readme . MD
探测
考虑到 ML,数据集相对较小,它有 543705 行和 18 列,包含客户信息(性别、姓名等。)和 API 事件(登录、播放下一首歌等。).
数据集从 2018 年 10 月 1 日到 2018 年 12 月 1 日。在这个数据集中,流失的客户数量是 99,客户总数是 449。相对于顾客总数来说,翻炒者的数量可能看起来很少,但当你想到这一点时,它就很大了。Sparkify 流失了 22%的客户!
以下是一些从初步调查中挑选出来的图片:
All customers event count free/paid distribution by gender
我们可以看到男性客户使用 Sparkify 活跃,付费客户比免费用户活跃。
从上面的图表我们可以得出结论,付费用户比免费用户流失更多。(标签=1 次流失)
有趣的是,状态也显示出对搅动的影响:
Churn count by States
特征工程
数据集包含了相当多的信息,我可以想到许多组合和计算来提取有用的信息。因此,我选择使用页面事件的每日和每月平均值(下一首歌,广告,…)、不同的会话、会话持续时间、项目以及项目的长度。我还添加了一些数字特征,比如注册后的时间。
A screenshot of features data frame
生成特征后,还要进行后处理,以便为建模准备数据。分类列需要编码,数字列必须在管道中组装。pyspark.ml 为我们提供了这些功能,您不需要自己进行热编码。
我使用 stringIndexer 为每个分类列创建索引(将标签的字符串列编码为标签索引列),VectorAssembler 是一个特征转换器,它将多个列组装(合并)为一个(特征)向量/列。[1]这些是作为阶段添加到管道中的,我们将对其进行数据拟合。代码和细节可以在开头提到的 Github 资源库中找到。
培训、指标和调整
我们的问题是预测哪些用户可能流失,哪些不会,所以本质上这是二元分类。
用正确的度量标准评估模型是很重要的。对于这项工作,当我选择正确的模型时,我选择以下指标:
f1 得分:解释模型的稳健性和精确性。
AUC:表示模型能够区分类别的程度,以及分类器将随机选择的正例排序高于随机选择的负例的概率。
对我们来说,对没有流失的客户进行正确分类很重要,否则我们可能会采取错误的行动,或者我们可能会采取不应该采取的行动,这可能会让客户感到困惑。
从简单到复杂的模型开始,对逻辑回归、随机森林分类和梯度推进分类器的 f1 得分和 AUC(ROC 曲线下面积)进行了比较:
logistic_regression
The F1 score on the test set is 79.83%
The areaUnderROC on the test set is 67.17%random_forest_classifier
The F1 score on the test set is 87.81%
The areaUnderROC on the test set is 95.08%gradient_boosting_classifier
The F1 score on the test set is 85.68%
The areaUnderROC on the test set is 88.83%
我选择了随机森林模型,并应用了基于 f1 分数的超参数调整:
网格搜索方式:
paramGrid = ParamGridBuilder() \
.addGrid(clf.numTrees, [20,75]) \
.addGrid(clf.maxDepth, [10,20]) \
.build()
结果:
The F1 score on the test set is 91.03%
The areaUnderROC on the test set is 93.25%
最佳参数:
maxDepth:10numTrees:75
因此,我们可以通过选择这些参数来改善我们的管道。
准确率和 f1 成绩都相当惊人。然而,我们不应该忘记,我们的数据集可能无法代表所有的客户群,我已经使用了中等规模的数据集。此外,在实际应用中,70%以上的精度被认为是良好的和可接受的。
最后,当我们检查功能的重要性,级别(免费/付费),注册以来的时间,每月平均滚动广告似乎是最重要的 3 个功能,这并没有让我感到惊讶。总是这些广告!… 😃
avg_daily_sessions : 0.04236219536508317
avg_monthly_sessions : 0.03951447819898451
avg_daily_session_duration : 0.02073779077811611
avg_monthly_session_duration : 0.016396278628309786
avg_daily_items : 0.020040353630460424
avg_monthly_items : 0.022926933384472603
avg_daily_RollAdvert : 0.019136923054740844
....
结论
Spark 的机器学习使我们能够处理大量数据,获得洞察力,并以可扩展的方式从结果中制定行动。使用大型数据集时,配置环境参数也非常重要,还应考虑模型的准确性、运行时内存使用和资源优化。在实际应用中,可行性和成本变得非常重要。这种模式可以每周或每月运行一次,具体取决于数据延迟和业务需求。应监控运营成本,并通过测试(A/B 测试)验证模型结果。应该跟踪实验结果(评估指标、KPI ),以便我们的模型和后续行动能够为业务带来价值。
【1】:https://spark.apache.org/docs/2.2.0/ml-features.html
使用 Pyspark 进行用户流失预测
预测本地机器和 AWS EMR 上的音乐流服务用户流失。
概观
用户流失(取消)预测是一个必要的预测工具。这个项目旨在为音乐流媒体服务 Sparkify 解决这个问题。通过探索 Sparkify 使用数据,该项目确定了模型学习的功能。出于计算效率的原因,一个极小的数据集(240Mb),即完整数据集(12Gb)的样本,用于在本地机器上进行初始数据探索、特征工程和建模实验。
对微小数据集的初始工作将为完整数据集确定最合适的模型和超参数,以训练最终模型。一旦确定了特征和模型,它们将用于在 AWS EMR 上建立完整数据集的模型。假设样本数据集代表总体数据集,则由样本数据集调整的超参数将很好地一般化,并且也适用于在完整数据集上建模。我们将会看到这种假设是否有助于节省一些计算资源,而不必在大数据上进行网格搜索。
从流失预测中获得的可操作的洞察力将识别可能流失的用户,并向他们发送优惠,希望阻止他们点击取消确认。
探索性数据分析
下面的数据集模式显示了可用于要素工程的数据集结构和列。从 EDA 中,我们将确定需要设计哪些特性。
root
|-- artist: string (nullable = true)
|-- auth: string (nullable = true)
|-- firstName: string (nullable = true)
|-- gender: string (nullable = true)
|-- itemInSession: long (nullable = true)
|-- lastName: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- registration: long (nullable = true)
|-- sessionId: long (nullable = true)
|-- song: string (nullable = true)
|-- status: long (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)
“流失”标签是通过识别确认其订阅取消的用户从数据集生成的。一旦识别出被搅动的用户,我们就可以看到它如何与数据集中的其他列一起工作:
Figure 1a, 1b and 1c
从上面的柱状图中,3 个主要数据列(客户流失状态、订阅水平和性别)进行了不同的排序,我们可以看到数据向一端倾斜,因此分布相当不均匀。非流失用户比流失用户多得多,性别和订阅水平也是如此。这是模型学习和度量选择需要注意的重要一点。
Figure 2a and 2c
图 2a 和 2c 显示用户升级比用户降级分布更均匀。
Figure 3
用户位置分布很广,几乎在所有位置都很稀疏。除非位置可以被分组到地理位置的类别中,否则这可能对建模没有帮助。直观地说,我们可以从位置的最后两个字符中提取州代码来创建州的分类特征。
Figure 4a
Figure 4b
页面访问也严重偏离了大多数互动集中的“NextSong”。去掉这个,用户页面交互分布本来可以更明显。捕捉用户行为的页面可能是“添加朋友”、“添加到播放列表”、“取消确认”(用于“流失”标签生成)、“降级”、“滚动广告”(有助于广告显示洞察)、“提交降级”、“提交升级”、“拇指向下”和“拇指向上”。其中,“添加朋友”和“添加到播放列表”将被用于功能工程。其他列可能有助于预测,而不是流失。
Figure 5
按照一天中的小时来累计用户数量,流媒体服务看起来会在深夜到午夜之后有更多的用户。然而,用户流失在一天中的不同时段看起来没有明显的趋势。
Figure 6
按一天中的小时排列的页面访问与其用户计数相似,具有更显著的趋势。
Figure 7
Figure 8
一周中的一天图显示了工作日期间更多的用户参与。
Figure 9
Figure 10
Figure 11
除了按天或小时汇总数据,我们还可以查看趋势如何在数据集期间演变。虽然这个微小的数据集只有两个月的数据可供探索,但这里揭示的见解相当有趣。它的周期性趋势与用户在周中更多参与的工作日图一致。它还显示了用户行为在两个月内的变化:更少的流失和更多的升级。我们需要更多的信息来了解现有用户和新用户的升级数量及其原因。
如果计算资源允许,在整个数据集中观察这种趋势如何演变将会更加有趣。
特征工程
经过探索性的数据分析,10 个假设在决定用户流失中发挥作用的特征被设计出来。接下来,作为模型训练结果的特征重要性将决定全数据集建模采用什么模型和特征。
因此,Spark 数据帧由 10 个特征和 1 个标签组成:
- 性别(二元)
- 付费或免费(二进制)
- 收听的歌曲总数(数字)
- 收听的艺术家总数(数字)
- 用户添加的播放列表中的歌曲数量(数字)
- 用户添加的好友数量(数字)
- 收听时间的总长度(数字)
- 每次播放的平均歌曲数量(数字)
- 每次会话的平均时间(数字)
- 每个用户的会话数(数字)
- 标签:每个用户的会话数(二进制)
数据帧被分成 80%用于训练,20%用于测试。
建模
通过对训练集的三重交叉验证和对四个分类器的最佳超参数的网格搜索来完成建模。通过评估结果和每个分类器的特征重要性来识别最佳模型:
- 逻辑回归
Figure 12
Evaluation result (With GridSearch):
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.6684|0.7606|0.6753| 0.7606|
+---------+------+------+--------+
Training time 19.93 minutesEvaluation result (Without GridSearch):
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.6684|0.7606|0.6753| 0.7606|
+---------+------+------+--------+
Training time 10.31 minutes
2.决策图表
Figure 13
Evaluation result (With GridSearch):
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.738|0.7746|0.7351| 0.7746|
+---------+------+------+--------+
Training time 11.05 minutesEvaluation result (Without GridSearch):
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.7355|0.7746|0.7134| 0.7746|
+---------+------+------+--------+
Training time 8.49 minutes
3.梯度增强树
Figure 14
Evaluation result: (With GridSearch)
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.8611|0.8662|0.8622| 0.8662|
+---------+------+------+--------+
Training time 187.47 minutesEvaluation result (Without GridSearch):
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.8346| 0.838|0.8158| 0.838|
+---------+------+------+--------+
Training time 11.74 minutes
4.随机森林
Figure 15
Evaluation result: (With GridSearch)
+---------+------+------+--------+
|precision|recall| f1|accuracy|
+---------+------+------+--------+
| 0.8387|0.7958|0.7277| 0.7958|
+---------+------+------+--------+
Training time 11.36 minutesEvaluation result (Without GridSearch):
+---------+------+-----+--------+
|precision|recall| f1|accuracy|
+---------+------+-----+--------+
| 0.8258|0.7746|0.683| 0.7746|
+---------+------+-----+--------+
Training time 9.23 minutes
采用 F1 分数作为模型选择的标准,因为它同时考虑了召回率和精确度。数据集具有不成比例的标注和二进制要素分布。F1 分数将确保我们的模型不会被错误分类的混淆矩阵混淆:对具有 95%真标签的数据集进行 100 %真分类以实现 95%的准确率,这与 5%的假阳性无关。
F1 得分为 0.8622,梯度推进树被选择用于使用以下超参数对大型数据集进行建模:
最佳 _ 模型。_java_obj.getMaxDepth() = 8
最佳 _ 模型。_java_obj.getMaxIter() =30
最佳 _ 模型。_java_obj.getMaxBins() = 40
通过 3 个 AWS EMR m3.xlarge 实例训练大型数据集花费了大约 160 分钟,并在其测试集上产生了以下评估结果:
Evaluation result:
+------+------+---------+--------+
| f1|recall|precision|accuracy|
+------+------+---------+--------+
|0.7254|0.7868| 0.7444| 0.7908|
+------+------+---------+--------+
微小数据与大数据
计算复杂度
由于不同学习、优化甚至数据争论算法的计算复杂性的差异,在小数据集和大数据集上执行任务会在时间和资源上造成显著差异。因此,有必要对各种算法的计算复杂性有一个很好的认识,以便在任务分配中有一个好的判断(本地与分布式计算)。
在云中执行探索性数据分析(EDA)的成本很高。像这样的大型数据集上的综合 EDA 将很容易花费每回合 25 美元的 EC2 和 EMR 费用。因此,在使用云服务进行建模之前,通过执行本地 EDA,决定大型数据集的哪些内容应该假设,哪些内容不应该假设,从而很好地理解数据集是非常重要的。
超参数调谐
在大型数据集上执行网格搜索是一项计算量很大的任务。要在一个大数据集上调优正确的超参数,如果不是 O (n 次方)时间复杂度的问题,很容易就是一个 O( n 次方)的问题。对于某些性能权衡,随机搜索可能是更快、更便宜的替代方案。超参数调整的一个更聪明的实现是结合随机搜索和网格搜索:
- 使用大型超参数网格进行随机搜索
- 使用随机搜索的结果,围绕最佳性能超参数值构建一个集中的超参数网格。
- 在缩减的超参数网格上运行网格搜索。
- 在更集中的网格上重复网格搜索,直到超过最大计算/时间预算。
具有和不具有网格搜索比较的超参数调整的比较,上述建模评估结果表明,对于实验的四个分类器,模型性能和计算资源之间的折衷是混合的。在没有网格搜索的情况下,我们节省了几乎一半的逻辑回归训练时间,因为在模型性能改进方面没有差异。对于梯度推进树(GBT ),网格搜索的性能改善是显著的:F1 值为 0.8622(使用网格搜索)对 0.8158(不使用网格搜索)。然而它的计算用网格搜索多花了将近 3 个小时。GBT 无疑是一种高计算复杂度算法,随着数据集大小的增加而呈指数增长。网格搜索是在 240Mb 的数据集上进行的,想象一下如果在 12Gb 的数据集上进行!
相比之下,决策树和随机森林的折衷更加线性,这是在计算需求略有增加的情况下的一个微小的性能改进。因此,通过了解学习算法的计算复杂性,我们可以做出更明智的决定,以实现最佳性能与计算的权衡。
特征和标签的统计比较
微小数据集和大数据集生成的特征之间肯定存在一些统计差异(统计比较见下文),因此评估结果也有很大不同:0.862(微小数据集)对 0.725(完整数据集)。下面提供了可能解释模型性能差异的两个数据集:
- 性别(二元)
Tiny dataset >
+-------+-----------------+------------------+
|summary| userId| gender|
+-------+-----------------+------------------+
| count| 448| 448|
| mean|67520.34821428571|0.4419642857142857|
| stddev|105964.5842848519|0.4971756279079038|
| min| 10| 0|
| max| 99| 1|
+-------+-----------------+------------------+
+-------------------+-------------------+
| skewness(gender)| kurtosis(gender)|
+-------------------+-------------------+
|0.23372261898725685|-1.9453737373737383|
+-------------------+-------------------+Full dataset >
+-------+------------------+-------------------+
|summary| userId| gender|
+-------+------------------+-------------------+
| count| 22278| 22278|
| mean|1498782.9615764432|0.47697279827632644|
| stddev| 288851.8472659188| 0.4994806768184825|
| min| 1000025| 0|
| max| 1999996| 1|
+-------+------------------+-------------------+
+-------------------+-------------------+
| skewness(gender)| kurtosis(gender)|
+-------------------+-------------------+
|0.09220664431939639|-1.9914979347433575|
+-------------------+-------------------+
2.付费或免费(二进制)
Tiny dataset >
+-------+------------------+------------------+
|summary| userId| level|
+-------+------------------+------------------+
| count| 691| 691|
| mean| 67259.47033285093|0.4645441389290883|
| stddev|106161.02506630082|0.4991025734716588|
| min| 10| 0|
| max| 99| 1|
+-------+------------------+------------------+
+-------------------+-------------------+
| skewness(level)| kurtosis(level)|
+-------------------+-------------------+
|0.14218137235745473|-1.9797844573545489|
+-------------------+-------------------+Full dataset >
+-------+------------------+-------------------+
|summary| userId| level|
+-------+------------------+-------------------+
| count| 22278| 22278|
| mean|1498782.9615764432| 0.5992010054762547|
| stddev| 288851.8472659188|0.49007136327332323|
| min| 1000025| 0|
| max| 1999996| 1|
+-------+------------------+-------------------+
+-------------------+-------------------+
| skewness(level)| kurtosis(level)|
+-------------------+-------------------+
|-0.4025645802550173|-1.8379417587241018|
+-------------------+-------------------+
3.收听的歌曲总数
Tiny dataset >
+-------+-----------------+------------------+
|summary| userID| num_song|
+-------+-----------------+------------------+
| count| 448| 448|
| mean|67520.34821428571|1178.5825892857142|
| stddev|105964.5842848519|1380.6279647524054|
| min| 10| 3|
| max| 99| 9767|
+-------+-----------------+------------------+
+------------------+------------------+
|skewness(num_song)|kurtosis(num_song)|
+------------------+------------------+
|2.2993135439060457| 6.814549363048505|
+------------------+------------------+Full dataset >
+-------+------------------+------------------+
|summary| userID| num_song|
+-------+------------------+------------------+
| count| 22278| 22278|
| mean|1498782.9615764432|1178.7054044348686|
| stddev|288851.84726591856| 5372.95993988227|
| min| 1000025| 1|
| max| 1999996| 778479|
+-------+------------------+------------------+
+------------------+------------------+
|skewness(num_song)|kurtosis(num_song)|
+------------------+------------------+
|135.95349045633083|19660.671563405533|
+------------------+------------------+
4.收听的艺术家总数
Tiny dataset >
+-------+-----------------+-----------------+
|summary| userId| num_artist|
+-------+-----------------+-----------------+
| count| 448| 448|
| mean|67520.34821428571|658.9799107142857|
| stddev|105964.5842848519| 625.882698051957|
| min| 10| 1|
| max| 99| 3507|
+-------+-----------------+-----------------+
+--------------------+--------------------+
|skewness(num_artist)|kurtosis(num_artist)|
+--------------------+--------------------+
| 1.4675037203781365| 2.2190317071776393|
+--------------------+--------------------+Full dataset >
+-------+------------------+-----------------+
|summary| userId| num_artist|
+-------+------------------+-----------------+
| count| 22261| 22261|
| mean|1498833.2082116706|645.0307263824626|
| stddev| 288882.1163228876|602.2479741901458|
| min| 1000025| 1|
| max| 1999996| 4368|
+-------+------------------+-----------------+
+--------------------+--------------------+
|skewness(num_artist)|kurtosis(num_artist)|
+--------------------+--------------------+
| 1.5260667285754526| 2.656182841474317|
+--------------------+--------------------+
5.用户添加的播放列表中的歌曲数量
Tiny dataset >
+-------+------------------+------------------+
|summary| userID| num_playlist_song|
+-------+------------------+------------------+
| count| 428| 428|
| mean| 65764.93457943926|28.852803738317757|
| stddev|105363.38578382804|33.913090694566286|
| min| 10| 1|
| max| 99| 248|
+-------+------------------+------------------+
+---------------------------+---------------------------+
|skewness(num_playlist_song)|kurtosis(num_playlist_song)|
+---------------------------+---------------------------+
| 2.4048605508981393| 7.736784413479519|
+---------------------------+---------------------------+Full dataset >
+-------+------------------+-----------------+
|summary| userID|num_playlist_song|
+-------+------------------+-----------------+
| count| 21260| 21260|
| mean|1498898.9698494826|28.12422389463782|
| stddev|289180.40429718536|32.27499039023108|
| min| 1000025| 1|
| max| 1999996| 340|
+-------+------------------+-----------------+
+---------------------------+---------------------------+
|skewness(num_playlist_song)|kurtosis(num_playlist_song)|
+---------------------------+---------------------------+
| 2.3914875986095625| 8.073009618134558|
+---------------------------+---------------------------+
6.用户添加的好友数量
Tiny dataset >
+-------+------------------+------------------+
|summary| userID| num_friend|
+-------+------------------+------------------+
| count| 409| 409|
| mean| 66368.3716381418|19.772616136919314|
| stddev|106064.01609030597| 22.49443576627283|
| min| 10| 1|
| max| 99| 158|
+-------+------------------+------------------+
+--------------------+--------------------+
|skewness(num_friend)|kurtosis(num_friend)|
+--------------------+--------------------+
| 2.4002649609428586| 7.678714779071942|
+--------------------+--------------------+Full dataset >
+-------+------------------+------------------+
|summary| userID| num_friend|
+-------+------------------+------------------+
| count| 20305| 20305|
| mean| 1499371.503718296| 18.79655257325782|
| stddev|288830.59626148926|20.747704116295065|
| min| 1000025| 1|
| max| 1999996| 222|
+-------+------------------+------------------+
+--------------------+--------------------+
|skewness(num_friend)|kurtosis(num_friend)|
+--------------------+--------------------+
| 2.3834675795984976| 8.182711524378096|
+--------------------+--------------------+
7.收听时间的总长度
Tiny dataset >
+-------+-----------------+------------------+
|summary| userID| time_listen|
+-------+-----------------+------------------+
| count| 448| 448|
| mean|67520.34821428571|240270.49760906256|
| stddev|105964.5842848519| 286257.952604531|
| min| 10| 131.00363|
| max| 99| 2019435.10394|
+-------+-----------------+------------------+
+---------------------+---------------------+
|skewness(time_listen)|kurtosis(time_listen)|
+---------------------+---------------------+
| 2.303088337172893| 6.791766280616427|
+---------------------+---------------------+Full dataset >
+-------+------------------+------------------+
|summary| userID| time_listen|
+-------+------------------+------------------+
| count| 22278| 22261|
| mean|1498782.9615764432|232963.16116480672|
| stddev| 288851.8472659186|273559.41985437507|
| min| 1000025| 78.49751|
| max| 1999996| 2807182.33115|
+-------+------------------+------------------+
+---------------------+---------------------+
|skewness(time_listen)|kurtosis(time_listen)|
+---------------------+---------------------+
| 2.439989611449916| 8.466901311178267|
+---------------------+---------------------+
8.每次会话的平均歌曲数量
Tiny dataset >
+-------+-----------------+-----------------+
|summary| userId| avg_songs|
+-------+-----------------+-----------------+
| count| 448| 448|
| mean|67520.34821428571|65.61168409976814|
| stddev|105964.5842848519|39.49496566617267|
| min| 10| 1.0|
| max| 99| 360.0|
+-------+-----------------+-----------------+
+-------------------+-------------------+
|skewness(avg_songs)|kurtosis(avg_songs)|
+-------------------+-------------------+
| 1.5731533253340253| 6.713093146603137|
+-------------------+-------------------+Full dataset >
+-------+------------------+-----------------+
|summary| userId| avg_songs|
+-------+------------------+-----------------+
| count| 22261| 22261|
| mean|1498833.2082116706|67.28930119633611|
| stddev| 288882.1163228875|42.00146132153544|
| min| 1000025| 1.0|
| max| 1999996| 579.0|
+-------+------------------+-----------------+
+-------------------+-------------------+
|skewness(avg_songs)|kurtosis(avg_songs)|
+-------------------+-------------------+
| 1.736381217684329| 8.096277081017185|
+-------------------+-------------------+
9.每次会话的平均时间
Tiny dataset >
+-------+-----------------+------------------+
|summary| userId| avgSessTime|
+-------+-----------------+------------------+
| count| 448| 448|
| mean|67520.34821428571| 267.78884543676|
| stddev|105964.5842848519| 164.3624010530248|
| min| 10|13.166666666666666|
| max| 99| 1502.4|
+-------+-----------------+------------------+
+---------------------+---------------------+
|skewness(avgSessTime)|kurtosis(avgSessTime)|
+---------------------+---------------------+
| 1.6431669833905285| 7.153986718436954|
+---------------------+---------------------+Full dataset >
+-------+------------------+------------------+
|summary| userId| avgSessTime|
+-------+------------------+------------------+
| count| 22278| 22278|
| mean|1498782.9615764432| 276.5377760334103|
| stddev| 288851.8472659185|180.68117321920786|
| min| 1000025| 0.0|
| max| 1999996| 5453.363730301772|
+-------+------------------+------------------+
+---------------------+---------------------+
|skewness(avgSessTime)|kurtosis(avgSessTime)|
+---------------------+---------------------+
| 2.8166509927927756| 38.509616923162206|
+---------------------+---------------------+
10.每个用户的会话数
Tiny dataset >
+-------+-----------------+------------------+
|summary| userId| session|
+-------+-----------------+------------------+
| count| 448| 448|
| mean|67520.34821428571|13.571428571428571|
| stddev|105964.5842848519| 13.17102391180226|
| min| 10| 1|
| max| 99| 92|
+-------+-----------------+------------------+
+------------------+-----------------+
| skewness(session)|kurtosis(session)|
+------------------+-----------------+
|2.3492597105561734|7.726622405284306|
+------------------+-----------------+Full dataset >
+-------+------------------+------------------+
|summary| userId| session|
+-------+------------------+------------------+
| count| 22278| 22278|
| mean|1498782.9615764432|20.431726366819284|
| stddev|288851.84726591856|1059.3297847404108|
| min| 1000025| 1|
| max| 1999996| 158115|
+-------+------------------+------------------+
+------------------+-----------------+
| skewness(session)|kurtosis(session)|
+------------------+-----------------+
|149.21426452603328|22266.26341038374|
+------------------+-----------------+
11.标签(流失)
Tiny dataset >
+-------+-----------------+-------------------+
|summary| userId| label|
+-------+-----------------+-------------------+
| count| 448| 448|
| mean|67520.34821428571|0.22098214285714285|
| stddev|105964.5842848519| 0.4153723104396363|
| min| 10| 0|
| max| 99| 1|
+-------+-----------------+-------------------+
+------------------+--------------------+
| skewness(label)| kurtosis(label)|
+------------------+--------------------+
|1.3449610206355533|-0.19107985297097096|
+------------------+--------------------+Full dataset >
+-------+------------------+-------------------+
|summary| userId| label|
+-------+------------------+-------------------+
| count| 22278| 22278|
| mean|1498782.9615764432|0.22457132597181076|
| stddev| 288851.8472659188| 0.4173090731235619|
| min| 1000025| 0|
| max| 1999996| 1|
+-------+------------------+-------------------+
+------------------+--------------------+
| skewness(label)| kurtosis(label)|
+------------------+--------------------+
|1.3200520841972045|-0.25746249500661467|
+------------------+--------------------+
结论
事后看来,当谈到对大型数据集做出的假设时,我可能会做出正确和错误的混合决定:设计什么功能,在训练中调整什么模型和超参数。回顾过去,对于训练更好的大数据预测模型应该有改进的空间,该模型可以用最少的计算时间和资源产生更有希望的评估结果。其中包括:
- 在本地机器上进行更多的数据探索和实验,以在整个数据集上获得更精确的建模方法。
- 在大型数据集中有相当大的统计差异,对整个数据集的超参数调整仍然是必要的。通过平衡网格和随机搜索进行更智能的调优。
- 测试工程要素之间的共线性,移除共线要素以节省计算资源
对于本项目之外的研究,除了流失预测之外,数据集还可用于识别更多导致(如果不能改善的话)图 9 和图 10 所示的有希望趋势的因素。
这个项目的源代码可以从我的 Github 库中获得。
关于客户流失的更多信息:
- 使用机器学习预测客户流失:主要方法和模型altex soft。
- 消除流失是增长黑客 2.0
- 如何通过预测客户流失来改善基于订阅的业务Neil Patel
AWS 上的 SparkML,6 个简单步骤
我最近完成了一个机器学习项目,由于数据集的大小和分析的计算复杂性,需要使用 Spark 和分布式计算环境。为了完成这个项目,我选择使用 Amazon Web Services(AWS)Elastic MapReduce(EMR)。因为这是我第一次尝试使用 AWS 和 EMR,所以学习曲线很陡。此外,我发现网上提供的说明和指导只有一点点帮助。为了解决 AWS EMR 实施缺乏清晰、简明和有用的说明的问题,我提供了以下六个简单的步骤来启动和运行 AWS EMR。
通过遵循以下说明,您将能够:
- 设置一个运行 SparkML 的 EMR 集群,以及
- 创建一个 EMR 笔记本(Jupyter)来执行代码
请注意,您需要一个 AWS 帐户才能使用 AWS EMR。设置 AWS 账户的说明可以在这里找到。此外,请注意,使用 EMR 和其他 AWS 服务(例如,S3 存储和数据传输)需要付费。亚马逊的收费结构可以在这里找到。
步骤 1:创建集群
首先登录 AWS 控制台。登录后,搜索 EMR。
GIF demonstrating how to search for EMR on AWS Management Console
在 EMR 主页上,单击“创建集群”按钮,然后单击“转到高级选项”
GIF showing the Create Cluster button and Go To Advanced Options link
步骤 2:选择软件配置
在软件配置页面上,您需要调整默认设置。通过仅选择 Hadoop、JupyterHub、Spark 和 Livy 来调整设置,然后单击屏幕底部的“下一步”按钮。
GIF demonstrating how to change the default software configuration
步骤 3:选择硬件配置
现在是时候为您的集群选择硬件配置了。您可以选择主实例的类型以及核心和任务实例的类型和数量。集群的每个组件都是一个弹性计算云(EC2)实例。您可以在这里找到 EC2 实例类型的详细信息。除了选择实例类型,您还可以指定分配给集群中每个实例的弹性块存储(EBS)存储量。选择所需设置后,单击页面底部的“下一步”按钮。
Screenshot of AWS EMR hardware configuration
步骤 4:选择常规选项
在“常规选项”设置中,为集群指定一个名称,并选择要记录日志的简单存储服务(S3)存储桶。如果你不熟悉 S3, Amazon 提供了创建存储桶的说明。或者,您可以使用默认设置,这将为您创建一个新的存储桶。准备好后,单击“下一步”按钮。
Screenshot of AWS EMR general options set-up page
步骤 5:设置安全选项并创建集群
创建集群的最后一步是设置安全选项。对于个人使用,默认设置应该可以。如果您计划使用 SSH 来访问集群,那么您需要分配一个 EC2 密钥对。使用页面右下方蓝色框中的链接,可以获得创建密钥对的说明。创建后,可以使用页面顶部的下拉菜单将密钥对分配给集群。准备就绪后,单击“创建集群”
群集启动需要几分钟时间。在您等待的时候,继续执行步骤 6 以创建一个笔记本实例。
Screenshot of AWS EMR security options set-up page
步骤 6:创建笔记本
在步骤 5 中单击“创建集群”后,您将进入如下所示的屏幕。点击屏幕左侧菜单中的“笔记本”。在下一个屏幕上,单击“创建笔记本”按钮。
GIF demonstrating notebook creation steps
为您的笔记本命名,并选择您希望用来运行笔记本的群集。完成后,单击“创建笔记本”按钮。
GIF demonstrating final notebook creation steps
就是这样!一旦集群启动并运行,您就可以打开笔记本了。在第一个单元格中键入“spark ”,然后运行该单元格以启动 Spark 会话。现在,您可以使用 SparkML 在 AWS EMR 集群上运行机器学习算法了。完成后,记得终止集群以避免产生额外费用。
Pytorch 中的稀疏矩阵
预定义稀疏度
第 1 部分:CPU 运行时
这是分析 Pytorch 中稀疏矩阵及其密集矩阵的执行时间的系列文章的第 1 部分。第 1 部分处理 CPU 执行时间,而第 2 部分扩展到 GPU。在深入讨论之前,让我先简单介绍一下概念。
Pytorch 是用 Python 编程语言编写的深度学习库。深度学习是科学的一个分支,近年来由于它为自动驾驶汽车、语音识别等“智能”技术提供了动力,因此越来越受到重视。深度学习的核心是大量的矩阵乘法,这非常耗时,也是深度学习系统需要大量计算能力才能变好的主要原因。不足为奇的是,研究的一个关键领域是简化这些系统,以便它们可以快速部署。简化它们的一种方法是使矩阵稀疏,这样它们的大部分元素都是 0,在计算时可以忽略。例如,这里有一个稀疏矩阵,我们称之为 S :
您可能想知道这种矩阵在哪里以及如何出现。矩阵通常用于描述实体之间的交互。例如,S 的行可能表示不同的人,而列可能表示不同的地方。这些数字表明每个人在上周去过每个地方多少次。有几个 0 是可以解释的,因为每个人只去了一两个特定的地方。稀疏矩阵的密度是它的非零元素的分数,比如 s 中的 1/3,现在的问题是,有没有更好的方法来存储稀疏矩阵以避免所有的 0?
有几种稀疏格式,Pytorch 使用的一种叫做首席运营官坐标格式。它在一个稀疏矩阵中存储非零元素(nnz)的索引、值、大小和数量。下面是在 Pytorch 中构造 S 的一种方法(输出以粗体显示,注释以斜体显示):
S = torch.sparse_coo_tensor(indices = torch.tensor([[0,0,1,2],[2,3,0,3]]), values = torch.tensor([1,2,1,3]), size=[3,4])
#*indices has x and y values separately along the 2 rows*print(S)
**tensor(indices=tensor([[0, 0, 1, 2],
[2, 3, 0, 3]]),
values=tensor([1, 2, 1, 3]),
size=(3, 4), nnz=4, layout=torch.sparse_coo)**print(S.to_dense()) #*this shows S in the regular (dense) format*
**tensor([[0, 0, 1, 2],
[1, 0, 0, 0],
[0, 0, 0, 3]])**
Pytorch 有处理稀疏矩阵的torch.sparse
API。这包括一些与常规数学函数相同的函数,例如用于将稀疏矩阵与密集矩阵相乘的mm
:
D = torch.ones(3,4, dtype=torch.int64)torch.sparse.mm(S,D) #*sparse by dense multiplication*
**tensor([[3, 3],
[1, 1],
[3, 3]])**torch.mm(S.to_dense(),D) #*dense by dense multiplication*
**tensor([[3, 3],
[1, 1],
[3, 3]])**
现在我们来看这篇文章的要点。Pytorch 中使用稀疏矩阵和函数是否节省时间?换句话说,torch.sparse API 到底有多好?答案取决于 a)矩阵大小,和 b)密度。我用来测量运行时间的 CPU 是我的【2014 年年中 Macbook Pro,配有 2.2 GHz 英特尔酷睿 i7 处理器和 16 GB 内存。所以,让我们开始吧!
大小和密度都不同
对角矩阵是稀疏的,因为它只包含沿对角线的非零元素。密度将总是 1/ n ,其中 n 是行数(或列数)。这是我的两个实验案例:
- 稀疏:稀疏格式的对角矩阵乘以密集的方阵
- 密集:使用
to_dense()
将相同的对角矩阵转换为密集格式,然后用乘以相同的密集方阵
所有元素均取自随机正态分布。输入torch.randn
就可以得到这个。以下是 n 随 2 的幂变化的运行时间:
Left: Complete size range from 2²=4 to 2¹³=8192. Right: Zoomed in on the x-axis up to 2⁸=256
密集情况下的计算时间增长大约为 O( n )。这并不奇怪,因为矩阵乘法是 O( n )。计算稀疏情况下的增长顺序更加棘手,因为我们将 2 个矩阵乘以不同的元素增长顺序。每次 n 翻倍,密集矩阵的非零元素的数量翻两番,但是稀疏矩阵的非零元素的数量翻倍。这给出了在 O( n )和 O( n )之间的顺序的总计算时间。
从右边的图中,我们看到稀疏情况下的初始增长很慢。这是因为访问开销在实际计算中占主导地位。然而,超过 n=64(即密度≤ 1.5%)标志的是稀疏矩阵比密集矩阵计算速度更快的时候。
密度固定,尺寸变化
请记住,稀疏对角矩阵的密度随着大小的增长而下降,因为密度= 1/ n 。更公平的比较是保持密度不变。下面的图再次比较了两种情况,只是现在稀疏矩阵的密度固定为 1/8,即 12.5%。因此,例如, n =2 稀疏情况将有 2 x 2 /8 = 2 个元素。
Left: Complete size range from 2²=4 to 2¹³=8192. Right: Zoom in on the x-axis up to 2⁷=128
这一次,当 n 增加一倍时,稀疏矩阵和密集矩阵的元素数量都增加了四倍。首席运营官格式需要一些时间来访问基于独立索引-值对的元素。这就是为什么稀疏矩阵计算时间以大于 O(n ) 的速度增长,导致稀疏计算时间总是比密集计算时间差。
尺寸固定,密度变化
最后,让我们研究在保持大小固定在不同值时,密度对稀疏矩阵计算时间的影响。这里的伪代码是:
cases = [2**4=16, 2**7=128, 2**10=1024, 2**13=8192]
for n in cases:
Form a dense random square matrix of size n
for nnz = powers of 2 in range(1 to n**2):
Form sparse square matrix of size n with nnz non-zero values
Compute time to multiply these matrices
我将 n 固定在 4 种不同的情况下——16、128、1024 和 8192,并绘制了每种情况下的计算时间与密度的关系。密度可以用 nnz 除以 n 得到。我将之前的两个实验——对角线矩阵和 12.5%密度——标记为每个图中的垂直虚线。将 2 个密集矩阵相乘的时间是红色水平虚线。
n=16 的情况有点不同,因为它足够小,访问开销足以支配计算时间。对于其他 3 种情况,计算时间随着 nnz 加倍而加倍,即 O(nnz)。这并不奇怪,因为矩阵大小是相同的,所以唯一的增长来自 nnz。
主要结论是,2 个密集矩阵总是比稀疏和密集矩阵相乘更快,除非稀疏矩阵具有非常低的密度。‘很低’好像是 1.5%及以下。
所以你有它。在 Pytorch 中使用当前状态的稀疏库并不会带来太多好处,除非您正在处理非常稀疏的情况(比如大小大于 100 的对角矩阵)。我的研究涉及神经网络中预定义的稀疏性( IEEE , arXiv ),其中权重矩阵是稀疏的,输入矩阵是密集的。然而,尽管预定义的稀疏度在低至 20%的密度下给出了有希望的结果,但在 1.5%及以下的密度下性能确实会下降。所以不幸的是 Pytorch 稀疏的库目前并不适合。话虽如此,Pytorch sparse API 仍处于试验阶段,正在积极开发中,因此希望新的 pull 请求能够提高稀疏库的性能。
本文的代码和图片可以在 Github 这里找到。
附录:存储稀疏矩阵
Pytorch 中的张量可以使用torch.save()
保存。结果文件的大小是单个元素的大小乘以元素的数量。张量的dtype
给出了单个元素的位数。例如,数据类型为 float32 的密集 1000x1000 矩阵的大小为(32 位 x 1000 x 1000) = 4 MB。(回想一下,8 位=1 字节)
不幸的是,稀疏张量不支持.save()
特性。有两种方法可以保存它们——(a)转换为密集的并存储它们,或者(b)将indices()
、values()
和size()
存储在单独的文件中,并从这些文件中重建稀疏张量。例如,假设spmat
是一个大小为 1000x1000 的稀疏对角矩阵,即它有 1000 个非零元素。假设数据类型为 float32 。使用(a),存储的矩阵的文件大小= (32 位 x 1000 x 1000) = 4 MB。使用(b),indices()
是数据类型 int64 的整数,有 2000 个索引(1000 个非零元素的每一个有 1 行 1 列)。这 1000 个非零values()
都是浮动 32 。size()
是一种叫做torch.Size
的特殊数据类型,它是一个由两个整数组成的元组。因此,总文件大小大约为=(64 x 2000)+(32 x 1000)+(64 x 2)= 20.2 KB。这远远小于 4 MB。更一般地,文件大小对于(a)增长为 O( n ),对于(b)增长为 O(nnz)。但是你每次从(b)加载的时候都需要重构稀疏张量。
Sourya Dey 正在南加州大学攻读博士学位。他的研究涉及探索深度学习中的复杂性降低。你可以在他的网站上读到更多关于他的信息。
Pytorch 中的稀疏矩阵
预定义稀疏度
第 2 部分:GPU 运行时
在第 1 部分的中,我分析了 Pytorch 中稀疏矩阵乘法在 CPU 上的执行时间。这里有一个快速回顾:
- 稀疏矩阵中有很多零,所以可以用不同于常规(密集)矩阵的方式存储和操作
- Pytorch 是一个用于深度学习的 Python 库,相当容易使用,但给了用户很多控制权。Pytorch 以坐标格式存储稀疏矩阵,并有一个名为
torch.sparse
的独立 API 来处理它们。 - 我使用的 CPU 是我自己的 Macbook Pro——2014 年年中,配有 2.2 GHz 英特尔酷睿 i7 处理器和 16 GB 内存。
- 第 1 部分的主要发现是: 2 密集矩阵总是比稀疏矩阵和密集矩阵增长得更快,除非稀疏矩阵的密度非常低。‘很低’好像是 1.5%及以下。
Pytorch 和 Keras 等深度学习框架的一个关键卖点是它们在GPU上的可部署性,这大大加快了计算速度。不过这也有它的局限性, Pytorch 只支持兼容 CUDA 的 GPU。CUDA 是由 Nvidia 创建的计算框架,它利用兼容 GPU 上的并行处理。并非所有的图形处理器都是如此,例如,我的 Macbook 配备了英特尔 Iris Pro 显卡,不幸的是,它与 CUDA 不兼容。所以我经常使用亚马逊网络服务(AWS)——它提供云计算资源。对于本文中的实验,我使用了一个 AWSp 3.2x large实例,它有 1 个 Nvidia V100 GPU,非常强大(例如,在 CIFAR-100 上训练 11 层深度 CNN 在 p3 上比我的 Macbook 快 100 倍)。
我在 p3 实例上重复了第 1 部分中相同的 3 个实验,做了一些小的改动,并将这些结果与 CPU 结果进行了比较。让我们开始吧!
大小和密度都不同
和第 1 部分一样,所有矩阵都是正方形的。稀疏情况是对角矩阵乘以密集矩阵。这意味着稀疏矩阵的密度是 1/n(n= #行)。密集的情况是两个密集矩阵相乘,但是,两个密集矩阵都是随机生成的,不像第 1 部分,其中一个密集矩阵只是稀疏矩阵的“密集”版本。这不会在质量上影响结果;我这样做只是因为它使计时的代码更容易设置(稍后会详细介绍)。结果如下:
访问开销在插图中占主导地位,插图在放大的 y 轴上显示了较低的 n 值。除此之外,趋势很明显——a)GPU 比 CPU 快得多,b)稀疏矩阵的执行速度比密集矩阵快,因为对角矩阵的密度极低。
密度固定,尺寸变化
稀疏矩阵现在具有 12.5%的固定密度,而密集情况与之前相同。结果如下:
虽然 GPU 的整体性能再次提高,但这一次稀疏情况比密集情况需要更多的时间。这将我们带到最后一个实验——密度对稀疏矩阵计算时间的影响。作为一个旁注,注意 GPU 上的稀疏矩阵似乎有非常大的开销,这在低 n 值时完全占主导地位。我不知道这是 GPU 的限制还是首席运营官格式造成的。
尺寸固定,密度变化
虽然低密度稀疏对角矩阵的运算速度比密集矩阵快,但对于更高密度的稀疏矩阵,这一结果是相反的。该实验改变给定大小的稀疏矩阵的密度水平,并测量执行时间。我使用的尺寸(T4 n T5 的值)和第一部分有点不同。这里是 64,256,1024 和 4096。密度是 2 的幂。与第 1 部分一样,红色虚线显示了密集情况下的计算时间,而黑色虚线显示了前两个实验的密度——对角矩阵和 12.5%。
在接近开始时, n = 64 和 256 的曲线表现有点奇怪,显示出非常低的密度值的线性增长。我能想到的唯一解释就是,当非零元素的个数非常低时,存储稀疏矩阵的内存的访问方式是不同的。似乎访问开销直到达到一个可观的密度才开始出现,然后它们在一段时间内占主导地位(平坦部分),然后让位于计算时间(指数部分)。
然而,我们再次注意到使用稀疏矩阵代替密集矩阵并没有实际的好处。对于 n = 64、256 和 1024,稀疏乘法总是比密集耗时,而对于 n = 4096,稀疏乘法比密集耗时只需要 1.5%左右的密度。
总之,这篇文章有两个主要的收获。首先,GPU(至少是现代的强大的)明显比 CPU快,如果可能的话应该一直用于线性代数。我说“如果可能”是因为 AWS 实例不是免费的。然而,如果你是一名受资助的研究人员,如果你要做线性代数密集型任务,如深度学习,说服你的主管为 AWS GPU 实例付费是非常值得的。其次,第 1 部分的主要发现也在 GPU 上得以恢复,即 2 个密集矩阵总是比稀疏和密集矩阵的乘法速度更快,除非稀疏矩阵的密度非常低(< 1.5%) 。
因此,在 Pytorch 中使用稀疏矩阵似乎没有太多好处。就我对预定义稀疏度的研究而言,1.5%及以下的密度太低,没有任何用处。其他 Pytorch 用户已经表达了稀疏矩阵的这些问题,所以这里希望开发人员能想出更有效的方法来处理稀疏矩阵。
测量执行时间
本文的代码和图片可以在 Github 这里找到。如果您浏览代码,您会发现我使用了 Python 提供的timeit
模块来测量代码片段的执行时间。为了获得最精确的时间测量,最好将setup
参数中的所有变量都设置为timeit.Timer()
。这与预先定义变量,然后让timeit.Timer()
访问全局名称空间相反,这很浪费时间。
作为timeit
模块的替代,Jupyter 笔记本提供了%timeit
魔法命令,这是我在第 1 部分中使用的。这两种方法的结果是一样的(最好是这样!),但是%timeit
更容易使用,因为它会自动计算多次运行的平均值。然而,在 AWS 上运行 Jupyter 笔记本,虽然可能,却很麻烦,而且很大程度上取决于服务器速度。因此,在第 2 部分中,我坚持使用传统的 Python 脚本。
暂时就这样吧!
Sourya Dey 正在南加州大学攻读博士学位。他的研究涉及探索深度学习中的复杂性降低。你可以在他的网站上了解更多关于他的信息。
面向外行的空间数据科学
从位置信息中提取洞察力的艺术和科学
一位老朋友上周发来一个很棒的问题:“对于我们这些门外汉来说,什么是空间数据科学(相对于常规数据科学)?为什么它意义重大?”
什么是空间数据科学?
空间数据科学是使用算法和分析技术从空间数据中提取洞察力的实践。
空间数据是由数据描述的观测值的相对位置的任何数据,并且可以在分析中用作维度。换句话说,空间数据有关于的信息,其中每个单独的数据是——因此,在那里观察是相互关联的。
A raster “hillshade” image of Colorado. Here think of each pixel as the observation — a recording of that point on Earth.
最直观的例子就是地理空间数据,它承载着地球上哪里发生事情的信息。地理空间数据可以描述自然或人类主题,如地形、政治边界、城市系统、天气和气候模式、道路网络、物种分布、消费习惯、航运和物流系统、人口统计等。
空间维度通常是对纬度(y 坐标)和经度(x 坐标)的测量,有时还有高度(z 坐标),这可以将一个点精确地放置在地球表面上、上方或下方。空间数据科学家经常使用 GIS 技能和工具(地理信息系统/科学)来操纵、分析和可视化地理空间数据。
值得注意的是,许多空间分析技术实际上与比例无关,因此我们可以将相同的算法应用于地球地图,就像我们可以应用于细胞地图或宇宙地图一样。空间数据科学技术可以应用于具有空间元素的更抽象的问题——例如,通过计算它们在一起出现的频率来分析单词的关联程度。
栅格和矢量
空间数据通常分为两类:栅格数据和矢量数据。两者都是描述空间和表现特征的方式,但是它们的工作方式完全不同。
栅格数据
光栅是一个由规则大小的像素组成的网格。通过给网格中的每个单元格分配一个值(或几个值),图像可以用数字描述为多维数组。
例如,以如下所示的 3x3 网格为例:
A 3x3 raster grid.
如果 1 表示黑色,0 表示白色,我们可以用数字表示如下:
img = [[ 1, 0, 1 ],
[ 0, 1, 0 ],
[ 1, 0, 1 ]]
栅格像元中的数字可能意味着很多事情-特定点的陆地高度或海洋深度、该点的冰或雪量、该像素内的居住人数等等。此外,可见光谱中的几乎任何颜色都可以通过代表红、绿、蓝(RGB)强度的三个数字的组合来描述——卫星图像是栅格数据结构。GeoTiff、jpg、png 和位图文件包含栅格数据。
A raster image of population in Africa, from http://www.ncgia.ucsb.edu/pubs/gdp/pop.html.
矢量数据
向量数据更抽象一些。在矢量数据集中,特征是数据集中的独立单元,每个特征通常代表一个点、线或*多边形。*这些特征用数学方法表示,通常用数字表示点的坐标或几何图形的顶点(角)。
Vector features from Saylor Academy.
点、线、多边形
举个简单的例子,下面是每种类型特征的简单数字表示:
point = [ 45.841616, 6.212074 ]line = [[ -0.131838, 51.52241 ],
[ -3.142085, 51.50190 ],
[ -3.175046, 55.96150 ]]polygon = [[ -43.06640, 17.47643 ],
[ -46.40625, 10.83330 ],
[ -37.26562, 11.52308 ],
[ -43.06640, 17.47643 ]]
// ^^ The first and last coordinate are the same
矢量要素通常包含一些描述要素的元数据,例如道路名称或一个州的人口数量。要素的这些额外的非空间元数据通常被称为“属性”,通常在“属性表”中表示。空间数据科学家通常会在分析中将空间维度(点的坐标或线和面的坐标数组)与非空间维度结合起来。乔森和。shp 文件通常包含矢量数据。
为什么这与常规的数据科学不同?
简短的回答是,它不是:空间数据科学是数据科学中的一个学科。但是,空间数据有一些特点,需要特殊处理。从编程/数据库的角度来看,数据的存储和处理方式是如此,从算法的角度来看,数据的分析方式也是如此。这意味着空间数据科学家必须学习一些概念-主要来自几何学,如在平面上表示 3D 形状-其他数据科学家可能永远不会处理这些概念。
空间数据科学工具
空间数据科学家试图理解这些空间数据集,以更好地理解他们正在研究的系统或现象。一些不可思议的(通常是免费的)软件工具使这成为可能。大多数编程语言,如 Python 、 R 和 Javascript 都有令人惊叹的空间分析库,如 geopandas 和 turf.js ,桌面程序如 QGIS 让不太懂技术的人也能访问可视化和分析空间数据。还有强大的在线工具,如 Mapbox 、 Carto 和 Google BigQuery 来帮助应对这些分析和可视化挑战。像 Leaflet 和 Mapbox GL JS 这样的 JavaScript 库使 web 开发人员能够在浏览器中创建交互式地图。
几个例子
空间数据科学家的任务可能是分析空间分布-查看点是聚集在一起、分散还是随机放置-例如,计算出建造新机场或零售中心的最佳位置,或者了解暴力或犯罪的模式。
Clustering ACLED conflict events in Yemen from a project I did at UCL.
这可能需要分析一段时间内的趋势——通过观察某个地区的投票结果如何演变,或者一个国家不同地区对某个问题的看法如何变化。
也许分析师正在分析卫星图像来绘制地图,以帮助更有效地提供紧急服务,或者计算出一个新的潜在建筑工地或自行车道有多阴暗。这可能意味着在给定当前交通状况的情况下,计算从 A 到 B 的最有效路线。尽管这个领域很小,但空间数据科学在几乎每个部门和领域都有广泛的应用。
通过使用这种专门的统计工具包来分析空间数据,空间数据科学家能够更好地理解空间关系,并可能找出为什么事情会在哪里发生,以及预测事情接下来会在哪里发生。
贷款偿还的预测空间建模
非洲有超过 3 . 5 亿人用不上电。菲尼克斯国际正在通过建造负担得起的租赁到拥有的太阳能家庭系统来解决这个问题。他们提供预先的,灵活的融资。基本的 ReadyPay 太阳能系统为照明和手机充电提供电力,并且可以升级为收音机、电视和炉灶供电。客户支付首付款,将设备带回家,并获得一周的免费电力。然后,他们以小额定期付款的方式还清贷款的剩余部分。如果客户选择不按时付款,系统将被远程锁定。一旦收到付款,客户就会收到解锁设备的代码。通过手机使用移动货币进行支付。一旦客户付清全部设备款项(通常为 24-30 个月),他们就拥有了设备,并且设备将永久解锁。
菲尼克斯国际拥有超过 50 万客户(250 万受益人),并发现贷款偿还因地理位置而异。在我作为Insight Data Science 的研究员期间,我咨询了菲尼克斯,以了解哪些地理位置可以支持他们的融资模式。他们要求:
- 使用公开可用数据预测特定地理位置还款的模型
- 预测还款模式的特征列表
他们将使用这种模式来帮助确定新的扩张领域的优先顺序。他们还将使用特征列表来通知他们的位置筛选。功能列表还将用于确定未来数据收集的优先级,并指导模型扩展。
为了预测还款,我重点关注了未付款率——在拥有设备的最初几个月内,由于未付款,设备被锁定的时间比例。通过专注于不支付率,我可以使用机器学习回归来解决提出的问题。
数据源
探究的数据类别:
气候
社会经济指标(和其他人口统计数据)
教育的
电力基础设施
首要任务是找到一些与贷款偿还预测相关的数据源。我收集了气候数据,因为它直接影响太阳能电池板的能源生产。更重要的是,气候对农业有重大影响,而农业是菲尼克斯客户的主要收入来源。我还从 Twaweza (通过人道主义数据交换)收集了各种社会经济、人口和教育数据。丰富的数据集包括资产所有权(电话、收音机、电视、自行车、牛)、水源、户主信息(教育、性别)以及每天吃饭的次数等信息。最后,脸书研究(通过 Energydata.info )有一个关于我探索的电网基础设施的优秀数据集。
特征工程
为了使回归模型适合不支付率,我需要构建一个表,其中每一行代表一个位置,每一列代表该位置的特定特性的值。
菲尼克斯以个人账户数据的形式向我提供了包括面积在内的数据。为了将帐户级别的数据转换为空间数据,我采用了一个区内每个帐户的平均不支付率。我放弃了不到十个客户的八个子账户,因为当账户数量如此之低时,不支付率是不可靠的。
为了准备与其他数据集合并的信息,我需要确定每个子县的纬度和经度。我使用了地理编码器来完成这个任务。
来自 WorldClim 的 气候数据以完全不同的方式呈现。WorldClim 提供了一百多个 geotiff 文件,每个文件都将特定变量的值编码为映射到地球上特定地理位置的像素。每个文件都编码了一个特定气候变量的信息,例如一月份的最低温度、七月份的降雨量、十一月份的太阳辐射以及生物学上有意义的变量,例如最干燥季度的降雨量。
我使用 rasterio 打开 WorldClim 文件。然后,对于菲尼克斯数据集中的每个子县,我使用 rasterio 查找该子县的经度和纬度的对应像素。然后,我可以将编码的气候信息与我的县级未支付率数据合并。
****通过人道主义数据交换获得的社会经济、人口和教育数据是来自个体户主的调查数据。我对每个县的数据进行了平均。我对分类数据进行了一次性编码,使其适合机器学习。我还必须删除一些会混淆机器学习模型的无关数据,例如子县的标识号。最后,我构建了一些新的特性,比如男性和女性的数量、总人口以及户主的比例。
由脸书研究所免费提供的 电力基础设施(EI)位置数据,仅仅是纬度和经度的列表。靠近电力基础设施表明替代能源选择、现有能源需求,或许还有区域社会经济。我使用 geopy 计算每个子区到最近 EI 的距离、到最近 EI 的平均距离以及十公里内 EI 位置的数量。
贷款偿还预测
构建要素并合并所有数据集后,我使用它们来预测拒付率。我选择了一个随机森林回归模型,因为它可以处理特征空间中的强非线性,对过度拟合具有鲁棒性,可以区分大量特征,并提供关于哪些因素有意义的推断。
The Random Forest model performs significantly better than the baseline model. The baseline is the naive model that does not discriminate loan repayment geographically.
为了量化模型的性能,我与基准模型进行了比较,基准模型是一个简单的模型,它的预测仅仅是每个案例中的平均不支付率。我的随机森林模型预测的不支付率比基线模型的均方差提高了 20%!菲尼克斯国际公司计划使用这种模式来指导其服务在新地点的推广。由于贷款偿还率是菲尼克斯收入的直接来源,这一改进的预测提供了重要的 价值。****
还款率的预测功能
最重要的特征是手机拥有率。一般来说,手机拥有率是一个地区社会经济的标志。更多的手机拥有者与更高的还款率相关。更多的手机意味着更多的电力需求,这种电力需求增加了支付太阳能账单的动力。事实上,菲尼克斯太阳能系统的主要用途之一就是手机充电。更多的手机用户也意味着更多的人可以打电话。手机也是主要的支付方式,所以更多的手机拥有者意味着更多的支付机制。
The most important characteristics for the prediction of loan repayment. The features are color-coded by category.
顶级要素列表包含社会经济、气候、教育、人口统计和电力基础设施数据。社会经济因素,如资产所有权和每天吃饭的次数有很强的预测性。与农业相关的气候因素也是如此,如生长季节中期的降水和生物气候变量。教育也是一个重要因素,尤其是母亲的学校教育和英语水平。
结论
空间模型可以改善贷款偿还的预测。菲尼克斯国际将使用我建立的模型来帮助其未来的扩张,因为它的表现明显优于基线。
The model will continue to improve with the addition of more data.
我已经列出了最关键的特性和数据类型。基本因素是社会经济、气候、教育和人口统计。手机和其他资产的所有权是预测信息。与农业相关的天气和教育指标也是如此。
未来方向
通过添加新数据,可以立即改进模型;随着测试集性能显著低于训练集性能,学习曲线仍在增长。只有气候的模型可以用免费获得的数据立即扩展。模型的各个方面将需要新的数据集,但更多的数据可以从我使用的相同资源中获得。
随着新数据源的加入,该模型可能会得到改进。经济数据,尤其是与农业相关的数据,前景看好。手机使用数据也是如此,如无线网络可用性和移动货币供应商的位置。
空间建模花絮:蜂巢还是渔网?
为什么我们在 场所 都喜欢六边形网格!
Source: Wikipedia
介绍
如果你是一个像优步一样的两度市场,你要迎合数百万请求搭车的用户,通过你的司机伙伴接受并满足这些请求。对于像 Swiggy 这样的三级市场,还添加了另一个静态组件(如餐馆或商店),交付合作伙伴在那里收取订单。
借用一句名言,“任何事情都发生在某个地方”——所有这些描述的事件和行为都发生在一个特定的地点!
通常,公司最终没有利用数据中的纬度/经度部分,而是在城市层面进行分析。但是,城市太大,地理位置多样,参数变化太大!
区域级别的多边形更加实用,但仍然很宽泛。它们没有统一的形状或大小,而且经常变化。即使是由运营团队的本地知识绘制的区域或集群也需要更新,并且具有任意的边。
Source: Standford Gaming Principles
理解您的空间数据并获得精确的见解需要这些分析变得更加精细和统一。
网格系统将这种细粒度带到了桌面上。它能出色地将你所有的纬度放进“细胞”里。这些像元还可以进行聚类以表示特定的邻域或区域,并且可以在不同的级别进行聚合。
Source: Uber
因此,该系统对于处理大型空间数据集以匹配城市中分散的供需变得至关重要。
网格是什么意思?
在空间数据科学中,我们使用规则多边形网格在表面上重复,边对边地覆盖任何空间,没有重叠和间隙——这种现象称为 镶嵌 **。**每个单元可分配一个唯一的 id 用于空间索引(聚合该单元内的点)。
已经提出了各种不同形状的网格,包括正方形、矩形、三角形、六边形或菱形。全球网格系统覆盖整个地球表面。
如果你是一家超本地化、随需应变的公司,网格可以小到 1 平方英尺。km 对于运行与位置相关的模型和在实时运行的模型非常有用。 示例包括高需求地区的激增定价、低需求地区的促销&地面送货人员的配送模式。
有哪些不同的网格类型?
只有三种类型的网格可以镶嵌:正方形、等边三角形和六边形。
Source: Wikipedia
1.方形网格:
方形格网最常见的应用出现在栅格数据集和地理哈希中。就本文而言,我们将重点关注 geohashes。
Geohash 是一种分层数据结构,用于将 2D 空间点(lat & long)转换为字母和数字的短字符串。他们将世界划分为 4 行 8 列 32 个单元的网格。
你可以把每个单元格分成 32 个单元格。因此,geohash 的字符串越长,精确度就越高!如果 geohashes 有一个共同的前缀,您也可以轻松识别它们是否靠得很近。所以,共同前缀越长,它们越接近。
例如,丹麦日德兰半岛顶端附近的坐标对 (57.64911,10.40744) 产生了略短的哈希u 4 prudqqvj。[1]
2.三角形网格:
三角形网格不是很常用。此外,他们不熟悉,其中一个原因是他们周长很大,面积很小,这意味着很难在地图上把他们拼凑在一起。
另一个原因是,每个三角形只连接到三个相邻的三角形,这限制了移动和建立连接的选项数量。(查看下图)
此外,对于六边形和正方形来说,总是有两个面相互平行,而对于三角形来说,有两个方向的直线平行于运动轴的中心。因此,在某种程度上,不存在完全对称。【2】
3.六边形网格:
除了看起来吸引人之外,六边形比 geohashes 更对称。它们在形状上非常接近圆形,以提供更精确的采样。【3】
因此,这个系统越来越多地被像优步这样的公司所采用。
Source: Wikipedia
**有趣的事实:**一个六边形网格和三角形网格是彼此的对偶——在每个六边形的中心放一个点&将它们连接到所有相邻的六边形,你得到一个三角形网格,反之亦然!【4】
在六边形网格的顶点上移动相当于在三角形网格的空间中玩耍。另一方面,正方形网格本身就是一个对偶。
为什么是六边形?
在场所,我们从客户那里得到很多的一个问题是,“为什么我们使用 hexbins 而不是 geohashes?”
嗯,选择取决于您的具体用例,无论您使用什么,您都必须做出一些权衡。所以,我们来取一些参数,深潜一下。
与最近像元的距离:
Source: Uber
此图显示了三角形、正方形和六边形的中心到其邻居的距离。
三角形有三种距离(穿过边、顶点和穿过边的中心),正方形有两种距离(穿过边和对角线),六边形只有一种距离——这是三角形不受欢迎的另一个原因。
六边形的这一属性使得执行分析变得非常容易,并且当您的分析包括连通性或运动方面时,这是首选。【4】
六边形中的所有邻居围绕它形成一个半径相等的环。 kRing 函数提供原点索引“k”距离内的格网单元。在下图中,这是阴影六边形和正方形的第次克里金。
Source: ESRI
在曲面上安装:
六边形是镶嵌中填充圆形并减少边缘效应的最密集方式。(圆的周长与面积之比最小,但不能形成连续的网格)。
多边形与圆越相似,靠近边界区域的点就越靠近中心。因此,与等面积的正方形或三角形相比,六边形内的任何一点都更靠近其中心。
现在,当大面积开始起作用并且地球的曲率是重要的考虑因素时,六边形因此更适合曲率并且遭受较少的扭曲。【5】
数据中的显式模式:
六边形允许数据中图案的任何曲率容易且明确地显示,因为它们分解了线条。
对于像正方形和长方形这样的线性图形,这就变得棘手了。这些形状将我们的注意力吸引到阻碍数据中存在的模式的直的、连续的和平行的线上。请参考下图。【6】
Source: GIS Exchange
为什么是 geohashes?
这让我想到了我的下一个问题,“什么时候人们会使用渔网或方形网格?”
细胞的聚集/分裂:
不同类型的模型需要不同的粒度,这就是聚合和划分变得重要的地方。
如果需要增加一个正方形网格的空间分辨率,只需要把它分成 4 份就可以了。类似地,要聚合,您需要将四个网格合并为一个。
对于六边形,聚合和分割在不同的比例下是不一致的,如下图所示。更精细的单元仅大致包含在父单元中。【7】
Source: Wikipedia
对于等级分析,正方形优于六边形。组合方形网格相当简单。合并构建在同一模板上的多个格网不需要任何空间操作,您可以使用矩阵代数。
非常直观和熟悉:
我们也用“方格来思考。上、下、左、右简单易懂。我们在正方形和长方形上建造了城市和文明。由于我们的主坐标系是平方的,人们发现很难在其他系统上工作。
它们有时也用于连通性分析,因为它们有八个邻居(包括对角线)。
一些现实生活中的例子
六边形在自然界中广泛存在。例如蜂窝、石墨、苯、硅烯等。中国跳棋是在一个六边形的格子上玩的,国际象棋的几个变种也被发明在一个六边形的棋盘上。
六边形网格是许多战争游戏发行商和一些其他游戏(如卡坦的定居者)的一个显著特征!【8】
附:如果你对此感兴趣,请查看我们的 博客 获取更多与地图、位置智能和空间建模相关的文章。
P.P.S .我们在locale . ai招聘!联系我aditi @ locale . ai!
参考资料:
如果您希望阅读更多内容,请查看链接进行深入研究:
**【2】https://board games . stack exchange . com/questions/633/why-are-there-less-board-games-with-a-a-triangular-grid 【3】https://www.redblobgames.com/grids/hexagons/ 【4】**
带 Kaldi 的扬声器二进制化
随着语音生物识别和语音识别系统的兴起,处理多个说话者的音频的能力变得至关重要。本文是使用 Kaldi X-Vectors(一种最先进的技术)完成这一过程的基础教程。
在大多数真实世界的场景中,语音不会出现在只有一个说话者的明确定义的音频片段中。在我们的算法需要处理的大多数对话中,人们会相互打断,而切断句子之间的音频将不是一项简单的任务。
除此之外,在许多应用中,我们希望在一次对话中识别多个发言者,例如在编写会议协议时。对于这样的场合,识别不同的说话者并连接同一说话者下的不同句子是一项关键任务。
说话人二进制化 就是针对这些问题的解决方案。通过这一过程,我们可以根据说话者的身份将输入音频分割成片段。这个问题可谓是谁在什么时候讲的?"在一段音频中。
Attributing different sentences to different people is a crucial part of understanding a conversation. Photo by rawpixel on Unsplash
历史
第一个基于 ML 的扬声器二进制化工作开始于 2006 年左右,但直到 2012 年左右才开始有重大改进( Xavier,2012 ),当时这被认为是一项极其困难的任务。那时的大多数方法都是基于 GMM 或嗯的(比如 JFA ),不涉及任何神经网络。
一个真正的重大突破发生在 LIUM 的发布上,这是一个用 Java 编写的致力于说话者二进制化的开源软件。第一次有了一种自由分布的算法,它可以以合理的精度执行这项任务。LIUM 核心中的算法是一种复杂的机制,它将 GMM 和 I-Vectors 结合在一起,这种方法曾经在说话人识别任务中取得了最先进的结果。
The entire process in the LIUM toolkit. An repetitive multi-part process with a lot of combined models.
今天,这种复杂的多部分算法系统正在被许多不同领域的神经网络所取代,如图像分割甚至语音识别。
x 向量
最近的一项突破是由 D. Snyder,D. Garcia-Romero,D. Povey 和 S. Khudanpur 在一篇名为“ 用于文本无关说话人验证的深度神经网络嵌入 ”的文章中发表的,该文章提出了一个模型,该模型后来被命名为“X-Vectors”。
A diagram of the proposed neural network, The different parts of the networks are highlighted on the right. From The original article.
在该方法中,网络的输入是以 MFCC 形式的原始音频。这些特征被输入到一个神经网络中,该网络可以分为四个部分:
- 帧级层——这些层本质上是一个 TDNN (时间延迟神经网络)。TDNN 是在神经网络日益普及之前的 90 年代发明的一种架构,然后在 2015 年被“重新发现”为语音识别系统的一个关键部分。这个网络本质上是一个全连接的神经网络,它考虑了样本的时间滑动窗口。它被认为比 LSTM 快得多。
- 统计池 -因为每一帧给我们一个向量,我们需要以某种方式对这些向量求和。在这个实现中,我们取所有向量的平均值和标准偏差,并将它们连接成一个代表整个片段的向量。
- 全连接层——向量被送入两个全连接层(分别有 512 和 300 个神经元),我们稍后会用到。第二层将具有 ReLU 非线性。
- Softmax 分类器 -一个简单的 Softmax 分类器,它在 ReLU 之后获取输出,并将片段分类到不同的说话者之一。
A visualization of a TDNN, The first part of the X-Vectors System.
X-Vectors 的真正力量不在于(仅仅)对不同的说话者进行分类,还在于它使用两个完全连接的层作为整个片段的嵌入式表示。在文章中,他们使用这些表示对一个完全不同于他们训练数据集的数据集进行分类。他们首先为每个新的音频样本创建嵌入,然后用 PLDA 后端相似性度量对它们进行分类。
X 向量二分化
在我们理解了我们可以使用这些嵌入作为每个音频样本中的说话者的表示之后,我们可以看到该表示可以如何用于分割音频样本的子部分。那种方法在“ 二进制化很难:首届迪哈德挑战赛 中的一些经验教训”一文中有所描述。DIHARD 挑战特别困难,因为它包含了从电视节目到电话到儿童对话的 10 个不同的音频域,此外还有 2 个域只出现在验证集中。
在文章中,他们描述了许多实践,这些实践将他们的二进制化算法带到了当前的艺术水平。虽然使用不同的技术(如变分贝叶斯)极大地提高了模型的准确性,但它本质上是基于相同的 X 向量嵌入和 PLDA 后端。
From the DIHARD Callenge article. You can see the major improvements of using X-Vectors. Previous works are in blue and the state of the art results are in red.
如何与卡尔迪合作
首先,如果你之前没有用过 Kaldi,我强烈推荐你阅读我的第一篇关于使用 Kaldi 的文章。没有语音识别系统的经验,很难开始使用该系统。
其次,你不需要重新训练 X-Vectors 网络或 PLDA 后端,你可以直接从官方网站下载。如果你仍然想从头开始进行完整的训练,你可以遵循卡尔迪项目中的 call_home_v2 recipe 。
现在你有了一个模型,不管你是创建了这个模型还是对它进行了预训练,我都将经历二进制化过程的不同部分。本演练改编自 GitHub 上的不同评论,主要是大卫的评论和文档。
准备数据
你首先需要有一个普通的 wav.scp 和 segments 文件,方式与 ASR 项目中的相同。如果您想要一种简单的方法来创建这样的文件,您可以始终使用compute _ VAD _ decision . sh脚本,然后在输出中使用VAD _ to _ segments . sh脚本。如果您不想分割音频,只需从头到尾将片段映射到话语。
接下来,您需要创建一个 utt2spk 文件,该文件将片段映射到话语。您可以在 Linux 中通过运行命令awk ‘{$1, $2}’ segments > utt2spk
简单地做到这一点。接下来,要创建其他必要的文件,只需将所有文件放在一个文件夹中,然后运行 fix_data_dir.sh 脚本。
创建特征
现在,您需要为音频创建一些特征,这些特征稍后将成为 X 向量提取器的输入。
我们将以与 ASR 项目中相同的方式开始创建 MFCC&CMVN。请注意,您需要有一个与您接受的培训相匹配的 mfcc.conf 文件。如果您使用预训练模型,请使用这些文件。
对于 MFCC 创建,运行以下命令:
**steps/make_mfcc.sh --mfcc-config conf/mfcc.conf --nj 60 \
--cmd "$train_cmd_intel" --write-utt2num-frames true \
**$data_dir** exp/make_mfcc **$mfccdir****
然后对 CMVN 运行这个命令:
**local/nnet3/xvector/prepare_feats.sh — nj 60 — cmd \ "$train_cmd_intel" **$data_dir** **$cmn_dir** **$cmn_dir****
完成数据后,使用utils/fix_data_dir.sh **$data_dir**
修复数据目录,然后使用cp **$data_dir**/segments **$cmn_dir**/
将段文件移动到 CMVN 目录,之后使用utils/fix_data_dir.sh **$cmn_dir**
再次修复 CMVN 目录。
创建 X 向量
下一步是为你的数据创建 X 向量。我这里指的是导出文件夹,其中有 X-Vectors 作为**$nnet_dir**
,如果你是从 Kaldi 网站下载的,使用路径"exp/X Vectors _ sre _ combined"然后运行该命令:
**diarization/nnet3/xvector/extract_xvectors.sh --cmd \ "$train_cmd_intel --mem 5G" \
--nj 60 --window 1.5 --period 0.75 --apply-cmn false \
--min-segment 0.5 **$nnet_dir** \
**$cmn_dir** **$nnet_dir**/exp/xvectors**
注意,在这个例子中,我们使用 1.5 秒的窗口,每个窗口有 0.75 秒的偏移。降低偏移可能有助于捕捉更多细节。
用 PLDA 评分 X 向量
现在你需要对 X 向量和 PLDA 后端之间的成对相似性进行评分。使用以下命令执行此操作:
**diarization/nnet3/xvector/score_plda.sh \
--cmd "$train_cmd_intel --mem 4G" \
--target-energy 0.9 --nj 20 **$nnet_dir**/xvectors_sre_combined/ \
**$nnet_dir**/xvectors **$nnet_dir**/xvectors/plda_scores**
二化
最后一部分是对你创造的 PLDA 分数进行聚类。幸运的是,这也有一个脚本。但是,你可以通过两种方式做到这一点,有监督的方式和无监督的方式。
在监督的方式下,你需要说出每句话中有多少人说话。当你在打一个只有两个发言人的电话时,或者在开一个有已知数量发言人的会议时,这尤其容易。要以监督的方式对分数进行聚类,您首先需要创建一个文件,将来自 wav.scp 文件的话语映射到该话语中的发言者数量。该文件应该被命名为 reco2num_spk ,看起来应该像这样:
**rec1 2
rec2 2
rec3 3
rec4 1**
一个重要的注意事项是,您需要根据说话者的数量来映射每个话语,而不是每个片段。创建了 reco2num_spk 文件后,您可以运行以下命令:
**diarization/cluster.sh --cmd "$train_cmd_intel --mem 4G" --nj 20 \
--reco2num-spk **$data_dir**/reco2num_spk \
**$nnet_dir**/xvectors/plda_scores \
**$nnet_dir**/xvectors/plda_scores_speakers**
如果你不知道每句话有多少个说话者,你总是可以以一种无人监督的方式运行聚类,并尝试在脚本中调整阈值。一个好的起始值是 0.5。要以无人监督的方式进行聚类,请使用相同的脚本,但使用以下方式:
**diarization/cluster.sh --cmd "$train_cmd_intel --mem 4G" --nj 40 \
--threshold **$threshold** \
**$nnet_dir**/xvectors/plda_scores \
**$nnet_dir**/xvectors/plda_scores_speakers**
结果
在集群化之后,您将在**$nnet_dir**/xvectors/plda_scores_speakers
目录中拥有一个名为 rttm 的输出文件。该文件将类似于以下内容:
**SPEAKER rec1 0 86.200 16.400 <NA> <NA> 1 <NA> <NA>`
SPEAKER rec1 0 103.050 5.830 <NA> <NA> 1 <NA> <NA>`
SPEAKER rec1 0 109.230 4.270 <NA> <NA> 1 <NA> <NA>`
SPEAKER rec1 0 113.760 8.625 <NA> <NA> 1 <NA> <NA>`
SPEAKER rec2 0 122.385 4.525 <NA> <NA> 2 <NA> <NA>`
SPEAKER rec2 0 127.230 6.230 <NA> <NA> 2 <NA> <NA>`
SPEAKER rec2 0 133.820 0.850 <NA> <NA> 2 <NA> <NA>`**
在该文件中,第 2 列是来自 wav.scp 文件的记录 id,第 4 列是当前片段的开始时间,第 5 列是当前片段的大小,第 8 列是该片段中发言者的 ID。
至此,我们完成了二化过程!我们现在可以尝试使用语音识别技术来确定每个说话者说了什么,或者使用说话者验证技术来验证我们是否知道任何不同的说话者。
如果你喜欢你读到的内容,你可以随时关注我的 推特 或者在这里给我留言。我还写了另一篇关于 Kaldi 的文章并且我还有一个 GitHub repo 充满了关于 Kaldi 的有用链接,请随意投稿!
谱聚类
其工作原理背后的直觉和数学!
Photo by Alexandre Chambon on Unsplash
什么是集群?
聚类是一种广泛使用的无监督学习方法。分组是这样的,一个聚类中的点彼此相似,而与其他聚类中的点不太相似。因此,找到数据中的模式并为我们分组是由算法决定的,根据所使用的算法,我们可能会得到不同的聚类。
聚类有两种大致的方法:
1 .紧密度-彼此靠近的点落在同一个聚类中,并且围绕聚类中心紧密。密切程度可以通过观察之间的距离来衡量。例如:K-Means 聚类
2。连通性 —相互连接或紧邻的点被放在同一个群集中。即使 2 个点之间的距离更小,如果它们不相连,它们也不会聚集在一起。谱聚类是一种遵循这种方法的技术。
这两者之间的差异可以通过下图很容易地显示出来:
Figure 1
谱聚类是如何工作的?
在谱聚类中,数据点被视为图的节点。因此,聚类被视为一个图划分问题。然后,节点被映射到一个低维空间,该空间可以很容易地被分离以形成集群。需要注意的重要一点是,没有对集群的形状/形式进行假设。
谱聚类的步骤是什么?
谱聚类包括 3 个步骤:
1 .计算相似度图
2。将数据投影到低维空间
3。创建集群
步骤 1——计算相似度图:
我们先创建一个无向图 G = (V,E),顶点集 V = { v1,v2,…,vn } = 1,2,…,n 个数据中的观测值。这可以由邻接矩阵来表示,该矩阵以每个顶点之间的相似性作为其元素。为此,我们可以计算:
**1)ε-邻域图:**这里我们连接所有成对距离小于ε的点。由于所有连接点之间的距离大致相同(最多为ε),对边进行加权不会将更多的数据信息合并到图表中。因此,ε-邻域图通常被认为是一个不加权的图。
**2) KNN 图:**这里我们用 K 个最近邻连接顶点 vi 和顶点 vj 如果 vj 在 vi 的 K 个最近邻中。
但是如果最近的邻居不是对称的,我们可能会有一个问题,即如果有一个顶点 vi 以 vj 作为最近的邻居,那么 vi 不一定是 vj 的最近邻居。因此,我们最终得到一个有向图,这是一个问题,因为我们不知道在这种情况下两点之间的相似性意味着什么。有两种方法可以使这个图没有方向。
第一种方法是简单地忽略边的方向,即如果 vi 是 vj 的 k 个最近邻居之一,或者如果 vj 是 vi 的 k 个最近邻居之一,我们用一条无向边连接 vi 和 vj 。得到的图就是通常所说的 k 近邻图。
第二个选择是连接顶点 vi 和 vj ,如果 vi 都是 vj 的 k 个最近邻,并且 vj 是 vi 的 k 个最近邻。得到的图称为互 k 近邻图。
在这两种情况下,在连接适当的顶点后,我们通过相邻点的相似性对边进行加权。
3)全连通图:为了构建这个图,我们简单地将所有的点相互连接起来,我们通过相似度 sij 对所有的边进行加权。该图应该模拟局部邻域关系,因此使用相似性函数,例如高斯相似性函数。
这里,参数σ控制邻域的宽度,类似于ε-邻域图中的参数ε。
因此,当我们为这些图中的任何一个创建邻接矩阵时,当这些点靠近时, Aij ~ 1,如果这些点远离,则 Aij → 0。
考虑以下具有节点 1 至 4、权重(或相似度) wij 及其邻接矩阵的图:
L: Graph, R: n x n symmetric adjacency matrix
步骤 2 —将数据投影到低维空间:
正如我们在图 1 中看到的,同一聚类中的数据点也可能相距很远——甚至比不同聚类中的点更远 *。*我们的目标是转换空间,这样当两个点靠近时,它们总是在同一个簇中,当它们远离时,它们在不同的簇中。我们需要把我们的观察投射到一个低维空间。为此,我们计算了图拉普拉斯,这只是图的另一种矩阵表示,可以用于发现图的有趣属性。这可以计算为:
Computing Graph Laplacian
Graph Laplacian for our example above
计算图拉普拉斯 L 的全部目的是找到它的特征值和特征向量,以便将数据点嵌入到低维空间中。所以现在,我们可以继续寻找特征值。我们知道:
让我们考虑一个数字的例子:
然后我们计算 l 的特征值和特征向量。
步骤 3 —创建集群:
对于这一步,我们使用对应于第二特征值的特征向量来给每个节点赋值。经计算,第二特征值为 0.189,相应的特征向量 v2 = [0.41,0.44,0.37,-0.4,-0.45,-0.37]。
为了获得二分聚类(两个不同的聚类),我们首先将 v2 的每个元素分配给节点,使得 {node1:0.41,node2:0.44,… node6: -0.37} 。然后,我们分割节点,使得值为> 0 的所有节点都在一个集群中,而所有其他节点都在另一个集群中。因此,在这种情况下,我们在一个集群中得到节点 1,2 & 3,在第二个集群中得到 4,5 & 6。
值得注意的是,第二个特征值表示图中节点的连接紧密程度。对于好的、干净的划分,第二特征值越低,聚类越好。
Eigenvector v2 gives us bipartite clustering.
对于 k 簇,我们必须修改我们的拉普拉斯算子以使其规范化。
因此我们得到:
Normalized Laplacian — Ng, Jordan, Weiss
谱聚类的优缺点
优势:
- 不对聚类的统计数据做出强有力的假设-聚类技术(如 K-Means 聚类)假设分配给聚类的点是关于聚类中心的球形。这是一个很强的假设,并不总是相关的。在这种情况下,谱聚类有助于创建更准确的聚类。
- 易于实现并给出良好的聚类结果。它可以正确地对实际上属于同一类但由于维数减少而比其他类中的观测值更远的观测值进行聚类。
- 对于几千个元素的稀疏数据集来说相当快。
缺点:
- 在最后一步中使用 K-Means 聚类意味着聚类不总是相同的。它们可以根据初始质心的选择而变化。
- 对于大型数据集,计算成本很高—这是因为需要计算特征值和特征向量,然后我们必须对这些向量进行聚类。对于大型密集数据集,这可能会大大增加时间复杂度。
在这篇博客中,我解释了谱聚类背后的数学原理。欢迎任何反馈或建议!同时,一定要看看我的其他博客。
参考
- https://calculated content . com/2012/10/09/spectral-clustering/
- 【http://ai.stanford.edu/~ang/papers/nips01-spectral.pdf
- https://www.youtube.com/watch?v=zkgm0i77jQ8
关于我
一名数据科学家,目前正在保护 AWS 客户免受欺诈。以前的工作是为金融领域的企业构建预测和推荐算法。
领英:https://www.linkedin.com/in/neerja-doshi/
谱聚类
基础与应用
介绍
在这篇文章中,我们将讨论图表和其他数据的谱聚类的来龙去脉。聚类是无监督机器学习的主要任务之一。目标是将未标记的数据分配到组中,其中相似的数据点有希望被分配到同一个组中。
谱聚类是一种源于图论的技术,该方法用于根据连接节点的边来识别图中的节点社区。该方法是灵活的,并且允许我们对非图形数据进行聚类。
谱聚类使用从图或数据集构建的特殊矩阵的特征值(谱)的信息。我们将学习如何构建这些矩阵,解释它们的光谱,并使用特征向量将我们的数据分配给聚类。
特征向量和特征值
这个讨论的关键是特征值和特征向量的概念。对于矩阵 A,如果存在一个不全是 0 的向量 x 和一个标量λ,使得 Ax = λx,则称 x 是 A 的一个具有相应特征值λ的特征向量。
我们可以认为矩阵 A 是一个将向量映射到新向量的函数。当 A 作用于大多数向量时,它们最终会到达完全不同的地方,但是本征向量只改变了大小。如果你画一条穿过原点和特征向量的线,那么在映射之后,特征向量仍然会落在这条线上。向量沿直线缩放的量取决于λ。
使用 Python 中的 numpy,我们可以很容易地找到矩阵的特征值和特征向量:
Finding eigenvectors in Python.
特征向量是线性代数的重要组成部分,因为它们有助于描述由矩阵表示的系统的动力学。有许多应用利用特征向量,我们将在这里直接使用它们来执行谱聚类。
图表
图表是表示多种类型数据的自然方式。图是一组节点和连接这些节点的一组相应的边。这些边可以是有向的或无向的,甚至可以具有与其相关联的权重。
互联网上的路由器网络可以很容易地用图来表示。路由器是节点,边是路由器对之间的连接。有些路由器可能只允许一个方向的流量,因此可以通过定向边来表示流量可以流向哪个方向。边上的权重可以表示沿该边可用的带宽。有了这种设置,我们就可以查询该图,找到通过网络将数据从一台路由器传输到另一台路由器的有效路径。
让我们使用下面的无向图作为运行示例:
Graph with two disconnected components.
这个图有 10 个节点和 12 条边。它还有两个相连的分量{0,1,2,8,9}和{3,4,5,6,7}。连通分量是节点的最大子图,所有节点都有到子图中其余节点的路径。
如果我们的任务是将这些节点分配给社区或集群,那么连接的组件似乎很重要。一个简单的想法是让每个连接的组件成为自己的集群。对于我们的示例图来说,这似乎是合理的,但是整个图可能是连通的,或者连通的组件非常大。在一个连接的组件中也可能有更小的结构,它们是社区的良好候选。我们将很快看到这种连通分量思想对于谱聚类的重要性。
邻接矩阵
我们可以将示例图表示为邻接矩阵,其中行和列索引表示节点,条目表示节点之间边的存在与否。我们的示例图的邻接矩阵如下所示:
在矩阵中,我们看到第 0 行第 1 列的值为 1。这意味着存在连接节点 0 和节点 1 的边。如果边是加权的,那么边的权重将出现在这个矩阵中,而不是只有 1 和 0。因为我们的图是无向的,所以第 I 行第 j 列的条目将等于第 j 行第 I 列的条目。最后要注意的是,这个矩阵的对角线都是 0,因为我们的节点都没有自己的边。
次数矩阵
一个节点的度就是有多少条边连接到它。在有向图中,我们可以谈论入度和出度,但是在这个例子中,我们只有度,因为边是双向的。查看我们的图,我们看到节点 0 的度为 4,因为它有 4 条边。我们也可以通过对邻接矩阵中的节点行求和来得到度。
度矩阵是对角矩阵,其中条目(I,I)处的值是节点 I 的度。让我们为我们的示例找到度矩阵:
首先,我们对邻接矩阵的轴 1(行)求和,然后将这些值放入对角矩阵中。从度矩阵中我们不难看出,节点 0 和 5 有 4 条边,而其余节点只有 2 条。
图拉普拉斯算子
现在我们要计算拉普拉斯图。拉普拉斯算子只是图的另一种矩阵表示。它有几个漂亮的属性,我们将利用这些属性进行谱聚类。为了计算正常的拉普拉斯算子(有几种变体),我们只需从我们的度矩阵中减去邻接矩阵:
拉普拉斯的对角线是我们的节点的度,而非对角线是负的边权重。这是我们在执行谱聚类时所追求的表示。
图拉普拉斯的特征值
如前所述,拉普拉斯有一些美丽的属性。为了对此有所了解,让我们在向我们的图添加边时,检查与拉普拉斯相关联的特征值:
Eigenvalues of Graph Laplacian.
我们看到,当图完全不连通时,我们的十个特征值都是 0。当我们添加边时,我们的一些特征值增加。其实 0 特征值的个数对应的是我们图中连通分支的个数!
仔细观察添加的最终边,将两个组件连接成一个。当这种情况发生时,除了一个特征值之外,所有的特征值都被提升:
Number of 0-eigenvalues is number of connected components.
第一个特征值是 0,因为我们只有一个连通分量(整个图是连通的)。相应的特征向量将总是具有恒定值(在这个例子中,所有的值都接近 0.32)。
第一个非零特征值称为谱隙。光谱间隙给了我们一些图形密度的概念。如果这个图是密集连接的(10 个节点的所有对都有一条边),那么谱间隙将是 10。
第二个特征值叫做菲德勒值,对应的向量就是菲德勒向量。Fiedler 值近似于将图分成两个相连部分所需的最小图割。回想一下,如果我们的图已经是两个连通的部分,那么 Fiedler 值将是 0。Fiedler 向量中的每个值都为我们提供了关于该节点属于切割哪一侧的信息。让我们根据字段向量中的条目是否为正来给节点着色:
Nodes colored based on whether their entry in the Fiedler Vector is >0.
这个简单的技巧将我们的图形分成了两个集群!为什么会这样?记住零特征值代表连通分量。接近零的特征值告诉我们,两个分量几乎是分开的。这里我们有一条边,如果它不存在,我们会有两个独立的部分。所以第二特征值很小。
总结一下我们目前所知道的:第一个特征值是 0,因为我们有一个连通分支。第二个特征值接近于 0,因为我们离有两个连通分量还差一条边。我们还看到,与该值相关联的向量告诉我们如何将节点分成那些近似连接的组件。
你可能已经注意到接下来的两个特征值也很小。这告诉我们,我们“接近”拥有四个独立的连接组件。一般来说,我们经常寻找特征值之间的第一个大间隙,以便找到我们的数据中表示的聚类数。看到特征值四和特征值五的差距了吗?
在间隙之前具有四个特征值表明可能有四个集群。与前三个正特征值相关联的向量应该给我们关于在图中需要进行哪三个切割以将每个节点分配给四个近似分量之一的信息。让我们从这三个向量构建一个矩阵,并执行 K 均值聚类来确定分配:
Spectral Clustering for 4 clusters.
该图被分割成四个象限,节点 0 和 5 被任意分配给它们相连的象限之一。这真的很酷,这就是谱聚类!
总而言之,我们首先用我们的图建立一个邻接矩阵。然后,我们通过从度矩阵中减去邻接矩阵来创建图拉普拉斯算子。拉普拉斯算子的特征值表明有四个聚类。与这些特征值相关联的向量包含关于如何分割节点的信息。最后,我们对这些向量进行 K-Means 运算,以获得节点的标签。接下来,我们将看到如何对任意数据执行此操作。
任意数据的谱聚类
看下面的数据。这些点是从添加了一些噪声的两个同心圆中画出的。我们希望有一种算法能够将这些点聚集成产生它们的两个圆。
Circles Data.
这些数据不是以图表的形式出现的。所以首先,让我们尝试一个类似 K-Means 的切饼算法。K-Means 将找到两个质心,并根据它们最接近的质心来标记这些点。结果如下:
K-Means Clustering of the Circles Data.
很明显,K-Means 没用。它基于欧几里得距离进行操作,并假设聚类大致是球形的。这些数据(通常是真实世界的数据)打破了这些假设。让我们尝试用谱聚类来解决这个问题。
最近邻图
有几种方法可以把我们的数据当作图表。最简单的方法是构造一个 k 近邻图。k-最近邻图将每个数据点视为图中的一个节点。然后,从每个节点到其在原始空间中的 k 个最近邻居画一条边。通常,该算法对 k 的选择不太敏感。较小的数字,如 5 或 10,通常工作得很好。
再次查看数据的图片,想象每个点都与其最近的 5 个邻居相连。外环中的任何一点都应该能够沿着环的路径前进,但是不会有任何路径进入内环。很容易看出这个图有两个相连的部分:外环和内环。
由于我们只将这些数据分成两个部分,我们应该能够使用之前的 Fiedler 向量技巧。下面是我用来对这些数据进行谱聚类的代码:
结果如下:
Spectral Clustering of the Circles Data.
其他方法
最近邻图是一种很好的方法,但是它依赖于“接近的”点应该属于同一个聚类的事实。根据你的数据,这可能不是真的。更一般的方法是构造一个亲和矩阵。相似矩阵就像邻接矩阵,除了一对点的值表示这些点彼此有多相似。如果成对的点非常不相似,那么相似度应该是 0。如果这些点是相同的,那么相似性可能是 1。通过这种方式,亲和度的作用就像是我们图中边的权重。
如何判断两个数据点相似意味着什么是机器学习中最重要的问题之一。通常领域知识是构建相似性度量的最佳方式。如果你能接触到领域专家,问他们这个问题。
还有专门学习如何直接从数据构建相似性度量的整个领域。例如,如果您有一些带标签的数据,您可以训练分类器根据两个输入是否具有相同的标签来预测它们是否相似。然后,该分类器可用于为成对的未标记点分配相似性。
结论
我们已经讨论了图和任意数据的谱聚类的理论和应用。当您的数据不满足其他常用算法的要求时,谱聚类是一种灵活的查找聚类的方法。
首先,我们在数据点之间形成了一个图表。图表的边捕捉了点之间的相似性。然后,可以使用图拉普拉斯的特征值来寻找最佳数量的聚类,并且可以使用特征向量来寻找实际的聚类标签。
我希望你喜欢这篇文章,并发现谱聚类在你的工作或探索中有用。
下次见!