TowardsDataScience 2023 博客中文翻译(二百二十六)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

掌握天气预报:利用 LSTM 深度学习模型释放 AI 的力量以实现准确的温度预测

原文:towardsdatascience.com/mastering-weather-predictions-unleash-the-power-of-ai-with-lstm-deep-learning-models-for-accurate-cadd72ce221

使用 LSTM 的先进深度学习技术预测温度趋势

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Octavio Santiago

·发表于Towards Data Science ·7 分钟阅读·2023 年 3 月 27 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片:使用 DALLE-2 生成的天气预测

天气预报是现代世界中最重要的工具之一,开发一个良好的温度预测模型可以为许多企业带来巨大的竞争优势。环境温度测量与农业、能源部门、交易、航空等多个业务领域直接相关。

测量温度是天气预报的重要方面,是气候条件中最基本且广泛使用的指标之一,因为它影响大气气体的行为以及空气和水的循环,这些都是气候系统的关键组成部分。

天气温度测量对许多企业的重要性体现在以下几个关键方面:

  • 农业:温度测量用于监测生长条件并预测潜在的作物产量,这有助于农民做出关于种植、收获和储存作物的明智决策。

  • 能源和公用事业:温度测量用于预测热浪、寒潮及其他温度极端情况,这有助于公用事业公司计划并应对能源需求的变化。一些公司进行能源交易,而能源需求预测对于获取良好的利润至关重要。

  • 运输:温度测量用于预测和监控极端温度,如热浪或寒潮,这些会影响交通系统(如道路和机场)的安全性和效率。

总体而言,温度测量是理解和预测天气及气候条件的基本工具,一个能够准确预测温度的模型对于许多行业的运作至关重要,并且对其他许多行业非常有利。

在这篇文章中,我们将学习如何构建 LSTM 深度学习模型来精确预测温度。

数据集

用于训练的数据集来自 INMET 网站,包含来自巴西圣保罗市 SAO PAULO — INTERLAGOS(A771)气象站的气象数据。温度数据的采样频率为每小时,训练数据从 2022 年 3 月 23 日 01:00 到 2023 年 3 月 23 日 12:00,共 1 年。

来源:mapas.inmet.gov.br/#

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:圣保罗地图

  • 目标变量:最高温度(ºC)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:INMET 温度(最高、平均、最低)

LSTM

LSTM(长短期记忆)是一种递归神经网络(RNN)架构,特别适合处理序列数据,如时间序列、语音或文本。

与传统的前馈神经网络不同,LSTM 模型的基本构建块是 LSTM 单元,该单元设计用于长期记住和更新信息。每个 LSTM 单元具有一组“门”,控制信息的流入和流出,使网络在处理输入数据时能够选择性地存储或忘记信息。

LSTM 网络通常由多个 LSTM 单元按序列排列组成。输入数据一次通过一个时间步的单元,前一个单元的输出用作下一个单元的输入。这使得网络能够学习和建模数据中的复杂序列模式。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:LSTM 架构

训练和基准

训练时间序列模型的主要结果是如何正确地将数据集拆分为训练集和测试集。由于序列很重要,我们不能随意拆分数据集,为了正确拆分数据集,使用了 sklearn 函数“TimeSeriesSplit”。

LSTM Vanilla

LSTM vanilla(或“vanilla LSTM”)指的是一种长短期记忆(LSTM)神经网络架构,是 LSTM 模型的基本或标准版本。

训练

  • 18 天的采样时间

  • 数据集按小时分隔

  • 80% 训练数据集

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:训练 LSTM Vanilla

训练的回归指标:

  • MAE(平均绝对误差):0.53

  • MSE(均方误差):0.60

  • MAPE(平均绝对百分比误差):2.9%

问题:

  • 温度上升的延迟和峰值识别

测试

  • 数据集的 20%

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:测试 LSTM vanilla

训练的回归指标:

  • MAE(平均绝对误差):0.59

  • MSE(均方误差):0.82

  • MAPE(均绝对百分比误差):2.5%

预测

  • 48 小时的预测

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:预测 LSTM Vanilla

LSTM 堆叠

堆叠 LSTM(长短期记忆)是一种流行 LSTM 循环神经网络架构的变体。在标准 LSTM 中,使用单层 LSTM 单元来处理序列数据。在堆叠 LSTM 中,使用多个 LSTM 单元层,其中一层的输出作为下一层的输入。

堆叠 LSTM 架构用于提高网络学习和建模数据中复杂序列模式的能力。堆叠中的每一层可以学习不同层次的抽象,并从输入数据中提取特征,这使得网络能够捕捉数据中的更复杂关系和依赖性。

训练

  • 18 天的样本时间

  • 数据集按小时分隔

  • 80% 训练数据集

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:训练 LSTM 堆叠

训练的回归指标:

  • MAE(均绝对误差):0.52

  • MSE(均方误差):0.57

  • MAPE(均绝对百分比误差):2.8%

问题:

  • 温度上升和峰值识别的延迟

  • 对未来预测存在问题

测试

  • 数据集的 20%

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:测试 LSTM 堆叠

训练的回归指标:

  • MAE(均绝对误差):0.57

  • MSE(均方误差):0.82

  • MAPE(均绝对百分比误差):2.4%

预测

  • 48 小时的预测

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:预测 LSTM 堆叠

评估预测 — LSTM Vanilla

在对每个模型进行训练和测试后,比较它们效率的最一致方法是与新的实际数据进行对比。对 LSTM Vanilla 和堆叠模型进行了未来 48 小时的预测,并与实际数据(黄色)进行比较,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:LSTM Vanilla、LSTM 堆叠和实际温度对比

LSTM Vanilla

  • 平均误差:2.04 ºC

  • 最大误差:4.44ºC

LSTM 堆叠

  • 平均误差:1.36 ºC

  • 最大误差:3.16ºC

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:MAE 比较

结论

在 48 小时的预测中,堆叠 LSTM 模型显示出较低的平均绝对误差和较低的最大温度误差,表明对这些 LSTM 模型及拟合数据来说,增加模型复杂性是有益的。

在 48 小时预测中,Stacked LSTM 模型显示出相较于 Vanilla 模型的平均绝对误差和最大温度误差均有所减少。这些发现表明,模型中加入的额外复杂性对 LSTM 架构是有利的。这些额外层的引入使得模型能够捕捉数据中的更复杂的模式和依赖性,从而提升其在预测任务中的性能。观察到的误差减少表明,增加的复杂性使得模型能够捕捉到更多的基础模式和温度数据的变异性,从而实现更准确的预测。

两个模型在预测未来 48 小时的温度值时均表现出较高的精确度,平均误差率约为 2%和 2ºC。然而,两个模型都遇到了准确识别与预期结果不同的次级温度谷的挑战。这一问题突显了模型在捕捉和解释温度数据中的复杂模式时的潜在限制,尤其是在出现意外波动或异常时。可能需要进一步研究以提高模型对这些不规则性的敏感性,并增强其在温度预测中的整体表现。

非常感谢您的阅读!如有任何问题或建议,请通过 LinkedIn 联系我:www.linkedin.com/in/octavio-b-santiago/

如果您想实现此解决方案或了解更多关于 LSTM 算法的内容,可以在我的 GitHub 库中找到完整的 Python 代码,链接如下:

代码

[## GitHub - octavio-santiago/temperature-forecasting: 温度预测模型与 AI - Python

温度预测模型与 AI - Python 本库包含了用于训练和评估 LSTM 深度学习的代码…

github.com

参考文献

数据来源:mapas.inmet.gov.br/# — INMET(国家气象研究所)公开数据集:CC0 1.0 通用许可证 (portal.inmet.gov.br/sobre)

现实世界中的数学:测试、模拟及更多

原文:towardsdatascience.com/math-in-the-real-world-tests-simulations-and-more-cf60b727cd86?source=collection_archive---------13-----------------------#2023-08-10

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 TDS Editors

·

关注 发布于 Towards Data Science · 发送为 Newsletter · 3 分钟阅读 · 2023 年 8 月 10 日

最佳的数学与统计学写作实现了一项艰难的壮举:它将高深的概念和复杂的公式与数据专业人士在日常工作中遇到的实际挑战联系起来。

一些数据科学家喜欢深入探讨新的数学话题,而另一些人则对这一领域持谨慎态度,甚至有些不情愿。无论你在这个范围的哪个位置,我们相信你会喜欢我们本周精选的文章。从 A/B 测试的内部机制到图论和统计实验,它们都轻松融合了理论与实践、抽象与具体。让我们深入了解吧。

  • 无论你是刚接触蒙特卡罗模拟还是需要一个扎实的复习,悉尼·奈的首篇 TDS 文章深入浅出地探讨了统计技术,“让我们在面对不确定性时进行战略性投注,从而使复杂的确定性问题变得概率化。”

  • 图论在机器学习研究中已经占据了核心位置,但对那些不在该领域的人来说,它可能仍然显得令人望而却步。亨尼·德·哈德提供了一份面向初学者的指南,介绍了图是什么,它们如何运作,以及数据科学家如何利用它们的力量来解决复杂的现实问题。

  • 如果自高中以来你对微分方程没有多加思考,这是你从全新角度重新审视它们的机会帅·郭的物理信息神经网络(PINN)系列文章回归,其中一篇专门讨论微分方程及其如何“提供对系统动态的洞察,并使我们能够预测系统未来的行为。”

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Flo P拍摄,来源于Unsplash

  • 要深入了解置换检验及其如何取代更传统的基于公式的统计方法,请跟随潘·克里坦了解如何设计带有重抽样的实验。(那些从较少数学背景转到数据科学的读者会发现这篇文章尤其有用!)

  • 我们最后的每周亮点回到了我们开始时提到的蒙特卡罗模拟,但将其力量用于不同的目的。艾达·约翰逊博士分享了一个有用的 A/B 测试介绍:它清晰地定义了相关的统计概念,并聚焦于使用蒙特卡罗模拟评估测试性能的过程

本周我们推荐的其他阅读内容并不完全免于数学,但它们为关于其他重要主题的引人入胜的讨论打开了空间。

  • 在一项系统而及时的研究中,Yennie Jun探讨了大型语言模型内置历史知识中的性别偏见。

  • 错过了 ICML 2023?Michael Galkin在这里帮助我们通过详细回顾最新进展和新兴趋势来赶上。

  • 每个人都喜欢抱怨数据清理,但Vicky Yu的简明指南可以帮助你简化这个过程,使其变得不那么乏味。

  • Francesco Foscarin在其首次 TDS 文章中,将变压器与爵士和弦相结合,展示了基于数据的树状音乐分析方法。

  • Hans van Dam将移动应用程序开发与 LLM 结合,通过一个利用 GPT-4 功能导航应用程序图形用户界面(GUI)的实用教程。

感谢您支持我们的作者!如果您喜欢在 TDS 上阅读的文章,请考虑成为 Medium 会员 — 这将解锁我们的整个存档(以及 Medium 上的每篇其他文章)。

直到下一个变量,

TDS 编辑

Matplotlib 提示,以立即提升你的数据可视化——根据《数据故事讲述》

原文:towardsdatascience.com/matplotlib-tips-to-instantly-improve-your-data-visualizations-according-to-storytelling-with-8e75601b38cb

使用 Matplotlib 在 Python 中重现Cole Nussbaumer Knaflic书中的经验教训

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Leonie Monigatti

·发布于数据科学前沿 ·阅读时间 9 分钟·2023 年 6 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

能够有效地用数据进行沟通是一项适用于所有数据相关工作的技能——不仅仅是数据科学家和数据分析师。

我最喜欢的一本关于这个主题的书是Cole Nussbaumer Knaflic的《数据故事讲述》。这本书里充满了如何改进数据可视化的实用示例。

[## 数据故事讲述:面向商业专业人士的数据可视化指南

《数据故事讲述:面向商业专业人士的数据可视化指南》 [Nussbaumer Knaflic, Cole] 在 Amazon.com 上…

www.amazon.com

我认为这本书唯一不幸的地方是,其示例是使用 Microsoft Excel 创建的。

如果你知道一个喜欢在 Excel 中创建数据可视化的工程师,请举手——是的,我也没有。

“你可能是一个工程师,但理解你的图表不应该需要一个工程学学位。” ——Cole Nussbaumer Knaflic在《数据故事讲述》中

这就是为什么这篇文章将涵盖我自从阅读Nussbaumer Knaflic的《数据故事讲述》以来使用过的 Matplotlib 代码片段。

import matplotlib.pyplot as plt

这篇文章假设你已经掌握了 Matplotlib 和 Seaborn 的数据可视化基础知识,比如创建条形图、折线图或散点图,修改颜色调色板,并添加基本标签。文章还假设你知道何时使用哪种类型的图表。

这篇文章重点介绍一些不太常见的技巧,而不是 Matplotlib 的基础知识,例如:

  • 如何去除 Matplotlib 图的顶部和右侧边框

  • 如何从 Matplotlib 图中移除刻度线

  • 如何自定义 Matplotlib 图中每个条形的颜色

  • 如何更改 Matplotlib 图中 x 轴和 y 轴的颜色

  • 如何向 Matplotlib 图中添加文本注释

  • 如何在 Matplotlib 图中为条形图添加数值

  • 如何在 Matplotlib 注释中将整段或部分文本设为粗体

  • 如何在 Matplotlib 注释中为文本上色

我们从一个简单的例子开始。以下数据是虚构的,以便我们能够专注于数据可视化技术:

import pandas as pd

# Define fictional example dataframe
df = pd.DataFrame(
          {'feature 1' : ['cat 1', 'cat 2', 'cat 3', 'cat 4'],
           'feature 2' : [400, 300, 200, 100]
          })

让我们以一个简单的单色条形图为起点,使用 Seaborn 并添加标题:

import seaborn as sns

# Create a basic bar chart from the example dataframe
fig, ax = plt.subplots(1,1, figsize = (6, 4))
sns.barplot(data =  df, 
            x = 'feature 1', 
            y = 'feature 2', 
            color = 'tan')

# Add title
ax.set_title('Meaningful Title')

plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

移除杂乱

在《杂乱是你的敌人!》一章中,Nussbaumer Knaflic 讨论了如何识别并消除数据可视化中的视觉杂乱——这一部分将展示如何在 Matplotlib 图中去除视觉杂乱。

“[……E]每一个元素都会增加观众的认知负担。”—— Cole Nussbaumer Knaflic 在《数据讲故事》中

如何去除 Matplotlib 图的顶部和右侧边框

默认情况下,Matplotlib 图的边缘有一个所谓的脊线框。尤其是顶部和右侧的脊线会使数据可视化显得杂乱,因此应该被去除。

你可以通过以下代码片段简单地去除不相关的脊线:

# Remove top and right spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

如果你还想去除其他脊线,可以使用 'bottom''left'。如果你想去除边框,包括完整的 x 轴和 y 轴,可以使用 ax.axis('off')

如何从 Matplotlib 图中移除刻度线

刻度线通常不会被认为是杂乱的。但在某些情况下,如此示例中,条形图的 x 轴刻度线是多余的。

# Remove ticks on x-axis
ax.tick_params(bottom = False)

如果你还想去除 y 轴的刻度线,可以使用 left = False

现在,去除杂乱后的例子如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 Matplotlib 图中去除杂乱前后的对比。

新图由于去除了视觉杂乱,难道没有给你一种更加平静的感觉吗?

减少强调

在《像设计师一样思考》一章中,Nussbaumer Knaflic 向我们展示了如何消除必要但相关的信息。这一部分展示了如何更改 Matplotlib 图中不重要部分的颜色。

“将必要但不会影响信息传递的项推到背景中。[……] 浅灰色在这方面效果很好。”—— Cole Nussbaumer Knaflic 在《数据讲故事》中

如何自定义 Matplotlib 图中每个条形的颜色

palette 参数替换 sns.barplot 方法中的 color 参数,以控制每个条形的颜色。通过这样做,你可以使用浅灰色来减少不重要的条形的强调,只用主要颜色来突出相关的条形。

# Define colors of individual bars
custom_colors = ['lightgrey', 'tan', 'lightgrey', 'lightgrey']

# De-emphasize less important bars
sns.barplot(data =  df, 
            x = 'feature 1', 
            y = 'feature 2', 
            palette = custom_colors) 

如何在 Matplotlib 绘图中更改 x 轴和 y 轴的颜色

接下来,我们还想降低 x 轴和 y 轴的颜色。为此,我们需要降低轴的脊、刻度和标签的颜色:

# Mute colors of spines
ax.spines['left'].set_color('grey')   
ax.spines['bottom'].set_color('grey')

# Mute colors of ticks
ax.tick_params(colors = 'grey')

# Mute colors of labels
ax.set_xlabel('feature 1', color = 'grey')
ax.set_ylabel('feature 2', color = 'grey')

现在,强调较不重要信息的示例如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 Matplotlib 图中强调和不强调不重要信息的前后对比。

文本是你的朋友:添加注释

Nussbaumer Knaflic 强调,您应该在数据可视化中添加文本以突出关键要点。在本节中,我们将探讨 ax.annotate() 方法,以在 Matplotlib 图中添加文本。

“如果您希望您的观众得出结论,请用文字说明。” — Cole Nussbaumer Knaflic 在《用数据讲故事》中

如何在 Matplotlib 图中添加文本注释

要在 Matplotlib 图形中添加文本,您可以使用 ax.annotate() 方法,该方法将文本及其在图中的位置作为参数。此外,您还可以指定水平 (ha) 或垂直对齐 (va) 以及字体大小等方面。

# Add text annotations
ax.annotate('Look at "cat 2". \nThis is important!',
             xy = (1.5, 360), 
              ha = 'center',
              fontsize = 11,
           )

如果您想添加额外的箭头指向某个位置,则需要使用以下参数:

  • xy:要注释的点——即箭头指向的位置

  • xytext:文本的位置(以及箭头的终点)

  • arrowprops = {'arrowstyle' : '->'}:箭头的样式

如何在 Matplotlib 图中给条形图添加数值

要为每个单独的条形添加数值,我们需要遍历 ax.patches。对于每个 bar,您可以使用 get_height()get_width()get_x() 方法来将数值放置在条形上方。

# Annotate bar chart with values
for bar in ax.patches:
    ax.annotate(int(bar.get_height()),
                xy = (bar.get_x() + bar.get_width() / 2, bar.get_height()), 
                ha = 'center', 
                va = 'center',
                xytext = (0, 8),
                textcoords = 'offset points'
                )

现在,添加了文本注释的示例如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 Matplotlib 图中添加文本注释的前后对比。

强调

在章节“聚焦您的观众注意力”中,Nussbaumer Knaflic 讨论了如何利用预注意属性引导观众的注意力到您希望他们看到的内容。在本节中,我们将讨论可以应用于 Matplotlib 图中文本注释的一些简单调整,以利用文本中的预注意属性。

“……如果我们战略性地使用预注意属性,它们可以帮助我们让观众在他们甚至不知道自己在看到之前,就看到我们希望他们看到的东西。” — Cole Nussbaumer Knaflic 在《用数据讲故事》中

如何在 Matplotlib 注释中使整段或部分文本变为 粗体

使用 粗体 文本可以帮助突出数据可视化中的重要部分。如果您只想突出注释中的一部分,可以在字符串中使用 $\\bf{}$ 并将要强调的文本放在花括号中。如果您想突出整个注释,只需添加参数 fontweight = 'bold'

# Make only part of text bold
ax.annotate('Look at "cat 2". \nThis is $\\bf{important}$!', 
            #...
           )

# Make all of the text bold
ax.annotate('Look at "cat 2". \nThis is important!', 
            #...
            fontweight='bold',
           )

如何在 Matplotlib 注释中给文本上色

要将特定文本与数据可视化中的特定元素关联起来,你可以利用相同颜色的关联性。要给文本注释上色,只需将 color 参数添加到 ax.annotate() 方法中。

# Remove ticks on x-axis
ax.tick_params(bottom = False)
# Add important take away to plot 
ax.annotate('Look at "cat 2". \nThis is $\\bf{important}$!', # Emphasize important terms
            xy = (1.5, 360), 
            ha = 'center',
            color = 'tan', 
            fontsize = 11,
           )

现在,强调重要信息的示例如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 Matplotlib 绘图中应用重要信息之前和之后。

下面是创建最终数据可视化的代码。

import matplotlib.pyplot as plt
import seaborn as sns

# Define color palette
highlight_color = 'tan'
muted_color = 'dimgrey'
muted_color2 = 'lightgrey'
custom_colors = [muted_color2, 'tan', muted_color2, muted_color2]

# Create a basic bar chart from the example dataframe
fig, ax = plt.subplots(1,1, figsize = (6, 4))

sns.barplot(data =  df, 
            x = 'feature 1', 
            y = 'feature 2', 
            palette = custom_colors) # De-emphasize less important bars

# Add title
ax.set_title('Meaningful Title')

# Mute colors of labels
ax.set_xlabel('feature 1', color = muted_color)
ax.set_ylabel('feature 2', color = muted_color)

# Remove unimportant spines and mute color of remaining spines
ax.spines['right'].set_visible(False)      # Remove top and right spines
ax.spines['top'].set_visible(False)        # Remove top and right spines 
ax.spines['left'].set_color(muted_color)   # Mute colors of spines
ax.spines['bottom'].set_color(muted_color) # Mute colors of spines

# Remove ticks on x-axis and mute colors of ticks
ax.tick_params(bottom = False,        # Remove ticks on x-axis
    colors = muted_color,             # Mute colors of ticks
)

# Annotate bar chart with values
for i, bar in enumerate(ax.patches):
    ax.annotate(int(bar.get_height()),
    xy = (bar.get_x() + bar.get_width() / 2, bar.get_height()), 
    ha = 'center', 
    va = 'center',
    xytext = (0, 8),
    textcoords = 'offset points',
    color = custom_colors[i])

# Add important take away to plot 
ax.annotate('Look at "cat 2". \nThis is $\\bf{important}$!', # Emphasize important terms
            xy = (1.5, 360), 
            ha = 'center',
            color = highlight_color, 
            fontsize = 11,
           )

plt.show()

科尔·努斯鲍默·克纳夫利克的《数据讲故事》是我最喜欢的数据可视化书籍之一。如果你对如何将数据可视化提升到更高水平感兴趣,我绝对推荐这本书。

如果你对更多的 Matplotlib 技巧感兴趣,在这个仓库中,安德烈·加斯科夫用 Python 和 Matplotlib 重现了书中的许多可视化:

[## GitHub - empathy87/storytelling-with-data: 来自《数据讲故事》的绘图…

来自《数据讲故事》的绘图实现,使用 Python 和 matplotlib - GitHub …

github.com

享受这个故事了吗?

免费订阅 以便在我发布新故事时获得通知。

[## 每当 Leonie Monigatti 发布新内容时获取电子邮件通知。

每当 Leonie Monigatti 发布新内容时获取电子邮件通知。通过注册,如果你还没有 Medium 账户,你将创建一个…

medium.com

LinkedInTwitter Kaggle上找到我!

参考文献

图片参考

如果没有其他说明,所有图片均由作者创作。

网络与文献

科尔·努斯鲍默·克纳夫利克. 《数据讲故事:商业专业人士的数据可视化指南》, Wiley, © 2015.

Matplotlib 教程:将你的国家地图提升到另一个水平

原文:towardsdatascience.com/matplotlib-tutorial-lets-take-your-country-maps-to-another-level-a6bd1f40fff

Matplotlib 教程

如何使用 Python 和 Matplotlib 绘制美丽的地图

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Oscar Leo

·发表于Towards Data Science ·阅读时间 10 分钟·2023 年 9 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

地图由作者创建

是的,我使用 Matplotlib 创建了上面的地图,我会在本教程中向你展示如何做。

这个想法是创建一个可重用且灵活的函数,使我能够立即绘制任何区域的美丽地图。

有了这样的功能,我可以节省大量时间,当我想创建带有地理组件的图表或信息图时。

我还想展示可重用代码的力量,因为许多从事数据可视化的程序员往往忽视了这样的最佳实践。

本教程包含了创建上述非洲地图所需的每一行代码。

让我们开始吧。

第一步:下载地理数据

在开始本教程之前,你需要做的唯一一件事是从这里下载地理数据:

datacatalog.worldbank.org/search/dataset/0038272/World-Bank-Official-Boundaries

我使用的是名为 World Boundaries GeoJSON — Very High Resolution 的数据。

这是来自世界银行的官方边界数据集,你可以按任何你想要的方式使用。

第二步:导入库

一如既往,我们从导入必要的库开始,我们不需要很多库。由于我们有地理数据,我们希望使用geopandas使绘图尽可能简单。

import numpy as np
import pandas as pd
import seaborn as sns
import geopandas as gpd
import matplotlib.pyplot as plt

import matplotlib.patheffects as PathEffects
from matplotlib.patches import Polygon

你可能还没见过的一个导入是PathEffe。我将使用它来稍后在国家标签周围创建边框。

第三步:创建 seaborn 风格

在绘图之前,我总是创建一个 seaborn 风格,以获得一致的外观。在这里,我只定义了一个background_colorfont_familytext_color。我将背景设置为浅蓝色,以代表海洋。

font_family = "sans"
background_color = "#D4F1F4"
text_color = "#040303"

sns.set_style({
    "axes.facecolor": background_color,
    "figure.facecolor": background_color,
    "font.family": font_family,
    "text.color": text_color,
})

你可以样式化其他图表方面,例如网格,但我更喜欢在绘制地图时使用plt.axis(“off”)隐藏大多数标准图表组件。

第 4 步:加载地理数据

现在是时候使用 geopandas 从世界银行加载地理数据。我正在更改塞舌尔的CONTINENT和赤道几内亚的INCOME_GRP

这看起来并不像它看起来那么奇怪,因为塞舌尔是非洲的一部分,根据世界银行最新的数据,赤道几内亚属于“中上收入”组别。

world = gpd.read_file("WB_Boundaries_GeoJSON_highres/WB_countries_Admin0.geojson")
world.loc[world.NAME_EN == "Seychelles", "CONTINENT"] = "Africa"
world.loc[world.NAME_EN == "Equatorial Guinea", "INCOME_GRP"] = "3\. Upper middle income"

africa = world[world.CONTINENT == "Africa"].reset_index(drop=True)

disputed_areas = gpd.read_file("WB_Boundaries_GeoJSON_highres/WB_Admin0_disputed_areas.geojson")
disputed_areas = disputed_areas[disputed_areas.CONTINENT == "Africa"]

接下来,我将非洲国家分开,因为我想将它们单独绘制,以作为我的图表的重点,对争议区域也进行同样的处理。

第 5 步:创建绘制地图函数

现在,我准备创建绘制地图的第一个版本函数。稍后我将扩展它以添加一些额外的函数。

这个第一个版本循环遍历一系列 geopandas 数据帧,并使用coloredgecolor列绘制它们。

该函数接受一些我还未使用的参数,但它们在教程的后续部分将会很有用。

def draw_map(
    maps_to_draw, 
    boundry_map_index=0,
    use_hatch_for_indexes=[],
    padding={},
    labels=[],
    figsize=(40, 40)
):

    assert "color" in maps_to_draw[0].columns, "Missing color column in map dataframe"
    assert "edgecolor" in maps_to_draw[0].columns, "Missing edgecolor column in map dataframe"

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot()

    for map_index, map_to_draw in enumerate(maps_to_draw):
        map_to_draw.plot(
            ax=ax, color=map_to_draw.color, edgecolor=map_to_draw.edgecolor,
            hatch="//" if map_index in use_hatch_for_indexes else "",
        )

    # Additional functions below this comment

    return ax

我希望draw_map()返回ax对象,因为我经常希望根据使用案例扩展图表以添加更多信息。

要使用draw_map(),我必须为每个要绘制的 geopandas 数据帧定义coloredgecolor列。

我为world定义了颜色,以将每个国家绘制为淡背景。对于africa,我选择了使用数据中表示的三种收入组别的更显眼的颜色。

选择颜色时的一个提示是使用Coloring for Colorblindness检查你的颜色是否适合色盲人士。

world["color"] = "#f0f0f0"
world["edgecolor"] = "#c0c0c0"

africa["edgecolor"] = "#000000"
africa.loc[africa.INCOME_GRP == "5\. Low income", "color"] = "#dadada"
africa.loc[africa.INCOME_GRP == "4\. Lower middle income", "color"] = "#89bab2"
africa.loc[africa.INCOME_GRP == "3\. Upper middle income", "color"] = "#1B998B"

disputed_areas["color"] = "#FFD6D6"
disputed_areas["edgecolor"] = "#000000"

现在我已经创建了所需的列,我可以运行draw_map()

ax = draw_map(maps_to_draw=[world, africa, disputed_areas])

plt.axis("off")
plt.show()

这是我得到的结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

地图由作者创建

这是一个很好的开始,但我们还有很长的路要走。

第 6 步:设置边界

由于我想创建一张非洲地图,所以显示整个世界是没有意义的。

这就是为什么我创建了以下函数,它接收一个地理数据帧,并设置ax对象的范围。

def set_limits(ax, data, pad_left=0, pad_right=0, pad_top=0, pad_bottom=0):
    xmin_ = data.bounds.minx.min()
    ymin_ = data.bounds.miny.min()
    xmax_ = data.bounds.maxx.max()
    ymax_ = data.bounds.maxy.max()

    xmin = xmin_ - pad_left * (xmax_ - xmin_)
    xmax = xmax_ + pad_right * (xmax_ - xmin_)
    ymin = ymin_ - pad_bottom * (ymax_ - ymin_)
    ymax = ymax_ + pad_top * (ymax_ - ymin_)

    ax.set(xlim=(xmin, xmax), ylim=(ymin, ymax))

你可以在区域的每一侧添加填充,以便为额外的信息(如图例)留出空间。

让我们使用boundary_map_indexpadding参数将其添加到draw_map()函数中。

def draw_map(
    maps_to_draw, 
    boundry_map_index=0,
    use_hatch_for_indexes=[],
    padding={},
    labels=[],
    figsize=(40, 40)
):

    ...
    # Additional functions below this comment
    set_limits(ax, maps_to_draw[boundry_map_index], **padding)

    return ax

我将boundry_map_index设置为maps_to_draw列表中我们希望显示的 geopandas 数据帧的索引,并添加了一些padding

注意:我还传递了use_hatch_for_indexes=[2]以在争议区域绘制斜纹,以显示它们与其他区域的不同。

ax = draw_map(
    maps_to_draw=[world, africa, disputed_areas], boundry_map_index=1,
    padding={"pad_bottom": -0.08, "pad_top": 0.07, "pad_left": 0.07, "pad_right": 0.05},
    use_hatch_for_indexes=[2]
)

plt.axis("off")
plt.show()

现在我们得到了一个看起来不错的非洲地图,还包括了邻近国家的轮廓。如果你不想显示邻近国家,请从maps_to_draw中删除world

这是生成的地图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

地图由作者创建

我们正在取得进展。

第 7 步:添加国家标签

由于大多数人不认识所有非洲国家,下一步是为每个国家添加标签。

我想在大国家的中间附近添加标签,对于较小的国家,则在国家和标签之间添加一条线。

def add_label(ax, label, fontsize=24, fontweight="bold", va="center", ha="center"):            
    annotation = plt.annotate(
        label["label"], 
        xy=label["xytext"] if "xypin" not in label.keys() else label["xypin"], 
        xytext=None if "xypin" not in label.keys() else label["xytext"], 
        xycoords="data", fontsize=fontsize, va=va, ha=ha,
        linespacing=1.3, color=label["color"], fontweight=fontweight, 
        arrowprops={
            "arrowstyle": "-",
            "linewidth": 2,
        })

    annotation.set_path_effects([PathEffects.withStroke(linewidth=6, foreground='w')])

定义国家标签的位置是这个教程中唯一繁琐的过程,但我已经为你完成了。

你可以尝试使用 geopandas 中的几何图形计算每个标签的位置,但这说起来容易做起来难,因为国家的形状各异。

有时候,最快的方法是卷起袖子直接硬编码值,这也是我在这里做的。

country_labels = [
    {"label": "Algeria", "color": "#040303", "xytext": (2.0, 27.5)},
    {"label": "Angola", "color": "#040303", "xytext": (17.7, -13.1)},
    {"label": "Benin", "color": "#040303", "xytext": (3.2, 5.4), "xypin": (2.3, 7.6)},
    {"label": "Botswana", "color": "#040303", "xytext": (24.4, -22.3)},
    {"label": "Burkina\nFaso", "color": "#040303", "xytext": (-1.4, 12.6)},
    {"label": "Burundi", "color": "#040303", "xytext": (43.3, -4.9), "xypin": (29.8, -3.6)},
    {"label": "Cameroon", "color": "#040303", "xytext": (12.5, 5.2)},
    {"label": "Cape Verde", "color": "#040303", "xytext": (-23.7, 19), "xypin": (-23.7, 16)},
    {"label": "Central African\nRepublic", "color": "#040303", "xytext": (21.1, 6.5)},
    {"label": "Chad", "color": "#040303", "xytext": (18.5, 16.0)},
    {"label": "Comoros", "color": "#040303", "xytext": (46.8, -9.6), "xypin": (43.3, -11.7)},
    {"label": "Cote\nd'Ivoire", "color": "#040303", "xytext": (-5.5, 8.5)},
    {"label": "Democratic\nRepublic of\nthe Congo", "color": "#040303", "xytext": (23.3, -2.7)},
    {"label": "Djibouti", "color": "#040303", "xytext": (47.0, 13.4), "xypin": (43.0, 12.2)},
    {"label": "Egypt", "color": "#040303", "xytext": (29.2, 26.6)},
    {"label": "Equatorial\nGuinea", "color": "#040303", "xytext": (5.9, -2.5), "xypin": (10.5, 1.6)},
    {"label": "Eritrea", "color": "#040303", "xytext": (43.0, 16.9), "xypin": (38.5, 16.2)},
    {"label": "Lesotho", "color": "#040303", "xytext": (35.0, -31.0), "xypin": (28.4, -29.5)},
    {"label": "Ethiopia", "color": "#040303", "xytext": (39.9, 8.5)},
    {"label": "Gabon", "color": "#040303", "xytext": (11.8, -0.7)},
    {"label": "Ghana", "color": "#040303", "xytext": (-1.3, 6.6)},
    {"label": "Guinea", "color": "#040303", "xytext": (-11.6, 11.0)},
    {"label": "Guinea-\nBissau", "color": "#040303", "xytext": (-20.3, 10.3), "xypin": (-14.5, 12.2)},
    {"label": "Kenya", "color": "#040303", "xytext": (37.9, 0.5)},
    {"label": "Eswantini", "color": "#040303", "xytext": (35.5, -29.3), "xypin": (31.5, -26.8)},
    {"label": "Liberia", "color": "#040303", "xytext": (-10.6, 3.6), "xypin": (-9.6, 6.7)},
    {"label": "Libya", "color": "#040303", "xytext": (17.5, 27.5)},
    {"label": "Madagascar", "color": "#040303", "xytext": (46.7, -19.6)},
    {"label": "Malawi", "color": "#040303", "xytext": (38.9, -21.3), "xypin": (35.0, -15.6)},
    {"label": "Mali", "color": "#040303", "xytext": (-1.9, 17.8)},
    {"label": "Mauritania", "color": "#040303", "xytext": (-11.1, 19.6)},
    {"label": "Morocco", "color": "#040303", "xytext": (-6.9, 31.3)},
    {"label": "Mozambique", "color": "#040303", "xytext": (40.8, -15.2)},
    {"label": "Namibia", "color": "#040303", "xytext": (17.3, -20.7)},
    {"label": "Niger", "color": "#040303", "xytext": (9.8, 17.5)},
    {"label": "Nigera", "color": "#040303", "xytext": (7.8, 9.8)},
    {"label": "Republic of\nthe Congo", "color": "#040303", "xytext": (7.8, -7.2), "xypin": (12.0, -4.1)},
    {"label": "Rwanda", "color": "#040303", "xytext": (43.8, -3.6), "xypin": (30.1, -2.0)},
    {"label": "São Tomé and\nPríncipe", "color": "#040303", "xytext": (-0.9, 0.2), "xypin": (6.8, 0.2)},
    {"label": "Senegal", "color": "#040303", "xytext": (-15.0, 14.7)},
    {"label": "Seychelles", "color": "#040303", "xytext": (55.6, -2), "xypin": (55.6, -4.5)},
    {"label": "Sierra Leone", "color": "#040303", "xytext": (-16.4, 6.3), "xypin": (-12.0, 8.5)},
    {"label": "Somalia", "color": "#040303", "xytext": (45.7, 2.7)},
    {"label": "South\nAfrica", "color": "#040303", "xytext": (22.4, -31.0)},
    {"label": "South\nSudan", "color": "#040303", "xytext": (30.2, 7.0)},
    {"label": "Sudan", "color": "#040303", "xytext": (29.7, 16.0)},
    {"label": "Tanzania", "color": "#040303", "xytext": (35.0, -6.7)},
    {"label": "The\nGambia", "color": "#040303", "xytext": (-20.3, 13.6), "xypin": (-15.4, 13.6)},
    {"label": "Togo", "color": "#040303", "xytext": (1.0, 4.1), "xypin": (1.0, 7.5)},
    {"label": "Tunisia", "color": "#040303", "xytext": (9.3, 38.9), "xypin": (9.3, 35.7)},
    {"label": "Uganda", "color": "#040303", "xytext": (32.6, 0.9)},
    {"label": "Zambia", "color": "#040303", "xytext": (26.1, -14.9)},
    {"label": "Zimbawe", "color": "#040303", "xytext": (29.7, -19.1)},
]

我们将函数直接添加到set_limits下方。

def draw_map(
    maps_to_draw, 
    boundry_map_index=0,
    use_hatch_for_indexes=[],
    padding={},
    labels=[],
    figsize=(40, 40)
):

    ...
    # Additional functions below this comment
    set_limits(ax, maps_to_draw[boundry_map_index], **padding)

    for label in labels:
        add_label(ax, label)

    return ax

然后将其传递给draw_map()

ax = draw_map(
    maps_to_draw=[world, africa, disputed_areas], boundry_map_index=1,
    padding={"pad_bottom": -0.08, "pad_top": 0.07, "pad_left": 0.07, "pad_right": 0.05},
    use_hatch_for_indexes=[2]
    labels=country_labels
)

plt.axis("off")
plt.show()

这是我们得到的结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

地图由作者制作

太棒了!现在,我可以看到我正在查看哪些国家,最后要做的是解释颜色。

第 8 步:添加图例和标题

在 matplotlib 中添加图例很简单,所以为了让它更有趣,我想使用国家的形状,而不是标准的圆圈或方块。

这个步骤是一个例子,其中我使用draw_map()返回的 ax,而不是直接将其添加到函数中。

为了将一个国家的几何图形转移到另一个位置,我创建了以下函数,它接受一个row并将形状放置在(x_locy_loc)。

def draw_legend_geometry(ax, row, x_loc, y_loc, height):
    x = np.array(row.geometry.boundary.coords.xy[0])
    y = np.array(row.geometry.boundary.coords.xy[1])

    x = x - (row.geometry.centroid.x - x_loc)
    y = y - (row.geometry.centroid.y - y_loc)

    ratio = height / (y.max() - y.min())
    x = x * ratio + (x_loc - x_loc * ratio)
    y = y * ratio + (y_loc - y_loc * ratio)

    ax.add_artist(Polygon(np.stack([x, y], axis=1), facecolor=row.color, edgecolor=row.edgecolor, hatch=row.hatch))

除了改变位置,它还会改变几何图形的比例,以达到特定的height

为了使用draw_legend_geometry(),我创建了一个名为legend的数据框,其中选择了适当的国家来代表每种颜色。目前,它仅适用于具有单个Polygon的国家,而不适用于MultiPolygon

我添加INCOME_GRP到西撒哈拉的方式有点“hacky”,利用它来排序值,并将row.INCOME_GRP[3:]作为文本,但现在这样也没关系。

我还在绘制图例后添加了标题和数据来源。

ax = draw_map(
    maps_to_draw=[world, africa, disputed_areas], boundry_map_index=1,
    padding={"pad_bottom": -0.08, "pad_top": 0.07, "pad_left": 0.07, "pad_right": 0.05},
    use_hatch_for_indexes=[2]
    labels=country_labels,
)

legend = pd.concat([
    disputed_areas[disputed_areas.NAME_EN == "Western Sahara"],
    africa[africa.NAME_EN.isin(["Niger", "Senegal", "Botswana"])]
])

legend.loc[legend.NAME_EN == "Western Sahara", "INCOME_GRP"] = "6\. Disputed area"
legend = legend.sort_values("INCOME_GRP")
legend["hatch"] = ["", "", "", "//"]

for i, row in legend.reset_index().iterrows():
    draw_legend_geometry(ax, row, -25, -20 - 3.5*i, 2.5)
    ax.annotate(row.INCOME_GRP[3:], (-22, -20 - 3.5*i), fontsize=28, fontweight="bold", va="center")

fontstyles = {"fontweight": "bold", "ha": "left"}
plt.annotate("Data source:", xy=(0.05, 0.32), fontsize=24, xycoords="axes fraction", **fontstyles)
plt.annotate("The World Bank", xy=(0.133, 0.32), fontsize=24, xycoords="axes fraction", color="#1B998B", **fontstyles)
plt.title("Income Groups in Africa", x=0.05, y=0.29, fontsize=42, **fontstyles)

plt.axis("off")
plt.show()

如果我运行上述代码,我会得到以下地图,它与您在教程开始时看到的地图完全相同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

地图由作者制作

这就是教程的最后一步。恭喜你完成了!

结论

你可以通过添加更多样式参数来使draw_map()函数更灵活(我这里有几个硬编码的值),但对我来说,这覆盖了 95%的使用场景。

我希望你喜欢这个教程,并学到了可以在项目中使用的东西。

如果你做到了,鼓掌,订阅,并分享,以便更多人能学会如何用 Python 和 Matplotlib 绘制美丽的地图。

你也应该看看我的其他教程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Oscar Leo

Matplotlib 教程

查看列表8 篇故事!外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

此外,订阅我的免费通讯,Data Wonder,并在 Twitter 上关注我 @oscarl3o

感谢阅读,下次见。

逻辑回归中的矩阵和向量运算

原文:towardsdatascience.com/matrix-and-vector-operations-in-logistic-regression-e35714c4810f?source=collection_archive---------8-----------------------#2023-07-07

向量化逻辑回归

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 穆拉利·卡沙博伊纳

·

关注 发布于 Towards Data Science ·10 分钟阅读·2023 年 7 月 7 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 fabio 提供,来源于 Unsplash

任何人工神经网络(ANN)算法背后的数学基础可能令人难以理解。此外,用于表示模型在批量训练过程中前向传播和反向传播计算的矩阵和向量操作会增加理解的难度。虽然简洁的矩阵和向量符号是有意义的,但深入这些符号以了解矩阵操作的细节会带来更多的清晰度。我意识到,理解这些微妙细节的最佳方法是考虑一个最基本的网络模型。我找不到比逻辑回归更好的算法来探索底层机制,因为它具备了 ANN 的所有特点,如多维输入、网络权重、偏差、前向传播操作、应用非线性函数的激活函数、损失函数和基于梯度的反向传播。我的博客意图是分享我对逻辑回归模型核心的矩阵和向量操作的笔记和发现。

逻辑回归简要概述

尽管名字里有“回归”,逻辑回归实际上是一种分类算法,而不是回归算法。它通常用于二分类任务,以预测某个实例属于两个类别之一的概率,例如,预测一封电子邮件是否是垃圾邮件。因此,在逻辑回归中,因变量或目标变量被视为分类变量。例如,垃圾邮件用 1 表示,而非垃圾邮件用 0 表示。逻辑回归模型的主要目标是建立输入变量(特征)与目标变量概率之间的关系。例如,给定一封电子邮件的特征作为输入特征集合,逻辑回归模型会找到这些特征与电子邮件是垃圾邮件的概率之间的关系。如果‘Y’表示输出类别,比如电子邮件是垃圾邮件,‘X’表示输入特征,则概率可以表示为 π = Pr( Y = 1 | X, βi),其中 βi 表示包括模型权重‘wi’和偏置参数‘b’在内的逻辑回归参数。实际上,逻辑回归预测给定输入特征和模型参数下 Y = 1 的概率。具体来说,概率 π 被建模为一个 S 形的逻辑函数,称为 Sigmoid 函数,公式为 π = e^z/(1 + e^z) 或等效地 π = 1/(1 + e^-z),其中 z = βi . X。Sigmoid 函数允许在 0 和 1 之间平滑曲线,非常适合于估计概率。本质上,逻辑回归模型在输入特征的线性组合上应用 Sigmoid 函数,以预测 0 和 1 之间的概率。确定实例输出类别的常见方法是对预测概率进行阈值处理。例如,如果预测概率大于或等于 0.5,则该实例被分类为类别 1;否则,分类为类别 0。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

逻辑回归模型示意图 — 由作者创建

逻辑回归模型通过将模型拟合到训练数据上来训练,然后通过最小化损失函数来调整模型参数。损失函数估计预测概率与实际概率之间的差异。用于训练逻辑回归模型的最常见损失函数是对数损失函数,也称为二元交叉熵损失函数。对数损失函数的公式如下:

L = — ( y * ln§ + (1 — y) * ln(1 — p) )

其中:

  • L 代表对数损失。

  • y 是实际的二元标签(0 或 1)。

  • p 是输出类别的预测概率。

逻辑回归模型通过使用梯度下降等技术来最小化损失函数,从而调整其参数。给定一批输入特征及其真实类别标签,模型的训练在多个迭代(称为 epoch)中进行。在每个 epoch 中,模型进行正向传播操作来估计损失,并进行反向传播操作以最小化损失函数并调整参数。所有这些操作都涉及矩阵和向量计算,如下一节所示。

矩阵和向量表示

请注意 我使用了 LaTeX 脚本来创建嵌入在此博客中的数学方程和矩阵/向量表示的图片。如果有人对 LaTeX 脚本感兴趣,请随时联系我;我很乐意分享。

如上图所示,使用二元逻辑回归分类器作为示例,以简化插图。如下所示,矩阵 X 表示‘m’个输入实例。每个输入实例包含’n’个特征,并表示为矩阵 X 中的一列,即输入特征向量,使其成为一个(n x m)大小的矩阵。上标(i)表示矩阵 X 中输入向量的序号。下标‘j’表示输入向量中特征的序号。大小为(1 x m)的矩阵 Y 捕捉了矩阵 X 中每个输入向量的真实标签。模型权重由大小为(n x 1)的列向量 W 表示,其中包含’n’个权重参数,对应于输入向量中的每个特征。虽然只有一个偏置参数‘b’,为了说明矩阵/向量操作,考虑一个大小为(1 x m)的矩阵 B,其中包含‘m’个相同的偏置 b 参数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

正向传播

正向传播操作的第一步是计算模型参数和输入特征的线性组合。如下所示,此矩阵操作的符号表示一个新的矩阵 Z 的计算:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

注意权重矩阵 W 的转置使用。上述矩阵的扩展表示如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述矩阵运算的结果是计算出大小为(1 x m)的矩阵 Z,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

下一步是通过对计算出的线性组合应用 sigmoid 函数来推导激活值,如以下矩阵操作所示。这会生成一个大小为(1 x m)的激活矩阵 A。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

反向传播

反向传播或称为反向传播是一种计算每个参数对最终预测错误或损失的贡献的技术。通过计算损失函数对每个模型参数的梯度来评估各个损失的贡献。函数的梯度或导数是该函数相对于一个参数的变化率或斜率,同时将其他参数视为常数。当在特定的参数值或点上进行评估时,梯度的符号指示函数增加的方向,梯度的大小指示斜率的陡峭程度。如下所示的对数损失函数是一个碗状的凸函数,具有一个全局最小点。因此,在大多数情况下,对数损失函数的梯度相对于参数指向全局最小值的相反方向。一旦评估了梯度,就使用参数的梯度更新每个参数值,通常使用称为梯度下降的技术。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个参数的梯度使用链式法则计算。链式法则使得能够计算由其他函数组成的函数的导数。在逻辑回归的情况下,对数损失 L 是激活‘a’和真实标签‘y’的函数,而‘a’本身是‘z’的 sigmoid 函数,‘z’是权重‘w’和偏置‘b’的线性函数,这意味着损失函数 L 是由其他函数组成的函数,如下所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

利用偏导数链式法则,权重和偏置参数的梯度可以如下计算:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

单个输入实例的梯度推导

在我们回顾作为更新参数的一部分的矩阵和向量表示之前,我们将首先使用单个输入实例推导梯度,以便更好地理解这些表示的基础。

假设‘a’和‘z’表示单个输入实例的计算值,并且真实标签为‘y’,则损失函数相对于‘a’的梯度可以推导如下。请注意,这个梯度是评估链式法则以推导参数梯度所需的第一个量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

给定损失函数相对于‘a’的梯度,可以使用以下链式法则推导损失函数相对于‘z’的梯度:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述链式法则意味着必须推导出‘a’相对于‘z’的梯度。请注意,‘a’是通过对‘z’应用 sigmoid 函数计算得出的。因此,‘a’相对于‘z’的梯度可以通过如下所示的 sigmoid 函数表达式推导出来:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述推导以‘e’为基础,似乎需要额外的计算来评估‘a’相对于‘z’的梯度。我们知道‘a’是在前向传播过程中计算的。因此,为了消除任何额外的计算,上述导数可以完全用‘a’表示,如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

插入用‘a’表示的上述术语,‘a’相对于‘z’的梯度如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现在我们有了损失函数对‘a’的梯度以及‘a’对‘z’的梯度,损失函数对‘z’的梯度可以如下评估:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们在评估损失函数对‘z’的梯度方面已经取得了很大进展。我们仍然需要评估损失函数对模型参数的梯度。我们知道‘z’是模型参数和输入实例‘x’特征的线性组合,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用链式法则,损失函数对权重参数‘wi’的梯度被评估如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

同样,损失函数对‘b’的梯度被评估如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用梯度的参数更新的矩阵和向量表示

现在我们理解了使用单个输入实例导出的模型参数的梯度公式,我们可以将这些公式表示为矩阵和向量形式,以考虑整个训练批次。我们将首先对损失函数对‘z’的梯度进行向量化,其表达式如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述所有‘m’实例的向量形式是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

同样,损失函数对每个权重‘wi’的梯度可以进行向量化。单个实例的损失函数对权重‘wi’的梯度由下式给出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述所有权重在所有‘m’输入实例中的向量形式被计算为‘m’梯度的均值,如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

同样,损失函数对‘b’的梯度在所有‘m’输入实例中的结果是通过如下方式计算的各个实例梯度的均值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

给定模型权重梯度向量和偏置的整体梯度,模型参数将按以下方式更新。如下所示的参数更新基于称为梯度下降的技术,其中使用了学习率。学习率是优化技术(如梯度下降)中使用的超参数,用于控制每次迭代时对模型参数进行调整的步长。有效地说,学习率充当缩放因子,影响优化算法的速度和收敛性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结论

从本博客中说明的矩阵和向量表示可以看出,逻辑回归使得一个基本的网络模型能够理解这些矩阵和向量操作的细微工作细节。大多数机器学习库封装了这些琐碎的数学细节,但却在更高层次上暴露了定义良好的编程接口,如前向传播或反向传播。虽然理解所有这些细微的细节可能不是使用这些库开发模型的必要条件,但这些细节确实揭示了这些算法背后的数学直觉。然而,这种理解肯定有助于将底层数学直觉应用到其他模型,如人工神经网络(ANN)、递归神经网络(RNN)、卷积神经网络(CNN)和生成对抗网络(GAN)。

数据流中的矩阵近似

原文:towardsdatascience.com/matrix-approximation-in-data-streams-7585720e8671

在没有所有行的情况下近似矩阵

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mina Ghashami

·发表于 Towards Data Science ·13 分钟阅读·2023 年 9 月 17 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:unsplash.com

矩阵近似是数据挖掘和机器学习中一个广泛研究的子领域。许多数据分析任务依赖于获得矩阵的低秩近似。例如,降维、异常检测、数据去噪、聚类和推荐系统。本文将深入探讨矩阵近似的问题,以及在数据不完全时如何计算它!

本文内容部分取自我在 斯坦福大学-CS246 课程讲座。希望对你有用。完整内容请见 此处

数据作为矩阵

大多数在网上生成的数据可以表示为矩阵,其中矩阵的每一行是一个数据点。例如,在路由器中,每个通过网络发送的包都是一个数据点,可以表示为所有数据点矩阵中的一行。在零售中,每次购买都是所有交易矩阵中的一行。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:数据作为矩阵 — 作者提供的图像

与此同时,几乎所有在网上生成的数据都是流式性质的;这意味着数据由外部源以我们无法控制的快速速率生成。想象一下用户每秒在 Google 搜索引擎上进行的所有搜索。我们称这种数据为流式数据;因为它像溪流一样源源不断地涌入。

一些典型流式网页规模数据的示例如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:典型的流式网页规模数据的大小 — 作者提供的图像

将流数据视为包含n 行、d 维空间中的矩阵A,其中通常 n >> d。通常 n 是以十亿为单位并不断增加的。

数据流模型

在流模型中,数据以高速到达,一次一行,算法必须快速处理这些项目,否则它们将永远丢失。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:数据流模型 — 图片由作者提供

在数据流模型中,算法只能对数据进行一次遍历,并且需要使用较小的内存进行处理。

秩-k 近似

矩阵A秩-k 近似是一个秩为k的较小矩阵B,使得BA进行准确的近似。图 2 展示了这一概念。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:从A获取更小的草图B — 图片由作者提供

B通常被称为A的草图。注意,在数据流模型中,B会比A小得多,以便适合内存。此外,rank(B) << rank(A)。例如,如果A是一个包含 100 亿文档和 100 万词的术语-文档矩阵,那么B可能是一个 1000×100 万的矩阵;即,少 1000 万行!

秩-k 近似必须“准确”地近似A。虽然准确是一个模糊的概念,但我们可以通过各种误差定义来量化它:

1️⃣ 协方差误差

协方差误差是矩阵 A 的协方差与矩阵 B 的协方差之间差异的 Frobenius 范数或 L2 范数。这个误差在数学上定义如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

协方差误差定义 — 图片由作者提供

2️⃣ 投影误差

投影误差是当A中的数据点被投影到B的子空间时的残差的范数。这个残差范数被测量为 L2 范数或 Frobenius 范数:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

投影误差定义 — 图片由作者提供

这些误差评估了近似的质量;它们越小,近似效果越好。但它们可以小到什么程度呢?

当我们计算这些误差时,我们必须有一个基准来进行比较。在矩阵草图领域,每个人使用的基准是由*奇异值分解(SVD)*创建的秩-k 近似!SVD 计算最佳的秩-k 近似!这意味着它在“协方差误差”和“投影误差”上造成的误差最小。

A的最佳秩-k 近似记作 Aₖ。因此,SVD 引起的最小误差是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最小秩 k 近似误差 — 图片由作者提供

SVD 将矩阵A分解为三个矩阵:

  • 左奇异矩阵 U

  • 奇异值矩阵 S

  • 右奇异矩阵 V

U 和 V 是正交的,意味着它们的列是单位范数且它们彼此正交;即 U 中每两列(V 中也是)之间的点积为零。矩阵 S 是对角矩阵;只有对角线上的条目是非零的,并且按降序排列。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:奇异值分解 — 图片来自作者

SVD 通过取 U 和 V 的前 k 列以及 S 的前 k 项来计算最佳的秩-k 近似:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:SVD 的秩 k 近似 — 图片来自作者

如前所述,以这种方式计算的 Aₖ 在任何秩为 k 或更低的矩阵 B 中具有最低的近似误差。然而,SVD 是一种非常耗时的方法,如果 A 是 n×d,则需要运行时间 O(nd²),并且不适用于数据流中的矩阵。此外,SVD 对稀疏矩阵效率不高;它在计算近似时没有利用矩阵的稀疏性。

❓现在的问题是我们如何以流式方式计算矩阵近似?

流式矩阵近似方法主要有三大类:

1️⃣ 基于行抽样

2️⃣ 随机投影方法

3️⃣ 迭代草图法

基于行抽样的方法

这些方法从相对于良好定义的概率分布的“重要”行中进行抽样。这些方法的不同之处在于它们如何定义“重要性”的概念。通用框架是它们按以下方式构建草图 B:

  1. 它们首先给流式矩阵A中的每一行分配一个概率

  2. 然后他们从A中抽取l行(通常是有放回的)来构建B

  3. 最后,它们将 B 适当缩放,使其成为 A 的无偏估计

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:有放回的行抽样以构建草图 B — 图片来自作者

注意,步骤 1 中分配给行的概率实际上是行的“重要性”。将“重要性”视为与项相关的权重,例如,对于文件记录,权重可以是文件的大小。或者对于 IP 地址,权重可以是 IP 地址发出请求的次数。

在矩阵中,每个项都是一个行向量,其权重是其范数的平方;也称为 L2 范数。有一种行抽样算法根据行的 L2 范数在数据的一次遍历中进行抽样。这个算法被称为“L2 范数行抽样”,其伪代码如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8:L2 范数行抽样算法 — 图片来自作者

该算法以有放回的方式抽样 l = O(k/ε²) 行,并实现以下误差界限:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9:L2 范数行抽样的误差保证 — 图片来自作者

注意,这是一个较弱的误差界限,因为它受限于矩阵 A 的 Frobenius 范数,总体来说可能是一个很大的数值!有一个改进的算法表现更好;我们来看看它。

扩展:有一种变体算法同时采样行和列!它被称为“CUR”算法,并且比“L2-范数行采样”方法表现更好。CUR方法通过从 A 中采样行和列来创建三个矩阵 C、U 和 R。它的工作原理如下:

步骤 1:CUR首先从A中采样几列,每列的采样概率与该列的范数成正比。这形成了矩阵C

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10:CUR 算法步骤 1—— 图片由作者提供

步骤 2:然后CURA中随机抽取几行,每行的抽取概率与该行的范数成正比。这形成了矩阵R

步骤 3:CUR 然后计算 C 和 R 的交集的伪逆。这被称为矩阵 U。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11:CUR 算法步骤 2,3—— 图片由作者提供

最终,这三个矩阵的乘积,即C.U.R,近似于A,并提供了一个低秩近似。该算法在采样*l = O(k log k/ε²)*行和列时达到了以下误差界限。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 12:CUR 误差保证—— 图片由作者提供

注意,与L2-范数行采样相比,这个界限要紧得多。

总结: 行采样方法家族(包括 CUR)通过采样行(和列)来形成低秩近似,因此它们非常直观并形成可解释的近似。

在下一部分,我们将看到另一类数据无关的方法。

基于随机投影的方法

这些方法组的关键思想是,如果将向量空间中的点投影到一个随机选择的适当高维子空间中,则点之间的距离大致保持不变*。*

Johnson-Lindenstrauss 变换(JLT)很好地描述了这一点:d个数据点在任何维度(例如,对于 n≫d 的 n 维空间)中可以被嵌入到大约log d**维的空间中,使得它们的成对距离在某种程度上得以保持。

JLT 的更精确和数学化的定义如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 13:JLT 定义—— 图片由作者提供

有许多方法可以构造一个矩阵 S,以保持成对距离。所有这些矩阵都称为具有JLT 属性。下图展示了一些创建这样的矩阵 S 的常见方法:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

如上图所示,S的一个简单构造是从N(0,1)中抽取独立随机变量作为S的条目,然后将 S 按√(1/r)进行缩放:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 14:JLT 矩阵—— 图片由作者提供

这个矩阵具有 JLT 属性 [6]*,我们用它来设计随机投影方法如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 15:随机投影方法 — 作者提供的图片

注意第二步,它将数据点从高维空间投影到低维空间。很容易证明 [6] 该方法生成了无偏的草图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 16:随机投影提供了无偏的近似 — 作者提供的图片

随机投影方法在设置 r = O(k/ε + k log k) 时能达到以下误差保证。请注意,它们的界限优于行采样方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 17:随机投影误差界限 — 作者提供的图片

有一类与随机投影类似的工作可以实现更好的时间界限。它被称为 哈希技术 [5]。这种方法采用一个每列只有一个非零条目的矩阵 S,而该条目是 1 或-1。它们计算近似值为 B = SA

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

哈希技术 — 作者提供的图片

总结:随机投影方法计算效率高,并且数据无关,因为其计算仅涉及一个随机矩阵 S。相比之下,行采样方法需要访问数据以形成草图。

迭代草图

这些方法在流 A=<a1,a2,…> 上工作,其中每个项目被读取一次,迅速处理且不再读取。读取每个项目时,它们更新草图 B。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 18:迭代草图方法 — 作者提供的图片

该组的最先进方法称为“频繁方向”,基于 Misra-Gries 算法 查找数据流中的频繁项。接下来,我们首先了解 Misra-Gries 算法如何查找频繁项,然后将其扩展到矩阵。

Misra-Gries 算法用于查找频繁项

假设有一个项目流,我们想找到每个项目的频率 f(i)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 19:流中的频繁项计数 — 作者提供的图片

如果我们保持 d 个计数器,我们可以计算每个项的频率。但这不够好,因为在某些领域,如 IP 地址、查询等,唯一项的数量太多了。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 20:用于项频率估计的 d 个计数器 — 作者提供的图片

所以让我们保持 l 个计数器,其中 l≪d。如果流中到达的新项目在计数器中,我们将其计数加 1:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 21:增加项的计数器 — 作者提供的图片

如果新项目不在计数器中且我们有空间,我们为其创建一个计数器并将其设置为 1。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 22:为新项目设置计数器 — 作者提供的图像

但如果我们没有空间容纳新项目(这里的新项目是棕色盒子),我们获得中位数计数器,即位置为l/2的计数器:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 23:从每一个计数器中减去中间计数器。— 作者提供的图像

并从所有计数器中减去它。对于所有变成负值的计数器,我们将其重置为零。所以它变成如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 24:一半计数器为零 — 作者提供的图像

如我们所见,现在我们有空间容纳新项目,所以我们继续处理流🙂。

在流的任何时刻,项目的近似计数是我们迄今为止保留的计数,例如:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 25:估计项目计数 — 作者提供的图像

这种方法会低估计数,因此对于任何项目 i,其近似频率小于或等于其真实频率:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与此同时,它的近似频率是下界的,因为每次我们减少时,最多减少l/2位置计数器的计数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在流中看到n个元素的任何点,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,它提供的错误保证如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Misra-Gries 错误界限 — 作者提供的图像

因此,Misra-Gries 对所有真实频率大于2n/l的项目生成一个非零近似频率。例如,要找到出现超过 20% 的项目,我们必须采取l = 10计数器并运行 Misra-Gries 算法。

频繁方向:Misra-Gries 的扩展

现在,让我们将 Misra-Gries 扩展到向量和矩阵。在矩阵的情况下,流中的项目是d维的行向量。在流中的任何时刻n,所有行一起形成一个n行的高矩阵A*。目标是找到A最重要方向。这些方向对应于A的前几个奇异向量。一个方向越重要,它在数据点中出现的频率就越高,这就是我们称下一个算法为频繁方向 [2,3]的原因。

频繁方向算法的伪代码如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 26:FrequentDirections — 作者提供的图像

如我们所见,该算法以矩阵A和草图大小l作为输入。然后,在第一步(即第一行突出显示的蓝色部分),算法将草图B初始化为空矩阵,具有l行。然后对于流中的每一行A,算法将其插入B中,直到B满为止:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 27:当草图 B 满时 — 作者提供的图像

然后我们计算BSVD;这将产生左奇异矩阵U、奇异值矩阵S和右奇异矩阵V。注意,UV提供了子空间的旋转,因为它们是正交矩阵。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 28: B 的 SVD — 作者提供的图像

然后我们通过从所有奇异值的平方中减去中间奇异值的平方来降低B的秩!注意,这一步类似于 Misra-Gries 中的部分操作,我们从所有计数器中减去中间计数器。从所有奇异值的平方中减去中间奇异值的平方使得一半的奇异值变为零。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 29: 降低 B 的秩 — 作者提供的图像

然后我们将S(奇异值矩阵)乘以V 转置(右奇异矩阵)并将其分配给B。换句话说,我们通过去掉左奇异矩阵U来重构B。这种操作的效果是得到一个新的矩阵B,其一半的行为空。这是好消息,因为它为流矩阵A中的下一行提供了空间。

误差保证:类似于频繁项的情况,该方法具有以下误差保证,其中l是草图大小,k是秩:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FrequentDirections 协方差误差界限 — 作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FrequentDirections 投影误差界限 — 作者提供的图像

比较第二个误差界限与随机投影和行采样的误差界限。注意这是一个更紧凑和更好的误差界限。

实验:实验[2,4]表明,Frequent Directions 算法优于上述讨论的所有其他流式算法。以下是与协方差误差界限相关的实验:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 30: 协方差误差中的实验 — 来自 [2] 的图像

这是关于投影误差界限的实验:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 31: 投影误差中的实验 — 来自 [2] 的图像

本文到此结束。正如我们所见,Frequent Directions 不仅在近似误差中优于所有其他方法,而且使用的空间最少。换句话说,它在实现的误差界限方面是空间最优的。

总结

数据流中的低秩矩阵近似是计算低秩矩阵近似的问题,其中矩阵的行以流式方式到达。这意味着无法一次性访问矩阵的所有行,也不知道矩阵的大小。有三种主要的近似方法类别:行采样——随机投影——迭代草图。虽然第一组方法是最直观的,因为它们采样实际数据点,第二组方法在运行时效率最高。最先进的方法(SOTA)属于第三组,称为频繁方向。该方法基于频繁项估计的旧方法,并且在误差界限方面具有空间最优性。

如果你有任何问题或建议,请随时联系我:

电子邮件:mina.ghashami@gmail.com

LinkedIn:www.linkedin.com/in/minaghashami/

参考文献

  1. 快速蒙特卡罗矩阵算法 I:矩阵乘法的近似,P. Drineas 等,2006

  2. 频繁方向:简单且确定性的矩阵草图

  3. github.com/edoliberty/frequent-directions

  4. 具有保证的改进实用矩阵草图

  5. 输入稀疏时间下的低秩近似和回归

  6. 通过随机投影改进的大矩阵近似算法

GPU 上的矩阵乘法

原文:towardsdatascience.com/matrix-multiplication-on-the-gpu-e920e50207a8?source=collection_archive---------1-----------------------#2023-10-09

如何在 CUDA 中实现最先进的矩阵乘法性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Andy Lo

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 10 月 9 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

“从矩阵乘法中汲取灵感的极简艺术,风格为 vaporwave” —— DALLE-2

这篇博客源于我突然意识到自己对矩阵乘法在 GPU 上如何运作知之甚少。做了这么多机器学习项目,我觉得我应该了解这个在机器学习中最重要的操作是如何工作的:什么是“张量核心”?为什么每个人都说“数据移动是瓶颈”?GPU 实际上能有多快?

为了回答这些问题,我决定必须走出我的 PyTorch 领域,深入探索 CUDA 的深渊。我写了这篇博客来记录我所学到的一切,希望读到这篇文章的任何人都不必像我一样经历挖掘 CUDA 文档/代码的痛苦。

如果我在这段旅程中学到了什么,那就是并发矩阵乘法是困难的。高效的矩阵乘法在很大程度上依赖于你使用的具体硬件和你尝试解决的问题规模。没有一刀切的解决方案。

够了,让我们深入了解吧!

回顾 GPU 架构

让我们回顾一下(NVIDIA)GPU 的工作原理。GPU 通过运行许多线程来实现并行处理。每个线程在一个 CUDA 核心上执行,但在某一时刻,只有一部分线程是活动的,因此可能有比可用的 CUDA 核心更多的线程。每个线程,无论是否活动,都有自己的寄存器

一组 32 个线程称为warp。warp 中的所有线程必须一起执行(或一起处于非活动状态)。在大多数情况下,非活动 warp 的数量远多于活动 warp,而warp 调度器负责选择在特定时间执行哪些 warp。这使得 GPU 能够通过调度其他 warp 在 warp 等待数据时执行,从而隐藏内存访问的延迟。

一组 warp 称为线程块。所有线程块中的 warp 在同一个流处理器(SM)中执行。每个线程块有自己的共享内存,所有线程块中的线程都可以访问。

注意:较新的架构

从 Volta 架构开始,每个线程也有自己的程序计数器和调用栈等。这意味着 warp 中的每个线程可以同时执行不同的指令。

Volta 架构还引入了Tensor Cores,这些核心专门用于解决特定大小的矩阵乘法。每个活动 warp 可以访问一个 Tensor Core。

在最新的 Hopper 架构中,引入了线程块集群的概念,它表示一组线程块。它使用户能够更细粒度地控制线程块的调度,并允许一个线程块的共享内存在同一集群中的其他线程块访问。

并行化矩阵乘法

假设我们想计算:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们说在这种情况下问题的规模是(M, N, K)。为了并行化这个操作,我们可以将AB拆分成更小的矩阵,分别进行矩阵乘法,然后将结果连接起来形成C

具体来说,我们可以按行分割A(即,将M分成大小为M’的块)和按列分割B(即,将N分成大小为*N’*的块),得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以看到C中的每个子矩阵彼此独立,因此我们可以轻松地并行计算每个子矩阵。

实际上,K可能过大,无法直接加载到内存中进行计算。相反,典型的实现也会将K分割成大小为K’的块,迭代每个块,并对部分结果进行累加(求和)。这被称为串行-K归约。(与parallel-K reduction相对,见下节)。从数学上看,这样表示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

注意:Padding

在任何问题大小不能被分区大小整除的情况下,我们需要添加padding。这通常在我们将分区输入(𝐴ᵢ,ₖ 和 𝐵ₖ,ⱼ)加载到低级内存时隐式完成,我们通过添加零确保加载的分区(𝐴ᵢ,ₖ的大小为 M’×K’,𝐵ₖ,ⱼ的大小为 K’×N’)总是“满”的。在将结果写回全局内存时需要特别小心,以避免越界错误。

从高层次看,三层嵌套分区用于在 GPU 上并行化矩阵乘法:

  1. 第一次分区发生在threadblock级别。每个线程块负责计算Cᵢ,ⱼ = Aᵢ Bⱼ

  2. 第二次分区发生在warp级别。线程块级别的问题Cᵢ,ⱼ 进一步被分区,每个 warp 负责计算Cᵢ,ⱼ⁽ᵐⁿ⁾ = Aᵢ⁽ᵐ⁾ Bⱼ⁽ⁿ⁾

  3. 第三次分区发生在instruction级别。有些指令需要特定大小的输入。例如,第二代 Tensor Cores 操作大小为(16, 8, 8)的fp16问题,而在 CUDA 核心上通过标量乘法直接实现则仅操作大小为(1, 1, 1)的问题。因此,warp 级别的问题被进一步分区,使得每个块有适合指令的大小:Cᵢ,ⱼ⁽ᵐⁿ⁾⁽ᵃᵇ⁾ = Aᵢ⁽ᵐ⁾⁽ᵃ⁾ Bⱼ⁽ⁿ⁾⁽ᵇ⁾

我们需要三个分区级别是有充分理由的,正如我们在下一节中将看到的。

数据冗余

矩阵乘法如果我们每次计算时都从全局内存重新获取数据,很容易变成内存瓶颈。关键观察是,许多子输入AᵢBⱼ在不同的子矩阵乘法中被重复使用。例如,Aᵢ需要用于Cᵢ,₁ , Cᵢ,₂ , … 和Bⱼ需要用于C₁*,ⱼ* , C₂*,ⱼ* , … 。如果我们能最小化冗余数据移动并尽可能多地重用加载的数据,就能获得最佳的吞吐量。

在 CUDA 中,有三种用户可访问的内存类型:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

下面是每种内存类型如何使用的高级视图:

  1. 每个线程块将首先从全局内存加载其所需输入到共享内存中。之后对这些数据的访问将由共享内存提供,而不是较慢的全局内存。

  2. 在每个线程块中,每个 warp 将首先从共享内存加载其所需输入到寄存器中。随后对这些数据的访问将直接由快速寄存器提供。

深入细节

线程块级别

在线程块级别,问题被划分为大小为 (M’, N’, K’) 的子问题。因此,每个线程块负责计算 C 的一个片段,记作:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过将子输入 Aᵢ,ₖBₖ,ⱼ 加载到共享内存中来最小化冗余的数据移动。当我们完成 Aᵢ,ₖ Bₖ,ⱼ 的计算后,下一个块 (Aᵢ,ₖ₊₁Bₖ₊₁,ⱼ) 将被加载。

warp 级别

在 warp 级别,子问题进一步划分为大小为 (M’’, N’’, K’’) 的子子问题。因此,每个 warp 负责计算 Cᵢ,ⱼ, 记作 Cᵢ,ⱼ⁽ᵐ ⁿ⁾

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过将子子输入 Aᵢ,ₖ⁽ᵐ ˡ⁾Bₖ,ⱼ⁽ˡ ⁿ⁾ 加载到寄存器中来最小化冗余的数据移动。任何对 Aᵢ,ₖ⁽ᵐ ˡ⁾Bₖ,ⱼ⁽ˡ ⁿ⁾ 的访问warp 内将由快速寄存器提供服务。

注意:在寄存器之间分配数据

值得注意的是,寄存器是线程级的。这意味着寄存器中的输入不能被 warp 中的其他线程访问。如何将 Aᵢ,ₖ⁽ᵐ ˡ⁾ 和 Bₖ,ⱼ⁽ˡ ⁿ⁾ 分配到每个线程的寄存器中,取决于使用的具体指令。NVIDIA 文档中的 Warp Level Matrix Multiply-Accumulate Instructions 对每条指令进行了详细描述。

张量核心级别

为了实际执行矩阵乘法,我们使用 GPU 上的张量核心。我的 GPU (RTX 2060) 具有第二代张量核心,专门解决大小为 (M’’’, N’’’, K’’’) = (16, 8, 8) 的 fp16 问题。因此,我们进一步将 Cᵢ,ⱼ⁽ᵐ ⁿ⁾ 划分为符合指令预期大小的片段:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这里,所有输入已经在寄存器中,因此数据移动开销最小。

注意:张量核心

张量核心操作是warp 级指令,意味着 warp 中的所有线程需要同时执行张量核心指令,协同准备要被一个张量核心消费的数据。

选择分区大小

所以,鉴于我们希望最小化数据移动,我们应该选择尽可能大的分区大小来利用所有的共享内存和寄存器,对吗? 其实并不是这样。

线程块分区大小

从渐近的角度来看,随着问题大小的增加,是的,我们确实希望尽可能使用更多的共享内存和寄存器。然而,对于小问题大小,我们可能会遇到两个问题:

  1. 大的分区大小意味着我们将有更少的线程块。由于每个线程块只能在一个 SM 上执行,这可能意味着我们不能利用所有的 SM。

  2. 对于不能被分区大小整除的问题大小,我们需要为输入添加更多的填充。这会降低效率,因为对有意义的输入计算较少。

典型的实现可能使用分区大小为 (M’, N’, K’) = (128, 256, 32)。

Warp 分区大小

通常,较大的 warp 分区大小意味着会有更少的冗余数据移动,但代价是拥有更少的 warps。拥有过少的 warps 意味着我们将无法隐藏内存访问的延迟(因为当当前 warp 等待数据时,我们可能没有其他 warp 来调度)。

典型的实现可能使用分区大小为 (M’’, N’’, K’’) = (64, 64, 32)。

指令分区大小

这完全取决于你的 GPU 支持什么指令。对于我的 RTX 2060,fp16 Tensor Core 矩阵乘法(带有 fp16 累积)的 ptx 指令是 mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16,它期望输入的大小为 (16, 8, 8)。

进一步优化

上述技术在问题规模较大时可以使我们接近 GPU 的理论峰值性能。然而,对于较小的问题规模,它们的效率不是很高。进一步提高矩阵乘法性能的两种常见技术是 并行-K 减少软件流水线

并行-K 减少

MN 较小时,我们可能只有少量的线程块。例如,如果 (M’, N’) = (128, 256) 且原始问题规模具有 M ≤ 128 和 N ≤ 256,我们将只有一个线程块,因此我们只利用了 GPU 计算能力的一小部分!(例如,我的 RTX 2060 有 30 个 SM,因此为了最大化利用率,我们希望至少有 30 个线程块。)

K 较大(尽管 MN 较小)的情况下,我们可以通过进行 并行-K 减少 来利用更多的并行性。回想一下,在串行-K 减少中,每个线程块遍历以下和:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

并将中间结果累积到 Cᵢ,ⱼ。在并行-K 减少中,我们将每个线程块分配为仅计算 一个和的元素(即 Aᵢ,ₖ Bₖ,ⱼ)。这使我们可以将线程块的数量增加 K/K’ 倍,从而利用更多的 SMs。

需要注意的是,现在我们需要 分配更多内存 来存储每个线程块的结果,并且 调用第二个内核 来对部分结果进行最终的归约以获得 Cᵢ,ⱼ

软件流水线

通常,CUDA 通过调度其他 warps 执行来隐藏内存访问的延迟,而当一个 warp 等待数据时。这要求我们拥有足够的 warps 来掩盖延迟。

然而,在进行 GEMM 时,warps 的数量通常相对较少。这是因为 warp 的数量受到“每个线程块的可用寄存器数除以每个 warp 需要的寄存器数”的限制。对于矩阵乘法,我们使用大量寄存器以最大化数据重用。因此,我们可能没有足够的 warps 来掩盖延迟。

“累加器元素通常占用至少一半的线程总寄存器预算。” — CUTLASS 文档

为了缓解这一效果,我们可以使用软件流水线。本质上,我们可以(手动)使用特殊指令异步预加载下一个迭代的输入。在输入被加载的同时,我们可以继续在当前迭代上进行计算。其总结如下图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自CUTLASS的软件下载流水线

这得益于 GPU 的特性:它像任何现代 CPU 一样,可以在没有数据依赖关系的情况下流水化内存访问和算术操作。这被称为指令级并行

矩阵乘法的实际应用

如果你想了解这些概念如何在实际实现中结合起来,可以查看我用 CUDA 从零开始训练 MNIST 的实现。在那里,我使用 CUDA 训练了一个多层感知器,并在隐藏层大小为 128 时实现了比优化后的 PyTorch 快 6 倍

[## GitHub - andylolu2/cuda-mnist

通过在 GitHub 上创建账户来参与 andylolu2/cuda-mnist 的开发。

github.com

参考资料

1. CUTLASS 文档

2. CUDA 文档

3. CUTLASS 示例

通过选择最佳图表:网络图、热图还是桑基图来最大化你的洞察力?

原文:towardsdatascience.com/maximize-your-insights-by-choosing-the-best-chart-network-heatmap-or-sankey-d9b4165d7f16

美丽的可视化是很棒的,但为了最大化其可解释性,你需要仔细选择图表。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Erdogan Taskesen

·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 8 月 13 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源于 David PisnoyUnsplash

可视化是数据分析中的重要部分,因为它可以将数据转化为洞察力,并帮助你讲述故事。在这篇博客文章中,我将重点介绍网络图、热图和桑基图。这些图表使用相同的输入数据,但我们应该记住,它们是为特定目标设计的,因此其可解释性可能有所不同。 我将描述网络图、热图和桑基图之间的差异、应用,并通过实际示例展示它们的可解释性。 所有示例均使用 D3Blocks library 在 Python 中创建。

输入数据用于 热图和桑基图。

作为数据科学家,一个常见但重要的任务就是制作图表。有时这些图表作为理智检查,有时它们会出现在演示文稿中,成为故事的基础。尤其是在后者的情况下,我们的目标是将复杂的信息转化为逻辑的图形可视化。

创建图表就像摄影一样。你想捕捉到讲述故事的风景。

然而,决定使用哪种图表并不总是容易的,因为尽管图表可能有类似的输入,但它们设计用于描述场景的特定部分。三个图表的输入需要 ***source*** ***target****** 和*** ***weight*** 信息。下面展示了一个小示例。它描述了变量(或节点)之间的连接及其强度。换句话说,Penny 与 Leonard 之间的连接强度为 5。第二个节点的名称再次是 Penny,它也与 Amy 连接,但强度略低,值为 3,依此类推。

# Source node names
source = ['Penny', 'Penny', 'Amy', 'Bernadette', 'Bernadette', 'Sheldon', 'Sheldon', 'Sheldon', 'Rajesh']
# Target node names
target = ['Leonard', 'Amy', 'Bernadette', 'Rajesh', 'Howard', 'Howard', 'Leonard', 'Amy', 'Penny']
# Edge Weights
weight = [5, 3, 2, 2, 5, 2, 3, 5, 2]

节点即为 sourcetarget 的联合名称集。边表示 sourcetarget 之间的关系。图表可以处理有向或无向的源-目标值关系。weight 值描述了关系的强度。值得注意的是,source-target-weight 的值也可以是(稀疏)邻接矩阵的形式,其中列和索引是节点,值大于 0 的元素被认为是边。这种形式通常用于热力图的创建,但它本质上包含相同的信息。在下一节中,我将描述这些信息如何转换为图表。

# Install d3blocks for the following examples
pip install d3blocks

# Install cluster evalation (required for the heatmaps)
pip install clusteval
# Import
from d3blocks import D3Blocks
# Initialize
d3 = D3Blocks()
# Convert
adjmat = d3.vec2adjmat(source, target, weight)
# Print
print(adjmat)

# target      Amy  Bernadette  Howard  Leonard  Penny  Rajesh  Sheldon
# source
# Amy         0.0         2.0     0.0      0.0    0.0     0.0      0.0
# Bernadette  0.0         0.0     5.0      0.0    0.0     2.0      0.0
# Howard      0.0         0.0     0.0      0.0    0.0     0.0      0.0
# Leonard     0.0         0.0     0.0      0.0    0.0     0.0      0.0
# Penny       3.0         0.0     0.0      5.0    0.0     0.0      0.0
# Rajesh      0.0         0.0     0.0      0.0    2.0     0.0      0.0
# Sheldon     5.0         0.0     2.0      3.0    0.0     0.0      0.0

图表以不同的方式翻译数据。

网络图、桑基图和热力图 各有其特性,因此可以以不同的方式呈现相同的数据。简要总结如下:

  • 网络图 直观地展示实体之间的关系,其中节点代表实体,边代表它们之间的关系。优点:这种图表适用于理解复杂的行为,并且你也需要知道(一些)实体之间的确切关系。缺点 是当数据集较大时,图表会变得混乱且难以阅读。然而,通过使用不同的布局或按权重拆分网络,它可以再次变得有效。有关如何使用交互功能的更多详细信息,请阅读以下博客 [1]。

## 使用 Python 创建美观的独立交互式 D3 图表

应用于 D3 力导向网络图

towardsdatascience.com

  • 热力图 有效地可视化变量之间关系的强度或大小,其中值由(不同的)颜色表示。 优点: 这种类型的图表对识别具有多个变量的大数据集中的模式和趋势非常有用。当网络图变得复杂时,热力图可以提供结构化的见解。 缺点: 你很容易失去对个别关系的跟踪。然而,当你提供清晰的标签并对行和/或列进行聚类时,变量之间的关系可以更容易解读。

  • 桑基图 可以 直观地显示数据或资源在实体之间的流动,其中节点代表不同的阶段或实体,链接代表数据或资源在它们之间的流动。 优点: 对于理解复杂的过程或系统以及识别优化或改进的领域非常有用。 缺点 是当阶段或实体过多时可能会变得难以阅读。有关更多详细信息,请阅读以下博客[2]:

## 使用 Python 在 d3js 中创建美丽的桑基图的实践指南。

桑基图是一种出色的方法,通过查看个别项目在各状态之间的流动,可以发现最突出的贡献。

towardsdatascience.com

网络图、热力图和桑基图的应用

网络图、热力图和桑基图可以使用D3Blocks 库创建。有关 D3blocks 的更多详细信息,请参阅[3]:

## D3Blocks: 用于创建互动和独立的 D3js 图表的 Python 库

创建基于 d3 javascript(d3js)图形但可以用 Python 配置的交互式和独立图表。

towardsdatascience.com

网络图、热力图和桑基图的应用各不相同。网络图常用于可视化社交媒体网络,例如 Twitter 帖子或 Facebook,其中节点代表用户,边代表他们之间的关系。热力图用于许多数据点较多的应用场景,如股票价格、基因表达数据和气候数据等。桑基图用于可视化流量,例如客户旅程数据中不同阶段(例如,网站访问、注册、购买)。另一个例子是能源流动或供应链流动,涉及能源的不同来源和用途或供应链的不同阶段(例如,原材料、制造、分销)。

网络图、热力图和桑基图的实操比较

让我们加载能源数据集[4],并比较这三种图表的可解释性。能源数据集包含 48 个节点和 68 个加权(无向)关系,我们可以可视化能源流动。你会发现网络图使理解角色之间的确切关系变得容易。另一方面,热力图展示了所有关系的整体视图,而桑基图则显示了角色之间的流动。例如,在这个数据集中,John似乎是一个重要的角色,在网络图中占据中心点,并且有许多流动进出。你可以使用以下代码块重现这些结果:

# ######################
# Create network graph #
# ######################

# Load library
from d3blocks import D3Blocks
# Initialize
d3 = D3Blocks()
# Load energy data sets
df = d3.import_example(data='energy')

# Create the network graph
d3.d3graph(df, cmap='Set2')
# Extract the node colors from the network graph.
node_colors = d3.D3graph.node_properties

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 D3Blocks 创建的 D3graph。交互式 HTML 版本可以在我的Github 页面上查看。

热力图的聚类使用clusteval 库[5]创建。该库使用聚类评估指标,如轮廓系数、DB 指数或 DBscan,来确定最优的聚类切割点。默认值可以如代码部分所示进行更改。数据已进行 z-score 标准化。

# ################
# Create Heatmap #
# ################

# Initialize
d3 = D3Blocks()
# Load Energy data sets
df = d3.import_example(data='energy')

# Create the default heatmap but do hide it. We will first adjust the colors based on the network colors.
d3.heatmap(df, showfig=False)

# Update the colors of the network graph to be consistent with the colors
for i, label in enumerate(d3.node_properties['label']):
    if node_colors.get(label) is not None:
        d3.node_properties['color'].iloc[i] = node_colors.get(label)['color']

# The colors in the dataframe are used in the chart.
print(d3.node_properties)

# Make the chart
d3.show(showfig=True, figsize=[600, 600], fontsize=8, scaler='zscore')

# You can make adjustments in the clustering:
d3.heatmap(df, cluster_params={'evaluate':'dbindex',
                               'metric':'hamming',
                               'linkage':'complete',
                               'normalize': False,
                               'min_clust': 3,
                               'max_clust': 15}) 

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 D3Blocks 创建的热力图。交互式 HTML 版本可以在我的Github 页面上查看。

# ###############
# Create Sankey #
# ###############
# Initialize
d3 = D3Blocks()

# Create sankey graph
d3.sankey(df, showfig=True)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 D3Blocks 创建的桑基图。交互式 HTML 版本可以在我的Github 页面上查看。供应在左侧,需求在右侧。链接显示了能源在被消耗或丢失之前是如何转换或传输的。

我们还可以调整节点的颜色以与其他图表匹配。

# Initialize
d3 = D3Blocks(chart='Sankey', frame=True)
# Load data set
df = d3.import_example(data='energy')

# Set default node properties
d3.set_node_properties(df)

# Update the colors of the network graph to be consistent with the colors
for i, label in enumerate(d3.node_properties['label']):
    if node_colors.get(label) is not None:
        d3.node_properties['color'].iloc[i] = node_colors.get(label)['color']

# The colors in the dataframe are used in the chart.
print(d3.node_properties)
#   id                               label    color
#    0                  Agricultural_waste  #66c2a5
#    1                      Bio-conversion  #66c2a5
#    2                              Liquid  #e5c494
#    3                              Losses  #e78ac3
#    4                               Solid  #66c2a5
#    5                                 Gas  #fc8d62
#    ...

# Create edge properties
d3.set_edge_properties(df, color='target', opacity='target')
# Show the chart
d3.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 D3Blocks 创建的桑基图。节点颜色与其他两个图表匹配,边缘颜色自动设置。

通过交互式图表提高可解释性

使用交互式图表可以帮助增强解释和/或突出感兴趣的区域。一种方式是平移和缩放功能,这在 d3graph 图表中也得到了演示。另一种获得更多洞察的方法是使用自动创建的滑块根据边的强度来拆分网络。这使我们能更快地理解节点之间的关系。

使用堆叠方法构建你的故事

你可能已经从前面的部分注意到,没有一种最佳的图表适用于所有用例。通常,使用堆叠方法从不同的角度和/或深度描述数据是有益的。例如,你可以开始使用热力图来展示整体的弱关系和强关系。然后,选择一个感兴趣的集群或区域,使用网络图更深入地分析精确关系。最后,如果你现在需要描述节点之间的流动和依赖关系,可以使用 Sankey 图。

摘要

总之,选择正确的可视化技术对于有效洞察数据集至关重要。图表的选择取决于数据集的类型以及研究问题。在本博客中,我们比较了 3 种流行的可视化图表:网络图、热力图和 Sankey 图,并使用了实际示例。需要注意的是,创建图表是分析中的重要部分。如果你特别关注解释,故事情节可以更有效地传达给观众。

保持安全。保持冷静。

干杯,E.

如果你觉得这篇文章有帮助,欢迎 关注我 ,因为我会写更多关于可视化技术的文章。如果你考虑订阅 Medium 会员,可以通过我的 推荐链接支持我的工作。价格相当于一杯咖啡,但允许你每月无限制阅读文章!

让我们联系吧!

参考文献

  1. E. Taskesen, 使用 Python 创建美丽的独立交互式 D3 图表,Medium(Towards Data Science),2022 年 2 月

  2. E. Taskesen, 动手指南:用 Python 在 d3js 中创建美丽的 Sankey 图表,Medium(Towards Data Science),2022 年 10 月

  3. E. Taskesen, D3Blocks: 用于创建交互式和独立 D3js 图表的 Python 库,Medium(Towards Data Science),2022 年 9 月

  4. 能源与气候变化部汤姆·康塞尔开放政府许可证 v3.0

  5. E. Taskesen, 从数据到聚类:你的聚类何时足够好? Medium, 2023 年 4 月

随机变量参数的最大似然估计

原文:towardsdatascience.com/maximum-likelihood-estimation-4a1a866dfa70

通过观察数据的最高似然性来建模参数

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Roman Paolucci

·发表于 Towards Data Science ·11 分钟阅读·2023 年 12 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Francesco Ungaro 摄影: www.pexels.com/photo/blue-and-white-abstract-painting-1912832/

概率和统计学中的概念由于高水平的数学、糟糕的符号表示以及随机变量和数据的交织而显得有些难以捉摸。本文阐明了在估计量、估计值、偏差和方差,以及最大似然估计方法的背景下,随机变量和数据之间的关系。

本文将分为以下几个部分。

  • 概率质量和分布函数的参数

  • 估计量

  • 估计值

  • 偏差和方差

  • 最大似然估计

  • 无耻的推广 Quant Guild

概率质量和分布函数的参数

本文不是关于常见随机变量的入门文章(这篇文章是为了这个目的)。我建议你阅读那篇文章,或在继续之前具备基本概率(公理、质量/分布函数等)的扎实基础。

让我们以医院接收病人的例子来讨论。

假设我们是一名医院风险管理人员、高级医生统计师、数据科学护士(我确实不知道谁会负责这个),我们想估计没有足够的病床来接纳病人的风险。将某一天检查入住的病人数建模为泊松随机变量是合理的。也就是说,我们假设每天检查入住医院的病人数服从泊松分布。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

X1、2、…、n天住院的病人数量。

我现在将简化问题空间,以免出现愤怒的评论:

  • 假设所有住院病人当天都会出院

  • 假设每天入住的病人彼此独立

这实际可行吗?可能不行——但它确实能帮助我们以简化的方式前进。这些是建模风险时需要考虑的事情,你的假设如何影响模型。

在这些假设下,我们很可能会低估风险。为什么?因为人们总是会在同一天出院吗?绝对不会,所以在这种情况下我们高估了床位的可用性。独立性呢?人们会互相传染吗?我在这里偏题了,因为这不是一篇关于假设和违背假设的文章。

好的,考虑到所有给定的假设,我们知道病人遵循一个泊松分布,其概率质量函数如下。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

因此,对于某一天的病人数量 a,我们可以找出看到该数量病人的相关概率。

此外,假设我们只有 100 张床位,根据历史病人数据,我们如何估计无法接纳病人的概率?我们将在回答一个更紧迫的问题后回到这个想法。

特别敏锐的人,或者那些稍微关注一点的人,会问出以下某种变体的问题

我们如何估计参数 lambda?我们选择什么值?

啊,是的,现在我们可以讨论估计量了。

估计量

一个估计量是任意一组随机变量的函数。重要的是要注意,估计量是一个函数,该函数在一组数据上给出估计值。请允许我详细说明。

以下方程是一个估计量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

希望以函数形式书写有助于传达要点。这个函数,即样本均值,是总体均值的一个估计量。请注意,这这是一个任意一组独立且服从标准正态分布的随机变量的函数。

让我们对估计量应用期望算子

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

我们可以看到,这个估计量的期望值为零,这也与给定标准正态随机变量集的总体参数一致。请注意,我们在这里有总体参数,在我们上面的医院示例中,我们缺少总体参数 lambda,因此它必须被估计。稍后会详细讨论。

以下也是一个估计量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

其中 zeta 是一个任意常数。

我们可以应用相同的期望算子

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由作者生成的图像

期望值现在是 zeta,注意这与总体参数不一致。

我们这里有一个有趣的情况。我们知道总体参数,因此自然希望我们的估计量的期望值能够得到控制我们数据的总体参数。直观上,这是第一个没有 zeta 的估计量。

但如果我们不知道总体参数呢?如果我们不知道总体参数,不添加像 zeta 这样的任意常数是否是最优的?

再次,我们将很快在偏差和方差的背景下回到这个概念。

估计

这一部分很简单,我保证会很简短。估计是数据的一个函数,其中函数是给定的估计量。

需要注意的是,这些估计有其自身的分布,因为它们是我们提供给估计量的数据的函数。

假设我们有来自标准正态分布的以下实现随机变量

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由作者生成的图像

这是数据没有任何不确定性。我们知道支配分布,并且知道每个z的值。根据这些数据,我们可以构造对总体均值的估计(我们刚好知道这个均值)。这是通过在实现的数据上应用g函数得到的。

这里有一个 Python 脚本来完成这个任务。

由作者生成的代码

Population Mean Estimate: -0.016672307230033958

注意,在这种情况下,n=1000,我们构造了一个总体均值的估计(我们知道这个均值是零),结果非常接近。每一个新的数据样本将产生不同的估计——然而,我们知道在期望值上,这些估计将与总体均值对齐。

这就是当说估计量遵循它们的自身分布时的含义。如果我们记录一系列 n=1000 个样本并创建直方图,我们实际上会观察到样本均值的分布

这里涉及大数法则和中心极限定理的影响——如果你对这些概念在完全不同的背景下的直观解释感兴趣,请参见以下视频。

由作者生成的视频

我们现在可以回到文章前面提出的问题。

在我们的医院例子中,我们需要估计参数,我们知道可以使用估计量来估计该参数。

我们面临的真正问题是,我们应该使用什么估计量,以及如何评估这个估计量?

偏差和方差

偏差指的是估计量对被估计参数的期望偏离。这听起来比较复杂,我们来看看数学公式。

假设我们有一系列独立且同分布的正态随机变量N。我们可以为总体均值 theta 构造一个估计。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

正式来说,偏差是估计量的期望与真实值之间的差异。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

在这种情况下,tau 等于总体均值 mu,因为这是我们试图估计的真实值。现在我们找出我们的估计量的期望值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

因此

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

在这里,我们发现估计量 theta 对于总体均值是无偏的。

在以下情况下,出现了一个必要的问题:

一个具有零偏差和大量方差的估计量是否比一个具有小偏差和极低方差的估计量更好?

哪个估计量更好?如果有一个考虑了偏差和方差的度量就好了…

均方误差 (MSE)

均方误差考虑了估计量与被估计值之间的期望(平方)距离。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

MSE 实际上简化为方差与偏差的平方的组合。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

我们现在有了回答总体问题的方法:我们应该选择哪个估计量? 也就是说,我们可以比较估计量之间的 MSE,以确定在偏差和方差的背景下哪个可能更好。

在机器学习的背景下,我们面临着偏差-方差权衡的著名问题。当模型有更多的预测变量时,它们往往对训练数据的估计误差(偏差)较低——然而,在拟合新的训练样本时,这些估计的方差往往会更大。权衡指的是同时最小化偏差和方差的问题。

或许我们将在另一篇文章中更深入地探讨这个话题。现在,让我们最后讨论如何利用最大似然估计构造一个有效的估计量。

最大似然估计

目前为止在本文中所花费的时间为为什么最大似然估计有用奠定了基础。让我们回顾一下。

  • 我们有一个有 100 张床位的医院,并且有关于每天检查和出院患者数量的数据,我们想用泊松分布来建模这个随机变量。

  • 我们知道泊松分布由 lambda 参数化,但我们应该如何估计这个参数呢?我们应该选择什么估计量?

  • 我们知道我们希望选择的估计量在偏差和方差的背景下具有较低的 MSE,但我们如何凭空构造这样的估计量呢?

幸运的是,我们回顾中的最后一个问题可以通过最大似然估计的方法来回答。MLE 在估计量中具有几个理想的特性,例如

  • MLE 提供渐近零偏差

  • MLEs 在其数据的统治分布下是渐近正态的

  • MLEs 提供的方差对所有渐近正态估计量来说是尽可能低的

这里的渐近指的是给定数据集的样本量的任意增加。

但是MLE到底是什么?对于我们估计的任何参数,我们都是在寻找使观察到的数据的似然性最高的估计。在随机变量的背景下进行这种操作将得到 MLE 估计量,将该估计量应用于数据将得到 MLE 估计。

现在让我们推导泊松随机变量的最大似然估计量。之后,我们可以使用这个估计量来估计 lambda。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

在这里,我们可以使用之前假设的独立性概念,找到多变量函数作为泊松随机变量的单独概率质量函数的分解乘积。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

我们通常使用对数似然,因为它使计算更简单。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

不过,不要混淆,以下解法是等效的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

现在让我们代入 theta 的概率质量函数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

利用对数的性质,我们可以将乘积分解为求和。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

我们现在希望通过选择一个合适的 lambda 值来最大化这个概率。这成为了一个微积分中的优化问题,我们对 lambda 求导并将方程设为零——求解 lambda。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者生成的图像

真棒!总体参数 lambda 的 MLE 是样本均值。使用 Python,我将展示如何将这一切结合起来。

  • 给定一组假设遵循泊松分布的每日患者数据,使用 MLE 计算总体参数 lambda 的估计值

  • 使用 MLE 估计量建模泊松分布

  • 计算在特定一天内没有足够床位容纳患者的概率

  • 将其与已知的 lambda 进行比较(这里我们可以知道,因为这是模拟数据)以查看 MLE 估计的效果如何

作者生成的代码

MLE estimated probability of not having enough beds: 0.049382812868119186
Actual probability of not having enough beds: 0.0493453265906304

就是这样,这是对我们人口参数的相当准确的估计!这就是最大似然估计的方法。

结论

概率与统计的基础研究涵盖了随机变量的基本假设和功能。然而,抽象的研究往往让学生感觉缺少了什么:应用于实际问题以指导决策

在这里,我们讨论了如何基于泊松随机变量的理论和假设构建一个基础概率模型,并结合观察到的数据来分析假设被违反时可能的概率高估和低估。

如前所述,这个基础概率模型的一个关键假设是日常就医者之间的独立性,但这可能并不完全准确。以近期的证据为例,请回忆一下在 COVID-19 大流行期间,医院床位如何逐日变得越来越紧张。在这种情况下,这一过程的自协方差结构很可能是非零的,这意味着日常入住之间存在某种相关性。此外,我们的所有就医者不可能在同一天全部离院。

基于今天的入住情况来建模明天的入住情况可能更为合理。利用这个框架,我们现在可以将原始问题视为一个随时间演变的依赖系统,其中的未来状态仅依赖于现在的状态。幸运的是,我们有方法处理这样的依赖过程,通过将其视为马尔可夫链进行研究,这将是未来文章的主题——敬请关注!

Quant Guild

有兴趣提升你在量化金融、数学、统计学、数据科学、机器学习和人工智能方面的知识吗?

在 Quant Guild 查看我们的课程,使用代码 QGMEDIUM 享受 50% 折扣

## Introduction to Python

Python 已经成为学术界和从业者的首选语言…

## Algorithmic Trading System Development ## Introduction to Python

一个为那些有兴趣构建自己算法交易系统的人员设计的简明课程。

quantguild.com

查看我们的免费资源!

YouTube

YouTube [## Quant Guild

divitiae et educatione

www.youtube.com](https://www.youtube.com/@QuantGuild?source=post_page-----4a1a866dfa70--------------------------------)

GitHub

[## Quant-Guild

Quant-Guild 目前有 4 个可用的代码库。关注他们在 GitHub 上的代码。

github.com](https://github.com/Quant-Guild?source=post_page-----4a1a866dfa70--------------------------------)

Medium

[## Quant Guild

财富与教育

medium.com](https://medium.com/quant-guild?source=post_page-----4a1a866dfa70--------------------------------)

Discord

[## 加入 Quant Guild Discord 服务器!

查看 Quant Guild 社区的 Discord——与其他 1 名成员一起闲逛,享受免费的语音和文字聊天。

discord.gg](https://discord.gg/MJ4FU2c6c3?source=post_page-----4a1a866dfa70--------------------------------)

非常感谢阅读!希望你喜欢这篇文章——如果有任何问题,请随时留言或随时联系:roman@quantguild.com

除此之外,我们下次见!

RMP

五月刊:城市空间的数据

原文:towardsdatascience.com/may-edition-the-data-of-urban-spaces-815831aaf749?source=collection_archive---------8-----------------------#2023-05-03

月刊

数据如何帮助我们理解城市

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 TDS 编辑

·

关注 发布于 Towards Data Science ·4 分钟阅读·2023 年 5 月 3 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Lennart Jönsson 提供,来源于 Unsplash

有时候很难不将城市视为我们需要解决的问题集合:住房、食品安全和环境影响都会浮现在脑海中。今年早些时候,我们强调了数据专业人士当前如何通过利用气候和地理空间数据的创新方法来应对这些挑战。

城市远不止其痛点的总和:它们反映了社会价值(善的、恶的和模糊的),激发创造力和联系,并作为不断演变的文化的活跃存储库。

本月,我们邀请您探索城市空间的丰富性和复杂性。从公交车等待时间到我们穿越的街道名称,我们选择的作者们将他们的数据技能应用于城市生活的具体片段,使我们能够更深入地理解城市环境的运作方式。

在我们深入探讨之前,我们想感谢您一如既往地支持我们发布的工作。如果您希望做出有意义的贡献,考虑成为 Medium 会员;如果您是符合条件国家的学生,现在可以享受实质性折扣

TDS 编辑

TDS 编辑精选

  • 街道名称中的隐藏模式:数据科学故事 [第一部分] (2023 年 1 月,6 分钟)

    我们很少质疑谁应该以其名字命名街道,谁可能在这一过程中被排除。Dea Bardhoshi的迷人数据分析项目是对这种常见漠视形式的有效解药:它研究了阿尔巴尼亚地拉那街道名称的性别分布,并揭示了许多有趣的发现(有些比其他的更可预测)。

  • 德国住房租赁市场:使用 Python 进行探索性数据分析 (2023 年 4 月,27 分钟)

    随着全球许多城市的住房危机持续存在,围绕可靠数据构建解决方案导向的对话至关重要。Dmitrii Eliuseev对德国租赁市场的深入分析提供了一条强有力的分析路线图,可以服务于租户和政策制定者。

  • 使用 GeoPy 和 Folium 绘制黑人拥有的企业地图 (2021 年 1 月,5 分钟)

    来自边缘化社区的城市居民经常面临长期的排斥和隔离历史。Avonlea Fisher的项目旨在解决 COVID-19 对波士顿黑人拥有企业产生的巨大经济影响;两年后,它仍然可以激励数据科学家寻找对超本地问题的超本地答案。

  • 公交车在哪里?GTFS 将告诉我们! (2023 年 1 月,15 分钟)

    Leo van der Meulen的端到端教程的基本前提很难反驳:“公共交通与开放数据的结合具有巨大的潜力。” 将数据洪流转换为一个面向用户的界面,以告知乘客下一班公交车何时到达需要付出相当大的努力,但这个过程既有趣又可能扩展到其他背景和地点。

  • 数字文本分析:荷兰语地区的街头诗歌 (2021 年 11 月,5 分钟)

    地理空间数据、文本分析和诗歌在 Emma-Sophia Nagels的帖子中结合在一起,该帖子追踪了在荷兰及荷兰语区比利时编制、绘制和分析公开展示的诗歌的过程。

  • 通过地理空间关联规则挖掘发现便利店位置模式 (2023 年 3 月,7 分钟)

    东京的便利店可谓传奇:无处不在,随时可达,且充满了世界上最好的一些小吃。Elliot Humphrey超越了表面,试图检测商店位置中的模式,以推测出一个看似随意甚至混乱的现象背后的商业策略。

原创特色

探索我们最新的资源和阅读推荐。

  • 音频机器学习的新前沿

    我们精心策划了一系列强大的文章,涵盖了新模型、AI 接口和应用程序的崛起,这些使得处理音频和音乐的工作变得更加高效。

  • “最佳实践”究竟意味着什么? 我们收集的关于工作流程优化的亮点帖子,从更好的绘图到更有效的实验。

热门帖子

如果你错过了,这里是上个月在 TDS 上最受欢迎的一些帖子。

  • 零 ETL、ChatGPT 与数据工程的未来Barr Moses

  • 时间序列预测:深度学习 vs 统计学 — 谁胜出? by Nikos Kafritsas

  • 6 Python 最佳实践,区分高级开发者和初学者的关键 by Tomer Gabay

  • 如何将公司文档转变为可搜索的数据库:使用 OpenAI by Jacob Marks, Ph.D.

  • 你需要了解的 4 种自主 AI 代理 by Sophia Yang

  • Pandas 2.0 有哪些新特性? by Jeff Hale

我们非常高兴在三月份迎来了新一批 TDS 作者——他们包括Matt CollinsKrzysztof PałczyńskiColin HorganVictor GraffDr. Roi YehoshuaMark ChenBernardo FurtadoToon BeertenPeng QianEdozie OnyearugbulemWillem KoendersAaron Master和 Doron Bergman、Lingjuan LyuJacob Marks, Ph.D.Anthony MensierFrancisco Caio Lima PaivaAbhi SawhneyChris MauckMassimiliano CostacurtaLee Vaughan以及Davide Caffagni等人。如果你有有趣的项目或想法要与我们分享,我们非常愿意听取你的意见!

下个月见。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值