TowardsDataScience 2023 博客中文翻译(二十)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

深入探讨统计期望的科学

原文:towardsdatascience.com/a-deep-dive-into-the-science-of-statistical-expectation-9dc0f80bd26

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

多佛白崖 (CC BY-SA 3.0)

我们如何形成对某件事的期望,这种期望的含义,以及产生这种含义的数学原理。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Sachin Date

·发表于Towards Data Science ·29 分钟阅读·2023 年 6 月 17 日

这是 1988 年的夏天,我第一次踏上了船。这是一艘从英格兰多佛出发、驶往法国加来乘客渡轮。我当时并不知道,但我正赶上了渡轮穿越英吉利海峡的黄金时代的尾声。这正是在廉价航空公司和英吉利海峡隧道几乎终结我仍认为的最佳旅行方式之前。

我曾期望渡轮看起来像我在儿童书籍中看到的许多船只之一。然而,实际看到的是一座异常庞大、闪闪发光的白色摩天大楼,带有小小的方形窗户。而且,这座摩天大楼似乎由于某种令人困惑的原因横着放置。从码头的视角来看,我看不到船体和烟囱。我看到的只是它长而平坦、带窗户的外观。我看到的是一座横向的摩天大楼。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Martin提供,Unsplash

回想起来,用统计学的语言来重新审视我的经历是颇为有趣的。我的大脑从我看到的船只图片的数据样本中计算出了期望的渡轮形状。但我的样本对于整体而言毫无代表性,这使得样本均值同样不代表总体均值。我是在用一个严重偏颇的样本均值来解码现实。

晕船

这次跨越英吉利海峡的旅行也是我第一次晕船。人们说,当你晕船时,应该走到甲板上,呼吸新鲜凉爽的海风,盯着地平线。对我来说唯一有效的办法是坐下来,闭上眼睛,喝着我最喜欢的汽水,直到我的思绪慢慢从搅动我胃部的痛苦恶心中脱离。顺便说一下,我并没有慢慢脱离本文的主题。我会很快进入统计学的内容。在此期间,让我解释一下你为什么会在船上生病的原因,以便你能看到与当前主题的联系。

在你大多数的生活日子里,你不会在船上摇晃。在陆地上,当你将身体倾斜到一侧时,你的内耳和身体的每一块肌肉都会告诉你的大脑你在倾斜。是的,你的肌肉也在和你的大脑沟通!你的眼睛热切地支持所有这些反馈,你也就安然无恙。然而在船上,眼睛和耳朵之间这段和谐的契约被打破了。

在船上,当海洋使船倾斜、摇晃、摆动、滚动、漂移、起伏或其他任何情况时,你的眼睛告诉大脑的东西可能会与肌肉和内耳告诉大脑的东西大相径庭。你的内耳可能会说:“小心!你在向左倾斜。你应该调整你对世界的预期。”但你的眼睛则说:“胡说!我坐着的桌子在我看来是完全水平的,桌子上的食物盘也是如此。墙上的那幅呐喊的画看起来也很直很水平。不要听内耳的。”

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

呐喊(公共领域)

你的眼睛可能向大脑报告更令人困惑的事情,比如“是的,你确实在倾斜。但这个倾斜的程度或速度并不像你那过于热心的内耳可能让你相信的那样严重。”

就好像你的眼睛和内耳各自要求你的大脑创建两个不同的预期,关于你的世界将如何改变。你的大脑显然做不到这一点。它感到困惑。而由于进化的原因,你的胃表达出强烈的欲望想要排空其内容。

让我们尝试通过统计推理的框架来解释这个令人痛苦的情况。这次,我们将使用一点数学来帮助解释。

你应该预期会晕船吗?深入统计学研究晕船问题

让我们定义一个 随机变量 X,它有两个取值:0 和 1。如果你眼睛的信号与内耳的信号一致,则X为 0。如果信号一致,则X为 1:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

随机变量 X (作者提供的图片)

理论上,X 的每个值都应该具有一定的概率 P(X=x)。概率 P(X=0) 和 P(X=1) 共同构成了 X概率质量函数。我们如下表述:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

X 的概率质量函数(图片来源于作者)

在绝大多数情况下,你的眼睛所接收到的信号会与内耳的信号一致。因此,p 几乎等于 1,(1 — p) 是一个非常非常小的数字。

让我们对 (1 — p) 的值做一个大胆的猜测。我们将使用以下推理来得出一个估计值:根据联合国的数据,2023 年出生时人类的平均预期寿命大约是 73 年。换算成秒,这大约是 2302128000 秒(约 23 亿)。假设一个普通人在其一生中经历 16 小时的晕船,即 28800 秒。现在,让我们不要对这 16 小时斤斤计较。这只是一个大胆的猜测,记住吗?所以,28800 秒给出了 (1 — p) 的一个工作估计值为 28000/2302128000 = 0.0000121626,而 p=(1 —0.0000121626) = 0.9999878374。因此,在一个普通人的一生中的任何一秒中,他们经历晕船的无条件概率仅为 0.0000121626。

根据这些概率,我们将进行一个持续 10 亿秒的模拟,模拟一个名叫 John Doe 的人的一生。这大约是 JD 模拟寿命的 50%。JD 更喜欢大部分时间待在坚实的地面上。他偶尔会进行海上巡游,并经常感到晕船。我们将模拟 JD 是否会在模拟的每一秒中经历晕船。为此,我们将进行 10 亿次伯努利随机变量的试验,其概率为 p 和 (1 — p)。每次试验的结果将是 1(如果 JD 晕船),或者 0(如果 JD 不晕船)。进行实验后,我们将得到 10 亿个结果。你也可以使用以下 Python 代码运行这个模拟:

import numpy as np

p = 0.9999878374
num_trials = 1000000000

outcomes = np.random.choice([0, 1], size=num_trials, p=[1 - p, p])

让我们计算结果为 1(即未晕船)和 0(即晕船)的次数:

num_outcomes_in_which_not_seasick = sum(outcomes)
num_outcomes_in_which_seasick = num_trials - num_outcomes_in_which_not_seasick

我们将打印这些计数。当我打印它们时,我得到以下值。你每次运行模拟时可能会得到稍有不同的结果:

num_outcomes_in_which_not_seasick= 999987794
num_outcomes_in_which_seasick= 12206

现在我们可以计算 JD 是否预计在这些 10 亿秒中的任何一秒中感到晕船。

期望值是两个可能结果的加权平均:即 1 和 0,加权是两个结果的频率。因此,让我们进行这个计算:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结果的期望值(图片来源于作者)

期望结果是 0.999987794,这实际上接近 1.0。数学告诉我们,在 JD 模拟的 10 亿秒中的任何随机选择的一秒钟内,JD 应该不会预期感到晕船。数据似乎几乎不允许这种情况发生。

现在让我们对上述公式稍作调整。我们将按如下方式重新排列它:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

结果的期望值(图像由作者提供)

当以这种方式重新排列时,我们看到一个令人愉快的子结构逐渐显现出来。两个括号中的比率表示与两个结果相关的概率,具体来说是从我们 10 亿强数据样本中得出的样本概率,而不是总体概率。它们是样本概率,因为我们使用了来自我们 10 亿强数据样本的数据进行计算。话虽如此,0.999987794 和 0.000012206 这两个值应与总体值 p 和 (1 — p) 非常接近。

通过插入概率,我们可以将期望公式重新表述如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

X 的期望值(图像由作者提供)

注意到我们使用了期望的符号,即 E()。由于X 是一个 Bernoulli§ 随机变量,上述公式还告诉我们如何计算Bernoulli 随机变量的期望值X ~ Bernoulli§ 的期望值就是 p。

关于样本均值、总体均值,以及一个让你听起来很酷的词

E(X) 也被称为总体均值,用 μ 表示,因为它使用了概率 p 和 (1 — p),这些是总体层面的概率值。这些是你如果能访问到全部数据总体时将观察到的‘真实’概率,但实际上几乎不可能做到。统计学家在提到这些和类似的测量时使用了“渐近”这个词。它们被称为渐近的,因为它们的意义仅在样本大小趋近于无穷大或整个总体的大小时才具有显著性。现在问题来了:我认为人们就是喜欢说‘渐近’。我也认为这是掩盖一个麻烦的真相,即你永远无法精确测量任何事物的值。

从积极的一面来看,无法接触到总体是统计科学领域的‘伟大平衡者’。无论你是新近毕业的学生还是诺贝尔经济学奖得主,这扇通往‘总体’的门对你始终紧闭。作为统计学家,你只能使用样本来工作,你必须默默忍受其缺陷。但情况实际上并没有听起来那么糟糕。想象一下,如果你能知道所有事物的确切值会发生什么。如果你可以接触到总体。如果你能够精确地计算均值、中位数和方差。如果你能够以精确的预测未来。那时就不再需要估计任何东西。统计学的许多分支将会消失。世界将需要减少成千上万的统计学家,更不用说数据科学家了。想象一下对失业、世界经济和世界和平的影响……

但我岔开话题了。我的观点是,如果X是伯努利分布(p),那么要计算 E(X),你不能使用实际的 p 和(1 — p)值。相反,你必须使用 p 和(1 — p)的估计值。这些估计值,你将使用一个适中规模的数据样本来计算,而不是整个总体——没有机会做到这一点。因此,我很遗憾地告诉你,你所能做的最好的是得到随机变量X期望值的估计。按照惯例,我们将 p 的估计值表示为 p_hat(带小帽的 p),将估计的期望值表示为 E_cap(X)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

X 的估计期望(图源作者)

由于 E_cap(X)使用样本概率,因此称为样本均值。它用 x̄或‘x bar’表示,就是在 x 上方加一条横杠。

总体均值样本均值是统计学中的蝙蝠侠和罗宾。

统计学的大部分内容都致力于计算样本均值,并将样本均值作为总体均值的估计值。

这就是它——用一句话概括了统计学的广阔领域。😉

深入期望的深渊

我们对伯努利随机变量的思维实验在某种程度上揭示了期望的本质。伯努利变量二元变量,它的操作非常简单。然而,我们经常使用的随机变量可能取多种不同的值。幸运的是,我们可以轻松地将期望的概念和公式扩展到多值随机变量。让我们通过另一个例子来说明。

多值离散随机变量的期望值

下表显示了 205 辆汽车的数据子集。具体来说,表中展示了每辆车引擎中的气缸数量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

汽车的气缸数(数据来源:UCI 机器学习数据集库,许可证:CC BY 4.0)(图片来自作者)

Y为一个随机变量,包含了从该数据集中随机选择的车辆的气缸数。我们知道数据集中包含气缸数为 2、3、4、5、6、8 或 12 的车辆。因此,Y的范围是集合 E=[2, 3, 4, 5, 6, 8, 12]。

我们将数据行按气缸数分组。下表显示了分组计数。最后一列表示每个计数的样本出现概率。该概率是通过将组大小除以 205 计算得出的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

气缸数的频率分布

使用样本概率,我们可以构建概率质量函数 P(Y) 来表示Y。如果我们将其与Y进行绘图,效果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Y的 PMF(图片来自作者)

如果一辆随机选择的车辆在你面前驶过,你会期望它的气缸数是多少?仅通过查看 PMF,你会想猜测 4 个气缸。然而,这个猜测背后有严谨的数学支持。类似于伯努利X,你可以按如下方式计算Y的期望值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Y的期望值(图片来自作者)

如果你计算这个和,它的结果是 4.38049,这与你猜测的 4 个气缸非常接近。

由于Y的范围是集合E=[2,3,4,5,6,8,12],我们可以将此和式表示为对 E 的求和,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

离散随机变量Y的期望值公式(图片来自作者)

你可以使用上述公式来计算任何离散随机变量的期望值,其范围是集合E

连续随机变量的期望值

如果你处理的是连续随机变量,情况会有所不同,如下所述。

让我们回到我们的车辆数据集。具体来说,让我们查看车辆的长度:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

汽车长度(数据来源:UCI 机器学习数据集库,许可证:CC BY 4.0)(图片来自作者)

假设Z表示随机选取的车辆的长度(单位:英寸)。Z的范围不再是离散的值集合,而是实数集合的一个子集。由于长度总是正数,因此它是所有正实数的集合,记作>0。

由于所有正实数的集合具有(不可数)无限多个值,因此将概率分配给Z的某个特定值是没有意义的。如果你不相信我,可以考虑一个简单的思想实验:想象一下给Z的每一个可能值分配一个正概率。你会发现这些概率的和会趋向于无穷大,这显然是不合理的。因此,概率 P(Z=z)根本不存在。相反,你必须使用概率密度函数 f(Z=z),它为不同的Z值分配一个概率密度

我们之前讨论了如何使用概率质量函数计算离散随机变量的期望值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

离散随机变量Y期望值的公式(图片由作者提供)

我们可以将这个公式用于连续随机变量吗?答案是可以的。要了解如何进行,想象一下你拿着一台电子显微镜。

拿起那台显微镜,聚焦在Z的范围内,即所有正实数的集合(>0)。现在,放大到一个极其微小的区间(z, z+δz]。在这个微观尺度下,你可能会观察到,从实际角度来看(现在,不是一个有用的术语),概率密度 f(Z=z) 在 δz 上是常量。因此,f(Z=z)和 δz 的乘积可以近似为随机选择的车辆长度落在开闭区间(z, z+δz]内的概率

拥有这个近似概率后,你可以将Z的期望值近似如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Z是连续的时,E(Z)的近似评估(图片由作者提供)

注意我们是如何从 E(Y)的公式跳跃到这个近似公式的。要从 E(Y)得到 E(Z),我们做了以下工作:

  • 我们将离散的 y_i 替换为实值的 z_i。

  • 我们将Y的 PMF P(Y=y)替换为Z的近似概率 f(Z=z)δz,它表示在微观区间(z, z+δz]中找到 z 的概率。

  • 我们不再对E的离散有限范围Y求和,而是对>0 的连续无限范围Z求和。

  • 最后,我们将等号替换为近似符号。罪行就在这里。我们作弊了。我们偷偷使用了概率 f(Z=z)δz,它是对精确概率 P(Z=z) 的近似。我们作弊的原因是,精确概率 P(Z=z) 对于连续Z而言是不存在的。我们必须为这个罪过做出补偿,这正是我们接下来要做的。

我们现在进行我们的绝技,我们的拿手好戏,并在这样做中救赎自己。

由于 >0 是正实数的集合,在 >0 中有无限多个大小为 δz 的微小区间。因此,对 >0 的求和是对无限多个项的求和。这一事实为我们提供了一个绝佳的机会,可以用精确积分来替代近似求和,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Z 的期望值(作者提供的图片)

一般来说,如果 Z 的范围是实数区间 [a, b],我们将定积分的上下限设置为 a 和 b,而不是 0 和 ∞。

如果你知道 Z 的概率密度函数(PDF),并且在 [a, b] 区间内 z 乘以 f(Z=z) 的积分存在,你将解出上述积分,从而得到 E(Z)。

如果 Z 在区间 [a, b] 上均匀分布,其 PDF 如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PDF of Z ~ Uniform(a, b)(作者提供的图片)

如果你设置 a=1 和 b=5,

f(Z=z) = 1/(5–1) = 0.25。

概率密度在 Z=1 到 Z=5 的区间内为常数 0.25,而其他地方为零。Z 的 PDF 如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PDF of Z ~ Uniform(1, 5)(作者提供的图片)

基本上,它是从 (1,0.25) 到 (5,0.25) 的一条连续平坦的水平线,其他地方的值为零。

一般来说,如果 Z 的概率密度在区间 [a, b] 上是均匀分布的,Z 的 PDF 在 [a, b] 上是 1/(b-a),其他地方为零。你可以使用以下过程计算 E(Z):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算均匀分布在区间 [a, b] 上的连续随机变量的期望值的过程(作者提供的图片)

如果 a=1 且 b=5,Z ~ Uniform(1, 5) 的均值就是 (1+5)/2 = 3。这与我们的直觉一致。如果在 1 和 5 之间的每一个无限多的值都是同样可能的,我们会期望均值等于 1 和 5 的简单平均值。

现在我不想打击你的积极性,但实际上,你更可能在前院看到双彩虹,而不是遇到需要用积分法来计算其期望值的连续随机变量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

双彩虹 (CC BY-SA 2.0)

你会发现,看似精美的概率密度函数(PDF)通常会被嵌入到大学教科书的章节末尾练习中。它们就像家猫一样,不“出门”。但作为一个实践中的统计学家,“外面”就是你生活的地方。在外面,你会发现自己面对着连续值的数据样本,比如车辆的长度。为了对这些真实世界的随机变量进行建模,你很可能会使用一些著名的连续函数,例如正态分布、对数正态分布、卡方分布、指数分布、威布尔分布等等,或者混合分布,即最适合你数据的模型。

这里有几个这样的分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

正态分布和卡方分布的连续随机变量的 PDF 和期望值(图片由作者提供)

对于许多常用的 PDF,已经有人花费心力通过积分(x 乘以 f(x))来推导分布的均值,就像我们对均匀分布所做的那样。这里有几个这样的分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

指数分布和伽马分布的连续随机变量的 PDF 和期望值

最后,在一些情况下,实际上是许多情况下,现实生活中的数据集表现出过于复杂的模式,无法用任何一个分布来建模。这就像你感染了一种病毒,带来了一堆症状。为了帮助你克服这些症状,你的医生会给你开一系列药物,每种药物的强度、剂量和作用机制都不同。当你面对的数据展现出许多复杂的模式时,你必须动用一小支概率分布的“军队”来进行建模。这种不同分布的组合被称为混合分布。一种常用的混合分布是强大的高斯混合模型,它是多个正态分布随机变量的几个概率密度函数的加权和,每个随机变量具有不同的均值和方差组合。

给定一个真实值数据的样本,你可能会发现自己在做一些非常简单的事情:你将计算连续值数据列的平均值,并将其称为样本均值。例如,如果你计算汽车数据集中汽车的平均长度,它会是 174.04927 英寸,仅此而已。就这么完成了。但是,这还不是全部,你还有一个问题需要回答。

你的样本均值有多准确?感受它的准确性

你如何知道样本均值对总体均值的估计有多准确?在收集数据时,你可能运气不好,或者懒惰,或者‘数据受限’(这通常是懒惰的一个极好的委婉说法)。无论哪种方式,你都在面对一个非随机的样本。它没有按比例代表总体的不同特征。以汽车数据集为例:你可能收集了大量中型车的数据,而大型车的数据则太少。而且伸缩豪华轿车可能完全没有出现在你的样本中。结果,你计算的平均长度将严重偏向于总体中仅中型车的平均长度。无论你是否喜欢,你现在都在以几乎每个人都开中型车为信念进行工作。

对自己要诚实

如果你收集了一个严重偏倚的样本,而你不知道或者不在意,那么愿上天保佑你在你选择的职业道路上。然而,如果你愿意考虑偏倚的可能性,并且你有一些线索关于你可能遗漏了哪些数据(例如跑车),那么统计学将通过强有力的机制来帮助你估计这种偏倚

不幸的是,无论你多么努力,你永远也无法收集到一个完美平衡的样本。它总是包含偏倚,因为总体中各种元素的确切比例对你来说永远无法访问。记住那扇通向总体的门吗?记得门上的标志上总是写着‘CLOSED’吗?

你最有效的行动方案是收集一个大致包含总体中所有事物的相同比例的样本——即所谓的良好平衡样本。这个良好平衡样本的均值是你可以出发的最佳样本均值。

但是自然法则并不总是让统计学家的风帆黯然失色。自然界有一个宏伟的特性,体现在一个被称为中心极限定理(CLT)的定理中。你可以使用 CLT 来确定你的样本均值估计总体均值的效果如何。

CLT 并不是应对严重偏倚样本的灵丹妙药。如果你的样本主要由中型车组成,你实际上已经重新定义了你对总体的概念。如果你是有意只研究中型车,那么你就免责了。在这种情况下,尽管使用 CLT。它将帮助你估计你的样本均值与中型车的总体均值有多接近。

另一方面,如果你的存在目的在于研究所有生产的车辆,但你的样本主要是中型车,那么你遇到了问题。对于统计学的学生,让我用稍微不同的词重述一下。如果你的大学论文是关于宠物打哈欠的频率,但你的样本是 20 只猫和你邻居的贵宾犬,那么无论中心极限定理是否适用,再多的统计技巧也无法帮助你评估样本均值的准确性。

中心极限定理的要点

对于中心极限定理(CLT)的全面理解是另一个话题,但其要点如下:

如果你从总体中随机抽取数据点,并计算样本均值,然后重复这个过程多次,你将得到……许多不同的样本均值。嗯,显然如此!但接下来会发生一些令人惊讶的事情。如果你绘制所有这些样本均值的频率分布,你会发现它们 总是 正态分布的。更重要的是,这种正态分布的均值总是你正在研究的总体均值。这种我们宇宙个性中诡异而迷人的方面正是中心极限定理用(还有什么?)数学语言描述的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

标记为 174.04927 英寸的样本均值长度在一个假设总体均值为 180 英寸的正态分布 Z 上(图由作者提供)

让我们来看看如何使用中心极限定理。我们将按以下步骤进行:

使用仅从一个样本得出的样本均值 Z_bar,我们将说总体均值 μ 落在区间 [μ_low, μ_high] 内的概率是 (1 — α):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

总体均值的下限和上限置信区间(图由作者提供)

你可以将 α 设置为 0 到 1 之间的任何值。例如,如果你将 α 设置为 0.05,你将得到 (1 — α) 为 0.95,即 95%。

要使这种概率 (1 — α) 成立,应该按如下方式计算下界 μ_low 和上界 μ_high:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

总体均值的下限和上限(图由作者提供)

在上述公式中,我们知道 Z_bar、α、μ_low 和 μ_high。其余的符号需要一些解释。

变量 s 是数据 样本 的标准差。

N 是样本量。

现在我们来讨论 z_α/2。

z_α/2 是你在标准正态分布的概率密度函数(PDF)的 X 轴上读出的值。标准正态分布是一个均值为零,标准差为一的正态分布连续随机变量的 PDF。z_α/2 是该分布 X 轴上使得 PDF 曲线左侧面积为 (1 — α/2) 的值。这一区域当你设置 α 为 0.05 时的样子如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 X 轴上某个值 X 左侧的 PDF 区域。在这种情况下,x=1.96(图由作者提供)

蓝色区域计算为 (1 — 0.05/2) = 0.975。请记住,任何 PDF 曲线下的总面积始终为 1.0。

总结一下,一旦你从一个样本中计算出均值 (Z_bar),你可以围绕这个均值建立界限,使得总体均值落在这些界限内的概率是你选择的值。

让我们重新检查估计这些界限的公式:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

总体均值的下界和上界(图像来源于作者)

这些公式给我们一些关于样本均值性质的见解:

  1. 随着样本方差 s 的增加,下界 (μ_low) 的值会降低,而上界 (μ_high) 的值会增加。这会有效地使 μ_low 和 μ_high 彼此远离,并远离样本均值。相反,随着样本方差的减少,μ_low 从下方更接近Z_bar,μ_high 从上方更接近Z_bar。区间界限本质上从两侧趋向于样本均值。实际上,区间 [μ_low, μ_high] 与样本方差成正比。如果样本在均值周围分布得很广泛(或紧密),则较大的(或较小的)分散会降低(或增加)样本均值作为总体均值估计的可靠性。

  2. 注意,区间的宽度与样本大小 (N) 成反比。在两个方差相似的样本之间,较大的样本会产生围绕其均值的更紧密区间,而较小的样本则不会。

让我们看看如何计算汽车数据集的这个区间。我们将计算 [μ_low, μ_high],以便有 95% 的概率使总体均值 μ 落在这些范围内。

为了获得 95% 的概率,我们应该将 α 设置为 0.05,这样 (1 — α) = 0.95。

我们知道 Z_bar 为 174.04927 英寸。

N 为 205 辆车辆。

样本标准差可以很容易地计算出来。它为 12.33729 英寸。

接下来,我们将处理 z_α/2。由于 α 为 0.05,α/2 为 0.025。我们需要找到 z_α/2 的值,即 z_0.025。这是在标准正态随机变量的 PDF 曲线的 X 轴上的值,其中曲线下的区域是 (1 — α/2) = (1 — 0.025) = 0.975。通过查阅 标准正态分布表,我们发现这个值对应于X=1.96 的左侧区域。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

包含标准正态分布 CDF 值的表格。包含不同 X 值的 P(X ≤ x)(来源:维基百科

插入这些值,我们得到以下界限:

μ_low = Z_bar — ( z_α/2 · s/√N) = 174.04927 — (1.96 · 12.33729/205) = 173.93131

μ_high = Z_bar + ( z_α/2 · s/√N) = 174.04927 + (1.96 · 12.33729/205) = 174.16723

因此,[μ_low, μ_high] = [173.93131 英寸, 174.16723 英寸]

有 95%的概率人口均值位于这个区间内。看看这个区间有多紧凑。它的宽度仅为 0.23592 英寸。在这个狭小的间隙中包含了样本均值 174.04927 英寸。尽管样本中可能存在各种偏差,我们的分析表明,样本均值 174.04927 英寸是对未知人口均值的极其良好的估计。

超越第一维度:多维样本空间中的期望

目前,我们关于期望的讨论仅限于一维,但它不一定非得如此。我们可以轻松地将期望的概念扩展到二维、三维或更高维度。要计算多维空间中的期望,我们只需要一个定义在 N 维空间上的联合概率质量函数(或密度函数)。联合 PMF 或 PDF 以多个随机变量作为参数,返回这些值同时出现的概率。

在文章前面,我们定义了一个随机变量Y,表示从汽车数据集中随机选择的车辆中的气缸数量。Y是你的典型一维离散随机变量,它的期望值由以下公式给出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

单维离散随机变量的期望值(图像由作者提供)

让我们引入一个新的离散随机变量XXY联合概率质量函数用 P(X=x_i, Y=y_j)表示,或简写为 P(X, Y)。这个联合 PMF 将我们从Y所处的舒适一维空间中带出,带入到一个更有趣的二维空间。在这个二维空间中,一个数据点或结果由元组(x_i, y_i)表示。如果X的范围包含‘p’个结果,Y的范围包含‘q’个结果,则二维空间将具有(p x q)个联合结果。我们用元组(x_i, y_i)来表示这些联合结果中的每一个。要计算这个二维空间中的 E(Y),我们必须将 E(Y )的公式适应如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

离散随机变量Y在二维空间中的期望值(图像由作者提供)

请注意,我们正在对二维空间中所有可能的元组(x_i, y_i)进行求和。让我们将这个求和拆解成嵌套求和如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

离散随机变量Y在二维空间中的期望值(图像由作者提供)

在嵌套求和中,内层求和计算 y_j 和 P(X=x_i, Y=y_j)在所有 y_j 值上的乘积。然后,外层求和对每个 x_i 值重复内层求和。之后,它将所有这些单独的和收集起来并加总,以计算 E(Y )。

我们可以通过将求和嵌套在彼此之间,将上述公式扩展到任意数量的维度。你需要的只是一个在 N 维空间上定义的联合 PMF。例如,以下是如何将公式扩展到 4 维空间:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

离散随机变量Y在 4 维空间上的期望值(图片由作者提供)

注意我们总是将Y的求和放在最深层次。你可以按任何顺序安排其余的求和——你将得到相同的 E(Y)结果。

你可能会问,为什么要定义一个联合 PMF 并为所有这些嵌套求和而发狂?在 N 维空间上计算的 E(Y)是什么意思?

理解多维空间中期望值含义的最佳方式是用实际的多维数据来说明其使用。

我们将使用的数据来自一艘特定的船,它与我渡过英吉利海峡的船不同,不幸的是没能到达另一边。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

RMS 泰坦尼克号于 1912 年 4 月 10 日从南安普顿出发(公有领域)

下图展示了 887 名乘客在 RMS 泰坦尼克号上的数据集中的一些行:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

泰坦尼克号数据集 (CC0)

Pclass列表示乘客的舱位级别,整数值为 1、2 或 3。Siblings/Spouses AboardParents/Children Aboard变量是二元(0/1)变量,表示乘客是否有兄弟姐妹、配偶、父母或子女在船上。在统计学中,我们常常有些残酷地称这些二元指示变量虚拟变量。它们并没有什么愚蠢的地方以至于配得上这样的贬义称呼。

从表中可以看出,有 8 个变量共同标识数据集中的每个乘客。每一个这 8 个变量都是一个随机变量。我们面临的任务有三方面:

  1. 我们希望在这些随机变量的一个子集上定义一个联合概率质量函数,并且,

  2. 使用这个联合 PMF,我们希望说明如何在这个多维 PMF 上计算这些变量的期望值,并且,

  3. 我们希望理解如何解读这个期望值。

为了简化问题,我们将Age变量分成 5 年为一个区间,并将这些区间标记为 5、10、15、20、…、80。例如,20 岁区间意味着乘客的实际年龄在(15,20]年区间内。我们将这个分箱后的随机变量称为Age_Range

一旦Age被分箱,我们将按PclassAge_Range对数据进行分组。以下是分组计数:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按乘客的舱位和(分组的)年龄的频率分布(作者提供的图像)

上表包含了每个群体(组)的泰坦尼克号乘客数量,这些群体是由PclassAge_Range的特征定义的。顺便提一下,群体也是统计学家非常崇拜的另一个词(以及渐进)。这里有个小提示:每当你想说“组”时,直接说“群体”。我向你保证,无论你本来打算说什么,瞬间都会显得十倍重要。例如:“八个不同的群体的酒精爱好者(请原谅,酒类鉴赏家)喝了假酒,他们的反应被记录下来。”明白了吗?

说实话,“群体”确实有“组”没有的精准含义。尽管如此,偶尔说一遍“群体”,观察听众脸上的尊敬感是有启发性的。

无论如何,我们将向频率表中添加另一列。这一新列将保存观察到特定PclassAge_Range组合的概率。这个概率 P(Pclass, Age_Range)是频率(即Name列中的数量)与数据集中乘客总数(即 887)的比率。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按乘客的舱位和(分组的)年龄的频率分布(作者提供的图像)

概率 P(Pclass, Age_Range)是随机变量PclassAge_Range联合概率质量函数。它给出了观察到被特定PclassAge_Range组合描述的乘客的概率。例如,查看Pclass为 3 且Age_Range为 25 的行。相应的联合概率为 0.116122。这个数字告诉我们,大约 12%的泰坦尼克号 3 等舱乘客年龄在 20 到 25 岁之间。

与一维 PMF 类似,当对其所有组成随机变量的值组合进行评估时,联合 PMF 也会加和为完美的 1.0。如果你的联合 PMF 没有加和为 1.0,你应该仔细检查一下定义。可能在公式中存在错误,或者更糟糕的是实验设计中有问题。

在上面的数据集中,联合 PMF 确实加和为 1.0。相信我吧!

要直观感受联合 PMF P(Pclass, Age_Range),你可以在 3 维中绘制它。在 3-D 图中,将 X 和 Y 轴分别设置为PclassAge_Range,将 Z 轴设置为概率 P(Pclass, Age_Range)。你会看到一个引人入胜的 3-D 图表。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PclassAge_Range的联合 PMF 的 3-D 图(作者提供的图像)

如果你仔细观察,你会注意到联合 PMF 包含三个平行的图,分别对应于泰坦尼克号上的每一个船舱等级。三维图展示了不幸海轮上的一些人口统计数据。例如,在所有三个舱等级中,15 到 40 岁的乘客占据了大部分。

现在让我们计算在这个二维空间上的 E(Age_Range)。E(Age_Range)由下式给出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Age_Range的期望值(图片来源:作者)

我们对所有Age_Range的值进行内部求和:5,10,15,…,80。我们对所有Pclass的值进行外部求和:[1, 2, 3]。对于每一个(Pclass, Age_Range)的组合,我们从表中选择联合概率。Age_Range的期望值是 31.48252537 岁,这对应于 35 的分箱值。我们可以预期,泰坦尼克号上的“平均”乘客年龄在 30 到 35 岁之间。

如果你取泰坦尼克号数据集中Age_Range列的平均值,你将得到完全相同的数值:31.48252537 岁。那么为什么不直接取Age_Range列的平均值来得到 E(Age_Range)? 为什么要构建一个嵌套求和的鲁布·戈德堡机器来计算同样的值呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

鲁布·戈德堡的“自动操作餐巾纸”机器(公共领域)

这是因为在某些情况下,你所拥有的只是联合 PMF 和随机变量的范围。在这种情况下,如果你只有 P(Pclass, Age_Range)且知道Pclass的范围是[1,2,3],以及 Age_Range 的范围是[5,10,15,20,…,80],你仍然可以使用嵌套求和技术来计算 E(Pclass) E(Age_Range)。

如果随机变量是连续的,那么可以使用多重积分在多维空间中找到期望值。例如,如果XYZ是连续随机变量,并且 f(X,Y,Z)是定义在三维连续空间(x, y, z)的联合概率密度函数,则Y在该三维空间中的期望值如图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

连续随机变量Y在连续三维空间中的期望值(图片来源:作者)

就像在离散情况下,你首先对你想计算其期望值的变量进行积分,然后对其他变量进行积分。

一个著名的例子展示了用于计算期望值的多重积分方法,其规模小到人眼无法感知。我指的是量子力学中的波函数。波函数在笛卡尔坐标中表示为Ψ(x, y, z, t),在极坐标中表示为Ψ(r, θ, ɸ, t)。它用于描述那些喜欢待在极其狭小空间里的微小物体的性质,例如原子中的电子。波函数Ψ返回一个形式为 A + jB 的复数,其中 A 代表实部,B 代表虚部。我们可以将Ψ的绝对值平方解释为定义在四维空间(x, y, z, t)或(r, θ, ɸ, t)上的联合概率密度函数。特别是对于氢原子中的电子,我们可以将|Ψ|²解释为在时间 t 时电子在(x, y, z)或(r, θ, ɸ)周围一个极其微小的空间体积中的大致概率。通过知道|Ψ|²,我们可以在 x, y, z 和 t 上进行四重积分,以计算电子在时间 t 沿 X、Y 或 Z 轴(或其极坐标等效轴)的期望位置

结束语

我以自己对晕船的经历开始这篇文章。如果你对用伯努利随机变量来建模这一非常复杂且尚未完全理解的人类困境感到不满,我也不会怪你。我的目的是说明期望如何从生物学层面实际影响我们。一种解释这一困境的方法是使用随机变量的酷炫且舒缓的语言。

从看似简单的伯努利变量开始,我们将我们的插图画笔从统计画布扫到量子波函数的宏伟多维复杂性。在整个过程中,我们力求理解期望如何在离散和连续尺度、单维和多维、以及微观尺度上运作。

还有一个领域,期望发挥了巨大的影响。这个领域是条件概率,其中计算随机变量X取值‘x’的概率,假设某些其他随机变量ABC等已经取值‘a’、‘b’、‘c’。XABC的条件下的概率表示为 P(X=x|A=a,B=b,C=c),或简写为 P(X|ABC)。在我们见过的所有期望公式中,如果将概率(或概率密度)替换为同一条件版本,得到的就是条件期望的相应公式。它表示为 E(X=x|A=a,B=b,C=c),它位于回归分析和估计的广泛领域的核心。这是未来文章的素材!

引用和版权

数据集

汽车数据集下载自加州大学欧文分校机器学习库,根据知识共享署名 4.0 国际(CC BY 4.0)许可协议使用。

泰坦尼克号数据集下载自Kaggle,根据CC0 许可使用。

图片

本文中的所有图片版权归Sachin Date所有,采用CC-BY-NC-SA许可,除非图片下方另有说明。

如果你喜欢这篇文章,请关注我,访问 Sachin Date 以获取关于回归、时间序列分析和预测主题的提示、操作指南和编程建议。

归纳偏差的一个童话故事

原文:towardsdatascience.com/a-fairy-tale-of-the-inductive-bias-d418fc61726c

|归纳偏差| 变换器| 计算机视觉|

我们需要归纳偏差吗?简单模型如何达到复杂模型的性能

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Salvatore Raieli

·发布于 Towards Data Science ·18 分钟阅读·2023 年 7 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Natalia Y. 提供,Unsplash

正如我们近年来所见,深度学习在使用量和模型数量上都经历了指数级增长。这一成功的铺路石或许就是 迁移学习 本身——即一个模型可以通过大量数据进行训练,然后用于各种具体任务的理念。

近年来,出现了一种范式:变换器(或基于此模型的其他变体)被用于 NLP 应用。而在图像领域,则使用 视觉变换器卷积网络

## LLMs 的无限巴别图书馆

开源、数据与注意力:LLMs 的未来将如何改变

towardsdatascience.com ## META 的 Hiera:减少复杂性以提高准确性

简单性使 AI 能够达到惊人的性能和令人惊讶的速度

towardsdatascience.com

另一方面,虽然我们有大量实践工作证明这些模型效果良好,但理论上的理解却滞后。这是因为这些模型非常广泛,实验起来很困难。视觉变换器的表现优于卷积神经网络,因为它们在视觉上具有理论上更少的归纳偏差,这表明存在一个需要填补的理论空白。

本文重点讨论:

  • 归纳偏差究竟是什么?为什么这很重要,我们最喜欢的模型有什么归纳偏差?

  • 变换器和 CNN 的归纳偏差。这两种模型之间有什么区别,为什么这些讨论很重要?

  • 我们如何研究归纳偏差?如何利用不同模型之间的相似性来捕捉它们的差异。

  • 具有弱归纳偏差的模型能否在计算机视觉领域取得成功?这是一个传统上被认为归纳偏差很重要的领域。

什么是归纳偏差?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Raphael Schaller提供,拍摄于Unsplash

学习是通过观察和与世界互动来获取有用知识的过程。它涉及在解决方案空间中搜索,以找到一个能够更好地解释数据或获得更高奖励的解决方案。但在许多情况下,存在多个同样好的解决方案。 (source)

想象一下在湖中遇到一只天鹅。从这只简单的天鹅,我们可能会假设所有的天鹅都是白色的(直到我们看到一只黑天鹅),它们是水禽,它们以鱼为食,等等。

这个过程被称为归纳推理。从一个简单的观察中,我们可能能够推导出成千上万(甚至数十亿)个假设,显然,所有的假设并不都是真实的。实际上,我们可能会认为天鹅无法飞行,因为我们当时只观察到它在游泳。

显然,没有直接观察很难决定哪个假设是正确的。所以根据奥卡姆剃刀原则,我们可以说“天鹅可以在湖中游泳”。

为什么这对机器学习很重要?

数据集是观察结果的集合,我们想要创建一个可以从这些观察结果中泛化的模型。其理念是,从我们的数据集中,我们可以推断出一些对总体人群也适用的规则。换句话说,我们可以将我们的模型视为一组假设。

理论上,假设空间是无限的。实际上,如果我们考虑笛卡尔空间中的两个点,它可以通过一条直线但无限多的曲线。在没有更多点的情况下,我们无法知道哪种假设是最正确的。

通常,最简单的假设是最正确的。一个完美拟合点的曲线通常是过拟合

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来自这里

归纳偏差可以定义为对某些假设的优先考虑(从而减少假设空间)。例如,当我们面对一个回归任务时,我们决定考虑线性模型,这时我们通过使假设仅限于线性模型来减少我们的假设空间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

线性回归。图片来源:这里

归纳偏差使得学习算法可以在观察到的数据之外优先考虑一种解决方案(或解释)而非另一种,来源

一方面,我们有不同类型的数据,以及具有不同假设和不同归纳偏差的不同类型模型(即,假设空间的不同简化)。因此,人们可能会倾向于为所有类型的数据选择一个模型。

然而在 1997 年,无免费午餐定理结束了这种诱惑。没有一个模型可以适用于所有情况。实际上,没有一种最优的偏差可以使模型对所有任务进行泛化。换句话说,一个任务的最优假设可能对另一个任务并不最优。

这就是我们为什么对图像使用卷积神经网络,对文本序列使用 RNN(或 LSTM)等的原因。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

为了更好地理解,以下是一些归纳偏差的例子:

  • 决策树基于这样的假设:一个任务可以通过一系列的二元决策(二元拆分)来解决。

  • 正则化的假设是指向参数具有小值的解决方案。

  • 全连接层是一种全对全的偏差,其中层 i 的所有单元都与下一层 j 相连(一个层中的所有神经元都与下一层相连)。这意味着存在一种非常弱的关系偏差,因为任何单元都可以与其他单元进行交互。

  • 卷积神经网络基于局部性的理念,即特征通过局部像素提取,并以层次化模式组合。在另一种世界观中,我们假设相邻的像素实际上是相关的,这种关系应当被模型考虑(在卷积步骤中)。

  • 递归神经网络具有与序列性相关的偏差,因为每个词是按顺序处理的。由于权重在序列的所有元素中被重用(我们更新隐藏状态),因此还存在时间等变性(或递归性)。

  • 变换器 它没有强的归纳偏差,这应当提供更多的灵活性(但需要大量的训练数据)。实际上,在数据较少的情况下,该模型的表现往往不如其他模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

CNN 和变换器的归纳偏差

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Tudose Alexandru拍摄,发布于Unsplash

卷积神经网络长期以来主导了计算机视觉领域,直到视觉变换器的出现。正如我们前面提到的,CNN 基于相邻像素之间存在关系的原理。因此,在卷积过程中,几个像素共享相同的权重。

此外,池化层的使用旨在实现平移不变性。这意味着,无论模式出现在图像的何处(例如,图像的左角或右角),它都会被识别。

这些偏差对于处理自然图像数据非常有效,因为局部邻域内具有较高的协方差,而随着距离的增加,这种协方差会减小,并且统计特性在整张图像上大致是稳定的。(来源

这些偏差实际上受到了下颞皮层的启发,该区域似乎提供了对应的生物学基础,用于尺度、平移和旋转不变性。这些偏差被认为对 CNN 在面对图像平移、缩放或其他变形时的鲁棒性很重要,因此通过卷积和池化加以应用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

另一方面,图像是复杂且信息丰富的对象。鉴于其用途,尝试更详细地理解CNN所看到的内容以及存在的其他偏差。

在 2017 年的一项研究中,作者展示了Inception 模型(一种 CNN)的“形状偏差”很强。换句话说,CNN 在识别对象时更依赖于对象的形状而非其他类型的模式。作者使用了一个图像三联体来分类一个对象,并使用了颜色相同但形状不同(颜色匹配)或形状相同但颜色不同(形状匹配)的图像来研究模式是否更注重形状或颜色。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

在一项后续研究中,一些作者则展示了,模型更关注的是纹理而非颜色。作者使用了 ResNet50 来测试这一假设。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

他们展示了在纹理与形状冲突的情况下,模型倾向于使用纹理。因此,对作者来说,CNN 具有强烈的“纹理偏差”。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

然而,作者总结道,具有形状偏差的模型更具鲁棒性:

值得注意的是,具有较高形状偏差的网络在对许多不同图像扭曲的鲁棒性上天生更强(对于一些甚至达到了或超过了人类表现,尽管从未接受过这些扭曲的训练),并在分类和物体识别任务中表现更好。 (这里)

实际上,对于图像来说,具有形状偏差是理想的。这可以通过使用适当的数据集或使用数据增强技术来实现,这些技术包括颜色失真、噪声和模糊(这些恰好减少了纹理偏差)。相反,随机裁剪会增加纹理偏差。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

具体来说,我们可以说偏差不仅依赖于卷积神经网络(CNN)的架构,还依赖于训练时使用的数据集。根据数据集的不同,CNN 会倾向于形状或纹理。

一项研究的作者 表示,这些偏差是互补的。模型可以专注于纹理或形状进行预测。然而,有时,仅这两种元素中的一种不足以进行正确预测(降低了性能)。作者表示,由于模型可以学习任一偏差,它还可以“自动找出如何避免对形状或纹理有偏见。”换句话说,使用具有冲突的(纹理和形状)示例可以指导模型避免偏见。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:这里

视觉变换器 来源于变换器,正如前面提到的,它是一个没有强偏差的模型。

一些研究表明,CNNs 和 ViTs 之间仍有几个相似之处。实际上,ViTs 也学习了层级视图,并且可以被可视化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像:来源

[## 视觉变换器所见的视觉之旅

一些最大的模型如何看待世界

pub.towardsai.net

然而,后来的研究表明,ViTs 实际上具有比 CNNs 更高的形状偏差。这实际上令人惊讶。此外,作者指出这种形状偏差在图像损坏的鲁棒性方面发挥了积极作用:

[…] 强调了形状偏差与均值损坏误差之间的一般性反向关系。模型对常见损坏的鲁棒性越高(即更小的 mCE),其形状偏差就越大。(来源)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

视觉变换器的形状偏差。图像来源:这里

几个研究小组假设,通过添加适当的归纳偏差,可能使 ViTs 即使不使用数百万张图像进行训练也能超越 CNNs。另一方面,这一假设导致了大量模型的创建但使训练极其低效。

所以问题仍然存在:

参数和训练样本数量的扩展能在多大程度上弥补缺乏归纳偏差的问题?

如何研究归纳偏差?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Aaron Burden拍摄,来自Unsplash

正如我们所见,仍然存在几个未解的问题。尽管有大量关于CNNsViTs的研究,但这些性能提升背后的许多理论背景仍然不清楚。

“MLP 是这些神经网络架构中最简单的一种,依赖于这种堆叠思想,因此提供了一个有效深度学习理论的最简模型。” (source)

通常,许多关于更理论方面的研究都是使用多层感知器(MLP)进行的。这是因为它是由简单的矩阵乘法组成的层,封装在一个非线性函数中。其简单性允许在较低的计算成本下进行许多实验。然后对更简单模型进行的研究被转化为更复杂和精细的模型。然而,MLP 在许多情况下性能较差,这留给我们的是,如何将观察到的内容转移到具有远超性能的模型中。

另一方面,MLP 还有一个优势,即其具有较弱的归纳偏差。这使得它成为 ViT 研究的一个良好候选模型。还有一个衍生模型,归纳偏差更小:MLP-Mixer

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

有趣的是,MLP-Mixer 既不使用卷积也不使用自注意力。相反,它依赖于多层感知器层,这些层应用于空间位置或特征通道。这一切都得益于矩阵乘法和非线性的巧妙使用。

简而言之,图像块被线性投影到嵌入空间(然后转换为可以被 MLP 利用的表格数据)。之后,我们有一系列混合层。输入数据进入并转置,然后我们有一个简单的全连接层。这个层识别在图像块中常见的特征(聚合通道)。然后结果被转置,并通过第二个全连接层来识别图像块本身的特征(与通道关联)。

此外,还有跳跃连接、作为非线性函数的GELU层归一化。另外,作者评论道:

我们的架构可以看作是一个独特的 CNN,它使用(1×1)卷积进行通道混合,使用单通道深度卷积进行令牌混合。然而,反之则不然,因为 CNN 不是 Mixer 的特例。(来源

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

另一个有趣的关系是,卷积可以看作是 MLP 的特例,其中 W 权重矩阵是稀疏的并具有共享的条目。权重的这种共享确实导致学习在空间上是局部化的(正如我们上文提到的卷积的空间偏差)。

考虑一个矩阵 W、一个 2x3x1 像素的图像和一个 2x2 的滤波器 f,这种关系变得非常清晰:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

这具有使模型具有平移不变性的优势,但如果图像中存在排列变换,则牺牲了MLPs的鲁棒性。

那么,关于视觉变换器(Vision Transformers)呢?

ViTs和卷积之间也有密切的关系(尽管它们具有相同的偏差)。实际上,正如上文所示,自注意力层以类似于卷积层的方式处理图像。在a 2020 paper中,作者展示了自注意力层如何表达任何卷积层。

所以正如我们所说,MLP、MLP-mixer、卷积网络和视觉变换器之间存在强烈的关系。虽然这些模型在归纳偏差和处理图像的方式上有很大的不同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

总的来说,由于各种模型之间存在强烈的关系和对应性,但也有归纳偏差的差异,我们可以使用MLP作为一个简单的模型来理解是否通过缩放和训练集中示例的增加可以弥补缺乏归纳偏差的问题。

大卫与歌利亚

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由肖恩·罗伯逊Unsplash拍摄

在一篇近期的论文中,他们正是这样做的。他们采用了MLP,一个结构上简单的模型,试图理解缩放时发生了什么。缩放能改善简单全连接层的性能吗?

作者采用了一个 MLP 并构建了一个模型,他们堆叠了相同大小的 MLP 层。利用最近的文献,他们添加了层归一化和跳跃连接,以查看这些是否使训练更稳定。他们还创建了一个简单的架构,称为反向瓶颈,在这个架构中,通过两个权重矩阵,他们在同一块中扩展和收缩输入:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

反向块。图像来源:这里

一方面,确实这些添加增加了归纳偏差,但与现代复杂架构相比,这几乎可以忽略不计。之后,他们决定探索将MLP与其他模型在计算机视觉任务中比较时的情况(通常MLP的表现远远逊色)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:这里

作者在一些流行的计算机视觉数据集上测试了这些架构,得到了有趣的结果:

  • MLP 标准直接进入过拟合状态。

  • 添加数据增强略微提高了性能。

  • 使用瓶颈增加了性能。使用反向瓶颈的数据增强对性能有显著更高的影响(约 20% 的性能提升)。

  • 尽管如此,ResNet18 的性能远远优于。

这些数据与文献一致,文献指出,在样本量较小的情况下(毕竟这些数据集较小),归纳偏差很重要。事实上,ViTs和 MLP mixers 也观察到了同样的现象。

近年来,大模型的优势在于它们可以在大量图像上进行训练,然后将知识转移到较小的数据集(迁移学习)。为此,作者使用了ImageNet21k(1200 万张图像和 11k 类别)。之后,他们在新任务上进行了微调

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:这里

结果令人惊讶,该模型能够将其对数据集的学习转移到另一个任务上。此外,结果远远优于以往所见。

尽管在大量数据上进行过预训练,我们仍然想强调的是,这样的 MLP 在所有数据集上与从头开始训练的 ResNet18 竞争,除了在 ImageNet1k 上表现意外不佳。 (source)

这证实了MLP是分析迁移学习、数据增强和其他理论元素的良好代理。这很令人惊讶,因为与现代模型相比,它是一个基础模型。

另一个令人惊讶的结果是,在训练中使用大的批量大小可以提高性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:这里

一般而言,观察到相反的效果。特别是在CNN的情况下,当用更多的示例进行训练时,试图保持小批量大小的性能。毕竟,使用小批量意味着在一个周期中进行更多的梯度更新(尽管训练时间更长)。另一方面,大批量大小更快,并且可以跨多个设备分配,从而节省时间。

此外,对变换器的观察表明,即使是这些大模型也从更大的批量大小中受益。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:这里

一般来说,近年来对规模定律的讨论很多:根据这一理论,随着参数的增加,性能有一个高度可预测的提高(并且遵循一个可量化的幂律)。这种规模定律在 LLMs 中得到了观察,尽管最近一些团队对其提出了质疑。

## 人工智能中的涌现能力:我们是否在追逐一个神话?

改变对大型语言模型涌现特性的看法

towardsdatascience.com

尽管关于规模定律的讨论仍在进行中,但分析这种情况是否也适用于像MLP这样的简单模型仍然很有趣(毕竟,MLP 通过增加参数数量应倾向于过拟合)。

在这项研究中,作者还定义了一个具有递增参数数量的模型家族。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:here

的确,MLP 似乎也展现了类似幂律的行为。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:here

这确实是一个有趣的结果,因为它表明,即使像 MLP 这样的简单模型也可以展示假定的幂律行为。

MLP 是一个并非为处理图像而设计的模型。事实上,作者指出,由于 MLP 的归纳偏差较差,它更依赖于示例的数量。因此,虽然可以通过大量示例来弥补弱的 归纳偏差,但这需要大量的示例。

一个非常有趣的点是,这些模型都在单个 GPU 上运行。对于 ImageNet21k 的最大架构,单个周期在单个 24 GB GPU 上花费了 450 秒。换句话说,这些实验可以在任何商业 GPU 上快速运行。

作者指出,MLPs 显然更高效,可以使用更大的批量:

正如很快会显而易见的那样,MLP 在对单个图像进行预测时需要显著更少的 FLOPs,本质上更有条理地利用其参数。因此,与其他候选架构相比,延迟和吞吐量显著更好。(来源)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来源:here

结论

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Philip Myrtorp 拍摄,发布在 Unsplash 上。

归纳偏差 是机器学习的基本概念之一。一般来说,这是我们根据数据类型选择一个模型而非另一个模型的主要原因之一。尽管已经有很多研究,但理论上仍存在空白。

令人着迷的是,考虑到假设先验领域的狭窄可能导致更好的结果。不过,这也付出了代价,包括理论水平和模型复杂性。如前所述,尝试向 ViTs 模型添加归纳偏差会导致创建越来越复杂且计算上低效的模型。

尽管 MLP 是一个极其简单的模型,但它在计算上具有高效的优势,这也是它被用于许多研究以填补理论空白的原因之一。主要问题之一是 MLP 在计算机视觉中的表现远远逊色于其他模型。

最近的结果显示,通过适当的调整,可以克服这一差距。此外,缺乏归纳偏差可以通过扩展来弥补。因此,MLP 可以作为研究现代架构及其在不同情况下表现的良好代理。

为什么这一切如此重要?

一般来说,近年来的 AI 研究集中在一个单一的范式上:更多的参数,更多的数据。为了提高准确率,出现了新的竞争。尽管如此,自 2017 年以来,变换器的架构一直没有改变。

这些庞大的模型具有相当高的训练成本。最近几个月,对替代方案的研究兴趣开始增长:既包括用更少的参数获得相同的结果,也包括寻找替代于变换器(及其平方计算成本)的方法。

## Welcome Back 80s: Transformers Could Be Blown Away by Convolution

Hyena 模型展示了卷积如何比自注意力更快

levelup.gitconnected.com META’s LLaMA: A small language model beating giants

META 开源模型将帮助我们理解语言模型的偏见是如何产生的

medium.com

在每种情况下,学术研究都被迫追赶以行业为主导的研究。很少有机构能够从头训练一个大型语言模型。然而,像这样的研究 表明,即使是简单的模型如 MLP 也可以大规模获得结果。这为更好地理解模型行为并开始思考变换器的替代方案提供了非常有趣的视角。

你怎么看?请在评论中告诉我。

如果你觉得这有趣:

你可以查看我的其他文章,也可以 订阅 以在我发布文章时获得通知,你还可以 成为 Medium 会员 来访问所有故事(平台的附属链接,我从中获得少量收入,您无需支付额外费用),也可以在 LinkedIn 上与我联系或找到我。

这是我 GitHub 仓库的链接,我计划在这里收集与机器学习、人工智能等相关的代码和资源。

[## GitHub - SalvatoreRa/tutorial: 机器学习、人工智能、数据科学的教程…

机器学习、人工智能和数据科学的教程,包含数学解释和可重复使用的代码(用 Python 编写)

github.com

或者你可能对我最近的一篇文章感兴趣:

[## AI 大学生重返实验室

大型语言模型如何解决大学考试以及这为何重要

levelup.gitconnected.com [## 我们能检测 AI 生成的文本吗?

水印可能是检测的解决方案

levelup.gitconnected.com ## 说一次!重复单词对 AI 无帮助

重复标记如何以及为何会伤害大型语言模型?这是一个问题吗?

[towardsdatascience.com

参考文献

这是我撰写本文时参考的主要文献列表,仅列出了每篇文章的第一个名字。

  1. Goodman, Nelson. 《事实、虚构与预测》(第四版)。哈佛大学出版社,1983 年

  2. Battaglia 等人, 2018, 《关系归纳偏差、深度学习与图网络》,链接

  3. Kauderer-Abrams, 2017, 《卷积神经网络中的平移不变性定量化》,链接

  4. Ritter 等人, 2017, 《深度神经网络的认知心理学:形状偏置案例研究》,链接

  5. Conway 等人, 2018, 《下颞皮层的组织与功能》,链接

  6. Geirhos 等人, 2022, 《ImageNet 训练的 CNN 对纹理存在偏见;增加形状偏置可以提高准确性和鲁棒性》,链接

  7. Hermann 等人, 2020, 《卷积神经网络中的纹理偏置的起源与流行》,链接

  8. Li 等人, 2021, 《形状-纹理去偏神经网络训练》,链接

  9. Ghiasi 等人, 2022, 视觉变换器学到了什么?视觉探索,link

  10. Morrison 等人, 2021, 探索腐败鲁棒性:视觉变换器和 MLP-Mixer 的归纳偏差,link

  11. Mormille 等人, 2023, 通过基于 Gram 矩阵相似度的正则化在视觉变换器上引入归纳偏差,link

  12. Tolstikhin 等人, 2021, MLP-Mixer:一种全 MLP 架构用于视觉,link

  13. Cordonnier 等人, 2020, 自注意力与卷积层之间的关系,link

  14. Bachmann 等人, 2023, 扩展 MLP:归纳偏差的故事,link

  15. Kaplan 等人, 2020, 神经语言模型的扩展规律,link

  16. Lei Ba 等人, 2016, 层归一化,link

  17. He 等人, 2015, 深度残差学习用于图像识别,link

  18. Ridnik 等人, 2021, 面向大众的 ImageNet-21K 预训练,link

  19. Sharad Joshi, 2022, 你需要了解的一切:归纳偏差,MLearning.ai

医疗 AI 的基础模型

原文:towardsdatascience.com/a-foundation-model-for-medical-ai-7b97e3ab3893?source=collection_archive---------2-----------------------#2023-09-19

介绍 PLIP,一个病理学基础模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Federico Bianchi

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 9 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Tara Winstead 提供: www.pexels.com/photo/person-reaching-out-to-a-robot-8386434/

介绍

正在进行的 AI 革新带来了各个方面的创新。OpenAI 的 GPT 模型正在引领发展,并展示了基础模型如何实际上使我们的一些日常任务变得更轻松。从帮助我们写得更好到简化一些任务,我们每天都会看到新模型的发布。

许多机会正在我们面前展开。能够帮助我们工作生活的 AI 产品将成为我们在未来几年中获得的最重要工具之一。

我们将在哪里看到最具影响力的变化?我们可以在哪里帮助人们更快地完成任务?人工智能模型最令人兴奋的一个方向是将我们引向医疗 AI 工具。

在这篇博客文章中,我将PLIP(病理语言和图像预训练)描述为病理学的第一个基础模型之一。PLIP 是一个视觉-语言模型,可以用于将图像和文本嵌入到同一向量空间中,从而实现多模态应用。PLIP 源自 2021 年 OpenAI 提出的原始CLIP模型,并已在《Nature Medicine》上发表:

Huang, Z., Bianchi, F., Yuksekgonul, M., Montine, T., Zou, J., 一种用于病理图像分析的视觉-语言基础模型,通过医疗 Twitter。2023, Nature Medicine.

在开始我们的冒险之前,一些有用的链接:

除非另有说明,所有图像均由作者提供。

对比预训练 101

我们展示了通过社交媒体上的数据收集以及一些额外的技巧,我们可以构建一个可以在医疗 AI 病理任务中取得良好结果的模型——而无需标注数据。

虽然介绍 CLIP(PLIP 衍生的模型)及其对比损失超出了这篇博客文章的范围,但了解一下还是很有帮助。CLIP 背后的非常简单的想法是,我们可以构建一个将图像和文本放入一个向量空间的模型,其中“图像及其描述将会彼此接近”。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对比模型——如 PLIP/CLIP——将图像和文本置于同一向量空间进行比较。黄色框中的描述与黄色框中的图像匹配,因此它们在向量空间中也非常接近。

上面的 GIF 还展示了一个示例,说明了如何将图像和文本嵌入同一向量空间的模型用于分类:通过将所有内容置于同一向量空间,我们可以通过考虑向量空间中的距离,将每个图像与一个或多个标签关联起来:描述与图像越接近,效果越好。我们期望最接近的标签是图像的真实标签。

明确来说:一旦 CLIP 训练完成,你可以嵌入任何图像或任何文本。请注意,这个 GIF 展示的是二维空间,但一般而言,CLIP 使用的空间维度要高得多。

这意味着,一旦图像和文本处于相同的向量空间中,我们可以做很多事情:从零样本分类找到哪个文本标签与图像更相似)到检索找到哪个图像与给定描述更相似)。

我们如何训练 CLIP?简单来说,模型会接收大量的图像-文本对,并尝试将相似的匹配项放在一起(如上图所示),而将其他项远离。图像-文本对越多,你学到的表示就越好。

我们将在这里结束对 CLIP 背景的介绍,这应该足以理解本文的其余部分。我在 Towards Data Science 上有一篇关于 CLIP 的更深入的博客文章。

## 如何训练你的 CLIP

介绍 CLIP 及其在 HuggingFace 社区周期间如何为意大利语进行微调。

towardsdatascience.com

CLIP 已被训练成为一个非常通用的图像-文本模型,但对于特定的使用案例(例如,时尚(Chia 等,2022 年))效果不佳,而且在某些情况下,CLIP 表现不佳,而领域特定的实现效果更好(Zhang 等,2023 年)。

病理学语言和图像预训练(PLIP)

我们现在描述如何构建 PLIP,这是我们对原始 CLIP 模型进行微调的版本,专门针对病理学设计。

为病理学语言和图像预训练构建数据集

我们需要数据,而这些数据必须足够好,以用于训练模型。问题是我们如何找到这些数据? 我们需要的是具有相关描述的图像——就像我们在上面的 GIF 中看到的那样。

尽管网络上有大量的病理数据,但这些数据通常缺乏注释,并且可能以非标准格式存在,如 PDF 文件、幻灯片或 YouTube 视频。

我们需要换个地方寻找,而这个地方就是社交媒体。通过利用社交媒体平台,我们有可能接触到大量与病理学相关的内容。病理学家使用社交媒体在线分享自己的研究,并向同事提问(请参见 Isom 等,2017 年,讨论了病理学家如何使用社交媒体)。此外,还有一组一般推荐的Twitter 标签,病理学家可以用来进行沟通。点击这里 查看这些标签。

除了 Twitter 数据外,我们还收集了LAION 数据集(Schuhmann 等,2022 年)中的一个子集,LAION 是一个包含 50 亿图像-文本对的大型数据集。LAION 通过爬取网络收集而来,也是许多流行的 OpenCLIP 模型训练所使用的数据集。

病理学 Twitter

我们使用病理学 Twitter 标签收集了超过 10 万条推文。过程相当简单,我们使用 API 收集与特定标签相关的推文。我们去除包含问号的推文,因为这些推文通常包含对其他病理(例如,“这是什么肿瘤?”)的请求,而不是我们实际需要的信息来构建模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们提取包含特定关键词的推文,并去除敏感内容。此外,我们还去除了所有包含问号的推文,这些推文通常是病理学家向同事询问一些可能罕见的病例的提问。

从 LAION 采样

LAION 包含 50 亿图像-文本对,我们收集数据的计划如下:我们可以使用来自 Twitter 的图像,并在这个大型语料库中查找相似的图像;这样,我们应该能够获得相当相似的图像,并且这些相似的图像也可能是病理图像。

现在,手动进行这些操作是不可行的,嵌入和搜索 50 亿个嵌入是一个非常耗时的任务。幸运的是,LAION 有预计算的向量索引,我们可以通过 API 用实际图像查询!因此,我们简单地嵌入我们的图像并使用 K-NN 搜索在 LAION 中查找相似图像。请记住,这些图像每个都有一个标题,非常适合我们的用例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们通过在 LAION 数据集上使用 K-NN 搜索扩展数据集的设置非常简单。我们从原始语料库中的图像开始,然后在 LAION 数据集中搜索相似的图像。我们得到的每张图像都有一个实际标题。

确保数据质量

并非所有我们收集的图像都是好的。例如,我们从 Twitter 收集了大量医疗会议的合影。从 LAION,我们有时会得到一些类似分形的图像,这些图像可能模糊地类似于某些病理模式。

我们做的事情非常简单:我们使用一些病理数据作为正类数据,用 ImageNet 数据作为负类数据来训练分类器。这种分类器具有极高的精度(实际上,区分病理图像和网络上的随机图像很容易)。

此外,对于 LAION 数据,我们应用了英语语言分类器来去除非英语的示例。

训练病理语言和图像预训练

数据收集是最困难的部分。一旦完成且我们信任我们的数据,就可以开始训练。

为了训练 PLIP,我们使用了原始的 OpenAI 代码进行训练——我们实现了训练循环,添加了损失的余弦退火,并做了一些调整,以确保一切顺利进行并且可验证(例如,Comet ML 跟踪)。

我们训练了许多不同的模型(数百个),并比较了参数和优化技术。最终,我们得出了一个令人满意的模型。详细信息请参见论文,但在构建这种对比模型时,最重要的组成部分之一是确保在训练过程中批量大小尽可能大,这可以使模型学会区分尽可能多的元素。

医学 AI 的病理语言和图像预训练

现在是测试我们的 PLIP 模型的时候了。这个基础模型在标准基准测试上表现如何?

我们进行了不同的测试来评估 PLIP 模型的性能。其中最有趣的三个是零样本分类、线性探测和检索,但我主要关注前两个。在这里为了简洁,我将忽略实验配置,但这些都可以在手稿中找到。

PLIP 作为零样本分类器

下面的 GIF 演示了如何使用类似 PLIP 的模型进行零样本分类。我们使用点积作为向量空间中的相似性度量(点积越高,相似度越高)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

进行零样本分类的过程。我们将图像和所有标签进行嵌入,然后在向量空间中找出与图像最接近的标签。

在下图中,你可以看到 PLIP 与 CLIP 在我们用于零样本分类的一个数据集上的快速比较。使用 PLIP 替代 CLIP 能显著提升性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PLIP 与 CLIP 在两个数据集上的零样本分类性能(加权宏 F1)。注意 y 轴在约 0.6 处停止,而不是 1。

PLIP 作为线性探测的特征提取器

使用 PLIP 的另一种方法是作为病理图像的特征提取器。在训练过程中,PLIP 处理了许多病理图像,并学习为这些图像构建向量嵌入。

假设你有一些标注数据,想要训练一个新的病理分类器。你可以使用 PLIP 提取图像嵌入,然后在这些嵌入上训练一个逻辑回归(或你喜欢的任何回归器)。这是一种简单有效的分类任务方法。

为什么这样有效?这个想法是,PLIP 嵌入具有病理特异性,应该比 CLIP 嵌入(通用目的)更好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PLIP 图像编码器允许我们为每个图像提取一个向量,并在其上训练一个图像分类器。

下面是 CLIP 和 PLIP 在两个数据集上的性能比较示例。虽然 CLIP 的表现不错,但我们使用 PLIP 得到的结果要高得多。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

PLIP 与 CLIP 在两个数据集上进行线性探测的表现(宏 F1)。注意 y 轴从 0.65 开始,而不是 0。

使用病理语言和图像预训练

如何使用 PLIP?以下是一些使用 PLIP 的 Python 示例,以及一个你可以用来稍微玩一下模型的 Streamlit 演示。

代码:使用 PLIP 的 API

我们的 GitHub 仓库提供了一些额外的示例,你可以参考。我们已经构建了一个 API,允许你轻松地与模型进行交互:

from plip.plip import PLIP
import numpy as np

plip = PLIP('vinid/plip')

# we create image embeddings and text embeddings
image_embeddings = plip.encode_images(images, batch_size=32)
text_embeddings = plip.encode_text(texts, batch_size=32)

# we normalize the embeddings to unit norm (so that we can use dot product instead of cosine similarity to do comparisons)
image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True)
text_embeddings = text_embeddings/np.linalg.norm(text_embeddings, ord=2, axis=-1, keepdims=True)

你还可以使用更标准的 HF API 来加载和使用模型:

from PIL import Image
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("vinid/plip")
processor = CLIPProcessor.from_pretrained("vinid/plip")

image = Image.open("images/image1.jpg")

inputs = processor(text=["a photo of label 1", "a photo of label 2"],
                   images=image, return_tensors="pt", padding=True)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image 
probs = logits_per_image.softmax(dim=1) 

演示:PLIP 作为教育工具

我们还相信 PLIP 和未来的模型可以作为医学 AI 的有效教育工具。PLIP 允许用户进行零样本检索:用户可以搜索特定的关键词,PLIP 将尝试找到最相似/匹配的图像。我们在 Streamlit 中构建了一个简单的 Web 应用,你可以在这里找到它。

结论

感谢阅读这些内容!我们对这项技术未来可能的发展感到兴奋。

我将通过讨论 PLIP 的一些非常重要的局限性以及建议一些可能感兴趣的附加内容来结束这篇博客文章。

局限性

尽管我们的结果很有趣,但 PLIP 存在许多不同的局限性。数据不足以学习病理学所有复杂的方面。我们已经构建了数据过滤器以确保数据质量,但我们需要更好的评估指标来理解模型的正确与错误。

更重要的是,PLIP 并没有解决病理学当前的挑战;PLIP 不是一个完美的工具,可能会犯很多需要调查的错误。我们看到的结果无疑是有前景的,它们为未来在病理学中结合视觉和语言的模型打开了许多可能性。然而,在我们能看到这些工具在日常医学中应用之前,还有很多工作要做。

杂项

我还有一些关于 CLIP 建模和 CLIP 局限性的博客文章。例如:

## 教授 CLIP 一些时尚知识

训练 FashionCLIP,一个特定领域的 CLIP 模型用于时尚

towardsdatascience.com ## 你的视觉-语言模型可能是一个词袋

我们在 ICLR 2023 的口头报告中探讨了视觉-语言模型在语言方面的局限性

towardsdatascience.com

参考文献

Chia, P.J., Attanasio, G., Bianchi, F., Terragni, S., Magalhães, A.R., Gonçalves, D., Greco, C., & Tagliabue, J. (2022). 一般时尚概念的对比语言与视觉学习。Scientific Reports, 12

Isom, J.A., Walsh, M., & Gardner, J.M. (2017). 社交媒体与病理学:我们现在处于何处以及为何重要?Advances in Anatomic Pathology

Schuhmann, C., Beaumont, R., Vencu, R., Gordon, C., Wightman, R., Cherti, M., Coombes, T., Katta, A., Mullis, C., Wortsman, M., Schramowski, P., Kundurthy, S., Crowson, K., Schmidt, L., Kaczmarczyk, R., & Jitsev, J. (2022). LAION-5B:一个用于训练下一代图像-文本模型的开放大规模数据集。ArXiv, abs/2210.08402

Zhang, S., Xu, Y., Usuyama, N., Bagga, J.K., Tinn, R., Preston, S., Rao, R.N., Wei, M., Valluri, N., Wong, C., Lungren, M.P., Naumann, T., & Poon, H. (2023). 大规模领域特定预训练用于生物医学视觉-语言处理。ArXiv, abs/2303.00915

卫星图像基础模型

原文:towardsdatascience.com/a-foundation-model-for-satellite-images-dbf356c746a9?source=collection_archive---------8-----------------------#2023-11-04

Prithvi-100M IBM 地理空间 AI 基础模型用于 NASA 地球观测数据

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Caroline Arnold

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 11 月 4 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

阿尔巴尼亚卡拉瓦斯塔泻湖的卫星图像,2017 年。图像来源:www.esa.int/var/esa/storage/images/esa_multimedia/images/2017/03/karavasta_lagoon_albania/16854373-1-eng-GB/Karavasta_Lagoon_Albania.jpg。包含修改后的 Copernicus Sentinel 数据。

基础模型是灵活的深度学习算法,旨在处理通用任务,而不是立即专注于特定任务。在大量未标记数据上进行训练后,它们可以通过最少的微调应用于各种下游任务。基础模型在自然语言处理(BERT,GPT-x)和图像处理(DALL-E)中都很有名。

2023 年 8 月,NASA 和 IBM 发布了用于 NASA 地球观测数据的地理空间 AI 基础模型。该模型以 Prithvi 命名,开放源代码在Huggingface上,Prithvi 是印度教的大地女神。它已在 NASA 卫星数据上进行训练——根据 IBM超过 250 PB的数据可用。

在这篇博客文章中,我们讨论

  • 用于训练的 NASA 协调 Sentinel-2 Landsat 数据集,

  • Prithvi-100M 地理空间 AI 基础模型的架构,

  • 在 IBM 的 Vela 超级计算机上的训练过程,

  • 示例应用:洪水和作物类型识别。

训练数据

地理空间 AI 基础模型已在NASA 协调的 LandSat Sentinel-2 数据上进行训练。

Sentinel-2是由欧洲航天局协调的卫星任务,目前有两颗卫星在轨道上拍摄地球的高分辨率图像。它专注于陆地、沿海地区和特定的开放水域。Landsat 卫星由 NASA 发射,用于记录地表反射。协调数据结合了两个传感器的输入, resulting in a spatial resolution of about 30 meters and an average revisit time of two to three days. This resolution is sufficient for agricultural monitoring, land use classification, and natural disaster detection.

标准照片由红色、绿色和蓝色三种颜色组成。Sentinel-2 数据总共提供 13 种“颜色”,即所谓的波段,涵盖可见光、近红外和短波红外电磁谱范围。选择的波段可以用于识别不同的事物,例如,红外波段包含有关植被的信息。有关背景,请参见这篇文章关于 Sentinel-2 波段组合。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

夏威夷机场的伪彩色红外图像。图像来源:ESA sentinel 卫星图像,CC BY-SA 4.0 <creativecommons.org/licenses/by-sa/4.0>, 通过维基媒体共享资源。

云层阻碍了地球观测卫星的视线。为应对这一影响,Sentinel-2 提供了一个可用于识别云层覆盖的波段。受影响的像素被屏蔽,以免干扰图像处理算法。

因此,Sentinel-2 和 Landsat 数据是未标记的。需要大量的人力和专业知识才能提供逐像素的土地使用类别分类。基础模型高度通用,并从数据中提取结构,而无需在训练过程的初始阶段提供标记数据。因此,它们在地球观测数据方面显得非常有前途。

模型架构

Prithvi-100M 地理空间 AI 基础模型基于时间序列视觉变换器和掩蔽自编码器。模型卡显示在 Huggingface 上:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Huggingface 上的 Prithvi-100M 模型卡。图像来源:huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/GFM.png

该模型接受视频格式的 Landsat 图像作为输入。来自同一地点的图像被加载为时间序列,而静态图像可以通过将时间序列长度设置为 1 进行处理。波段对应于视觉变换器的通道。

视觉变换器

在 2020 年,Google Research 的团队展示了变换器不仅可以应用于自然语言处理,还可以应用于图像 (Dosovitsky et al, ICLR 2020)。在那之前,卷积神经网络一直是图像处理的事实上的标准。

视觉变换器首先将图像切割成小块,类似于对语言处理变换器进行句子的标记化。然后,添加可学习的嵌入和位置编码。在原始论文中,展示了在大量训练数据下,视觉变换器可以超越典型的计算机视觉架构,如 ResNet。

[## 视觉变换器(ViT)放大镜下,第一部分

嵌入

yurkovak.medium.com](https://yurkovak.medium.com/vision-transformer-vit-under-the-magnifying-glass-part-1-70be8d6661a7?source=post_page-----dbf356c746a9--------------------------------)

掩蔽自编码器

Prithvi-100M 掩蔽自编码器基于 He 等人(2021)的原始实现,arxiv.org/pdf/2111.06377.pdf。概念很简单:

图像中的随机块被掩蔽。自编码器学习预测缺失的像素。这类似于大型语言模型的训练,其中模型学习预测句子中缺失的单词。

在原始论文中,考虑了带有 RGB(红色、绿色、蓝色)颜色通道的 2D 图像。论文中广泛讨论了在语言数据和图像数据上进行训练的区别。

编码器仅在未被遮挡的图像块上工作,这样可以节省计算时间。嵌入由对单独图像块的线性投影来处理,该投影包含可学习参数。

位置嵌入很重要,以便算法知道图像块在原始图像中的位置。在遮挡自编码器的情况下,位置嵌入通过 2D 正弦-余弦函数提供,这种函数通常用于变换器模型。它对图像中的 2D 网格位置进行编码。位置嵌入可能包含可学习的参数,但在MAE 库的实现中似乎并非如此。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

遮挡自编码器的应用。左侧:原始图像的遮挡图像块。中间:重建。右侧:真实情况。图像来源:arxiv.org/pdf/2111.06377.pdf(图 2)

MAE 架构的变化

为了处理具有更多通道的卫星数据时间序列,NASA 和 IBM 团队对遮挡自编码器架构进行了若干修改。

  • 2D 图像块嵌入被更改为 3D 图像块嵌入。

  • 2D 位置嵌入被更改为 3D 位置嵌入。

  • 图像块创建考虑到数据的 3D 特性。

  • 除了 RGB 颜色外,还增加了一个近红外和两个短波红外波段。

损失函数

均方误差(MSE)损失用于训练,通过逐像素比较原始图像和重建图像。

模型训练

模型训练过程描述在 IBM 博客中:research.ibm.com/blog/nasa-hugging-face-ibm。遗憾的是,提供的细节不多。然而,IBM 提到他们在公司 AI 超级计算机 Vela 上进行了训练。Vela是一个完全基于云的超级计算机,仅为 IBM 研究部门运营。

超级计算机由 200 个节点组成。每个节点配备了 8 个 NVIDIA A100 GPU,每个 GPU 有 80 GB 的内存。节点 RAM 为 1.5 TB,并且配备四个 3.2 TB 的本地硬盘。这些配置能够处理训练基础模型所需的大数据集。节点之间通过一个能传输高达 100 GB/秒的网络连接。

应用

Prithvi-100M 地理空间 AI 基础模型可以应用于多种下游任务。我们专注于两个任务:洪水和作物类型识别。

洪水

保留 Prithvi-100M 的原始编码器部分,模型现在被调整为预测卫星图像中洪水的扩展。详细信息描述在 HuggingfaceSen1Floods11 数据集用于微调,涵盖了六大洲的 11 次洪水事件。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

微调地理空间 AI 基础模型以进行洪水检测。图像来源:huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11/blob/main/sen1floods11-finetuning.png

为了将 Prithvi-100M 准备好以应对下游任务,需要将嵌入形状转换回原始图像形状。然后,添加一个最终的 2D 卷积层,应用特定任务的分类。

图像中的每个像素被分类为水域或非水域(陆地)。由于这是一个分类问题,因此使用了二元交叉熵损失。一次只处理一张图像,因此未使用 Prithvi-100M 的时间序列功能。

作者报告了在玻利维亚的一个保留洪水事件中,平均准确率为 93%,平均交并比为 86%。

提供了一个演示页面,用户可以上传自己的 Sentinel-2 图像,并要求 Prithvi-100M 识别洪水。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

洪水识别演示的快照。黑色像素对应陆地,白色像素对应水域。图像来源:huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-sen1floods11-demo(使用 India_900498_S2Hand.tif)

作物类型识别

为了利用时间序列功能,作者提供了作物类型识别的演示。作物类型的实际情况由标记图像提供。这是一个多类分类问题,训练时使用了交叉熵损失。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作为 Prithvi-100M 的下游任务,进行多时相作物类型分类。图像来源:huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification/blob/main/multi_temporal_crop_classification.png

作者报告了不同作物类型的不同准确率。平均准确率为 64%,交并比为 46%。然而,作者指出实际情况存在噪声,更准确的标签将有助于改进这一下游任务。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Prithvi-100M 的作物类型演示。左侧三幅图显示卫星图像的时间序列。右侧图显示模型预测,每个像素根据作物类型着色。图片来源:huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification-demo

总结

我们已经介绍了地球空间 AI 基础模型,目前(2023 年)在 Huggingface 上以 Prithvi-100M 的名义是最大的地球空间模型。该模型由 IBM Research 和 NASA 开发,使用 Landsat 数据集进行训练。

我们已经介绍了地球空间 AI 基础模型的训练数据、架构和训练过程。该模型开放源代码,可以进行更具体任务的微调。洪水检测和作物类型识别应用展示了地球空间 AI 基础模型的巨大潜力。

由于 Sentinel-2 数据可用于个人非商业用途,有兴趣的用户可以创建适用于特定下游任务的自己的模型。在未来的帖子中,我将展示如何为植被识别和超分辨率微调地球空间 AI 基础模型。

进一步阅读

## 环境数据科学:简介

处理环境数据的示例、挑战和展望

towardsdatascience.com

基于自然法则的人本中心 AI 框架

原文:towardsdatascience.com/a-framework-for-a-human-centered-ai-based-on-the-laws-of-nature-a8bfbb233250?source=collection_archive---------16-----------------------#2023-05-08

整合自然智能与人工智能

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Tom Kehler

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 5 月 8 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

自然的涌现秩序(图片来源 iStock Getty Images 2090608323

在哈佛大学教师俱乐部举办的波士顿全球论坛高级会议“AI 助手监管峰会:促进科技启蒙经济联盟”上做了报告。此处展示的论文是该讲座的扩展。

我们面临许多十字路口。最近几个月的一个显著十字路口是 AI,导致了从恐惧到欣喜的广泛反应。毫无疑问,您现在已经体验到了与 ChatGPT 互动的乐趣。许多人已加入了采用的潮流。其他人则认为当前的 AI 表现不过是另一场追逐底线的竞赛,我们因为必须而抛弃了谨慎。其他人都在做,我们也必须这样做。汇聚的恶劣行为无人愿见——但存在因为没人知道如何建立信任的阴影。技术不是敌人。未能合作并建立信任会导致鲁莽采用,可能带来伤害。

在这个简要概述中,我希望为您提供一个建立信任和减少风险的 AI 未来框架。

这个框架最初由科学的创始人和启蒙时代的科学方法揭示。随后形成的科学方法为建立可信知识奠定了基础——这是一个依赖于集体人类智慧和对自然中出现优雅的信任的协作过程。

我们建议利用集体人类智慧和内置于生物系统物理中的智慧来指导我们前进。¹

几乎 70 年来,人工智能的科学探索集中在使用符号表示和推理工具构建自然智能和人类认知技能的手工模型上。它们能够解释如何解决问题。通过观察它们的推理来建立信任。

在过去 20 年里,互联网提供的数据爆炸带来的统计学习取得了显著成果——从自动驾驶汽车到今天将我们聚集在一起的大型语言模型。特别是,变压器深度学习架构解锁了生成型 AI 强大的潜力,这创造了我们今天看到的令人印象深刻的结果。

让我们今天关注的问题涉及三个基本问题。这是信息技术历史上第一次,我们没有执行数据来源的概念。 因此,这些巨大的生成能力可能成为误导性信息的有力传播者,破坏对知识的信任。第二个问题是可解释性——系统是黑箱。第三个问题是它们需要一个上下文感知。

这三点弱点与科学方法的三个支柱——引用、可重复性和结果的背景化——相矛盾。 我们该怎么办?

朱迪亚·珀尔说,“你比你的数据更聪明。”我们同意。人类的反事实思维能力远远强于我们从过去数据中的相关模式中学到的任何东西。

大型语言模型和深度学习架构通常基于数据中的模式识别和相关性模型来发展智能行为模型。LLM 的生成输出利用人工干预来过滤和训练结果。风险依然存在。过滤过程可能会遗漏包含错误信息的内容生成。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1(图片作者提供)

五年前,在MIT Technology Review的采访中,深度学习的奠基人之一,Yoshua Bengio 说:

“我认为我们需要考虑人工智能的艰巨挑战,而不是满足于短期的、渐进的进步。我不是说我想忘记深度学习。相反,我想在此基础上进行构建。但我们必须扩展它,以进行推理、学习因果关系和探索世界以获取信息。”²

目前基于历史数据模式相关性的模型不太可能捕捉到人脑能力的复杂性。人脑的想象力和基于经验生成因果模型的能力必须成为未来人工智能模型的一个重要部分。我们提出了一种结合人类集体智能和人脑模型的方法。

拉里·佩奇、谢尔盖·布林和特里·温诺格发现引用索引可以成为对网络信息进行有规模排序的方法。³ PageRank 算法为网络带来了秩序。引用索引的数学为理解人类协作中的信息共享带来了秩序。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:图片作者提供

一种新一代的人工智能,融合了人类集体推理,经过过去八年的开发,使用引用索引方法作为知识发现过程。它允许大规模的知识发现,支持引用、可重复性和情境化。我们建议将其作为未来框架的一部分。

集体推理旨在了解一个社区或小组对预期结果的集体偏好和信念。产品发布会产生我们想要的结果吗?如果我们改变远程工作的政策,我们的生产力会增加还是减少?使用 ChatGPT 和 LLMs 的最佳政策是什么?这些问题需要了解一个小组对预测结果的‘集体思维’。集体推理过程利用人工智能技术学习集体思维模型。该过程是单盲的,减少了偏见。系统经过四年的测试,针对 20 至 30 名专家/投资者预测初创企业成功的情况,准确率超过 80%。⁴ 每项投资的集体信念和预测被映射为集体知识模型——贝叶斯信念网络⁵。这些因果模型是该小组集体推理的生成可执行表示。

我们可以将科学知识发现过程中的关键元素融入我们共同创造或协作解决复杂问题的方式中。我们建议使用 AI 来学习集体知识模型、保持来源、可解释性和背景的因果模型,而不是让 AI 破坏对知识的信任。这是一种新的启蒙关键组成部分——将科学方法带入协作中。

集体推理允许学习一个群体的意图。基于代理的模拟在预测提议解决方案的影响方面很有用。基于公共数据的合成模型允许对共同创建的解决方案进行规模化和预测,我们建议将其作为框架的一部分。该倡议中的一个合作伙伴公司已经建立了一个重要的能力来大规模模拟影响,并将其应用于疾病传播的社会影响。⁶

未来 AI 的基础是什么?自 1956 年夏天 AI 诞生以来的 68 年里我们学到了什么?前几代开发了形成当前 AI 格局的组件。合作现象的数学和磁性的物理学在将这一切联系在一起方面发挥了令人兴奋的作用。1982 年,霍普菲尔德证明了人工神经网络的集体计算能力的涌现直接映射到自旋玻璃的数学物理学。⁷ 相同的合作现象数学描述了从混乱中涌现出的秩序,如本文开头的燕群照片所示。

最近,MIT 的 Lin、Rolnick 和 Tegmark 显示,深度学习和廉价学习之所以效果如此好,与物理定律有关。贝叶斯学习被重新表述为量子和经典物理学中使用的基本方法——哈密顿量。⁸ 明确关注 AI 在自然法则中的根源应成为未来 AI 发展的重点。

一切的核心在于从无序中学习秩序。大脑中的新一波研究将学习在秩序/无序边界上的理论应用于创建活的智能系统——自由能原理。⁹

FEP 是一个基于贝叶斯学习的框架。大脑被认为是一个贝叶斯概率机器。如果感官输入与期望不匹配,主动推理会寻求最小化未来的不确定性。我们期望和感知之间的差异称为惊讶,并表示为自由能(可用于行动的能量)。寻找最小自由能的路径等同于寻找减少惊讶(不确定性)的路径。

基于 FEP 的 AI 在本地进行适应,并根据物理和生物科学中使用的变分自由能最小化原则进行扩展。Bioform Labs 正在构建一个适应和学习的生物 AI。¹⁰ 与需要大量训练数据集和复杂成本函数的第二代 AI 不同,基于生命系统物理学的 AI 是适应性的,并且存在于一个生态系统中。它可以被设计为尊重导致生命系统需求的状态。

启动这一新框架所需的技术今天就可以应用。我们不需要暂停 AI 的发展。集体推理适用于我们需要问自己关于 AI 在各种具体背景下影响的问题。AI 将如何影响技术投资?它将如何改变我们的招聘实践?它对我们的社区有什么影响?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3(作者提供的图片)

此外,可以在保留隐私边界的情况下,参与 ChatGPT 和 LLMs 的创意过程。来自 LLM 的创意可以在特定的私人背景中进行策划和使用。策划和情境化的贡献在一个获得专利的私人 LLM 环境中进行管理。¹¹

集体推理学习意图和可能的解决方案。基于代理的模拟预测影响。我们不再需要将组织视为僵化的。基于主动推理的新型组织治理支持适应性学习生存路径。我们相信这一框架是未来的愿景,将为新的 AI 赋能的启蒙提供基础。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4(作者提供的图片)

一个新的 AI 赋能的启蒙。正如启蒙时代使科学摆脱了宗教权威的压迫一样,新倡议——AI 赋能的启蒙,提供了一条协作和共同创造解决方案的路径——使我们摆脱当前 AI 狂热潮的不良后果。

总结来说,大型语言模型提供了非常有用的能力,这些能力以令人印象深刻的速度展开。请阅读警告标签!ChatGPT 确实警告不要盲目相信结果,而要使用批判性思维。不要暴露私人数据。关于私人数据,图 3 和图 4 展示了一种通过允许 ChatGPT 或其他“代理”提供输入与人类专家进行精心策划的合作的方法,结果保留在私密管理的 LLM 环境中。这种方法允许在保留私人知识产权的情况下探索 LLM 的生成能力。

(1) 本文提出的框架源于 MIT 媒体实验室于 3 月 6 日举行的会议,由 BioForms 的 John Clippinger、Crowdsmart 的 Kim Polese 以及其他几位参与者共同发起。波士顿全球论坛的 CEO Tuan Nguyen 参加了会议,并在会后与波士顿全球论坛一起创建了 AI 治理框架bostonglobalforum.org/

(2)’Knight, Will,(2018 年 11 月 17 日)“AI 之父之一对其未来感到担忧”MIT Technology Review

(3) Page, L., Brin, S., Motwani, R. and Winograd, T. (1998) 页面排名引用排名:为网络带来秩序。技术报告 SIDL-WP-1999–0120,斯坦福数字图书馆技术项目。

(4) 更多信息请参阅AI-guided Co-creation

(5) 专利 US11366972 分配给CrowdSmart.ai

(6) Epistemix.com

(7) Hopfield JJ. 1982. 神经网络和具有突现集体计算能力的物理系统。Proc. Natl Acad. Sci. USA 79, 2554–2558

(8) Lin, H.W., Tegmark, M. & Rolnick, D. 为什么深度和廉价学习效果如此显著?. J Stat Phys 168, 1223–1247 (2017). doi.org/10.1007/s10955-017-1836-5

(9) Friston, K. 自由能原理:统一的大脑理论?. Nat Rev Neurosci 11, 127–138 (2010). doi.org/10.1038/nrn2787

(10) bioformlabs.org/

(11) 专利 — 管理和测量知识发现过程中的语义覆盖。2022/072895

分析流失的框架

原文:towardsdatascience.com/a-framework-for-analyzing-churn-370d2283b75c?source=collection_archive---------4-----------------------#2023-01-13

使用模拟数据集进行客户流失分析的逐步指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gabriele Albini

·

阅读 发表在 Towards Data Science ·14 分钟阅读·2023 年 1 月 13 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:JESHOOTS.COMUnsplash

介绍

客户流失”已经成为一个常见的商业词汇,它指的是流失率的概念,维基百科定义为:

“在给定时间段内离开供应商的合同客户或订阅者的比例”

从数据角度分析流失时,我们通常意味着使用现有工具提取有关现有客户群的信息,具体来说:量化当前的流失率并了解可能影响/预防未来流失的因素。

因此,当我们开发“流失模型”时,应该考虑使用现有数据并且要有两个目标:

  1. 为现有活跃客户预测流失

  2. 对影响客户流失决策的因素进行一些假设,识别出可能减少流失的潜在措施。

预测流失需要大量工作,这不是一项容易的任务,但更重要的是,它甚至不是最终目标:这是设计和实施客户“留存”策略的起点!

本文将重点介绍流失分析框架的实现,灵感来源于书籍:[1]《用数据对抗流失》,作者是卡尔·S·戈德。这是一本推荐给所有处理流失数据的人的优秀书籍:书中详细介绍了流失分析的全过程,提供了很多细节和示例(包括解释代码!)。在所有建议的步骤中,我提取了在我的经验中最相关和成功的部分,并将其调整为我熟悉的背景和数据集。本文将该框架应用于一个模拟数据集,灵感来源于一个真实的商业案例(Github 仓库链接)。

目录:

1- 数据

1.1- 开发流失模型时应考虑哪些数据?

1.2- 原始数据

2- 数据预处理:流失指标

2.1- 创建客户指标

2.2- 分析流失指标

3- 使用机器学习进行流失预测

3.1- 逻辑回归

3.2- 随机森林

3.3- XGBoost

4- 生成流失预测

5- 下一步

参考文献

1. 数据

1.1 开发流失模型时应考虑哪些数据?

这不是一个简单的问题!很多不同的信息可能与流失相关,制定通用规则永远无法涵盖所有可能的业务、系统、背景等。例如,在考虑流失相关信息时,我们可能会考虑:

  • 关于客户(或账户)的基本信息:性别、位置、年龄、任期等

  • 与订阅相关的信息:客户订阅的产品、激活的附加功能、激活和取消日期等

  • 支付信息:客户支付了多少?他们使用什么支付方式?他们是否定期付款?

  • 产品使用信息:登录信息、点击信息、与产品的互动分钟数等

  • 与客户支持互动相关的信息:客户进行的聊天或电话、支持服务的评分、投诉细节等

翻译成系统后,这些数据需要来自各种事务系统(CRM、ERP、计费等),并应适当地组织到某些数据湖/数据仓库中(理想情况下,频繁地拍摄覆盖几个月)。考虑到这一点,需要有大量的专业知识来了解哪些字段代表了哪些信息,通常,访问这些数据需要大量的批准,特别是如果外部顾问想要使用这些数据的话。

根据我的经验,所有这些数据(以及相关的历史记录)很少可用。通常,可用且已组织的数据是公司因财务或法律法规要求或仅因运行日常业务所需的数据。这些数据必须在某处可用。

例如:假设我们是一家提供按需视频培训的公司,我们需要知道客户拥有哪些订阅以及他们支付了多少,以提供我们的服务并制作财务报表。然而,我们不一定需要存储客户 XYZ 在完成特定视频之前暂停视频的具体时间。

鉴于所有这些原因,为了保持文章简洁和现实,我将重点关注一个“较小”的数据集,理想情况下这些数据应来自任何 CRM 中应有的数据。

1.2 原始数据

让我们假设我们是一家通过网站提供在线视频课程的 B2C(商业对消费者)公司。我们的业务运作方式如下:

  • 新用户可以订阅两个领域的课程:机器学习(领域 A)和吉他(领域 B)。他们可以购买多个订阅,从而允许不同用户同时登录。此外,他们还可以选择包含“附加服务”的选项,该服务包括每周与所选领域的专家进行在线直播。

  • 一旦订阅,用户将每月支付费用,并可能有或没有折扣。他们可以随时取消订阅,这意味着订阅将在月底不会续订。

  • 用户可以打开实时聊天并联系支持团队解决任何问题。

原始数据将如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

虚拟原始数据 | 图片来源作者

这种商业背景非常常见,适用于任何具有每月订阅和将附加服务添加到基本报价的 B2C 业务(例如:按需媒体内容、电信、公用事业、电子商务、保险和高级软件等企业)。

2. 数据预处理:流失指标

从原始数据开始,我们需要预测客户是否会流失。我们将把客户视为一个整体,无论他们有多少个订阅。

由于我们的目标是预测流失以制定留存策略,我们需要提前知道客户是否会流失,以便我们可以采取措施影响这一决定。

2.1 创建客户指标

考虑到以上原始数据,我们可以生成哪些 KPI?以下是一些想法(它们是我们数据集的列):

  • “mrr_ratio” = 这是按订阅计算的每月经常性收入。因此,对于每个客户:我们对每个有效订阅求和([每月费用 — 折扣]),然后计算有效订阅的数量,并将两者相除。

  • “mrr_ratio_A”和“mrr_ratio_B” = 这些是按领域计算的每月经常性收入(A 是机器学习;B 是吉他),考虑领域内的 mrr 和活跃订阅数量。

  • “subs_A”和“subs_B” = 按领域的活跃订阅数量

  • “discount_ratio” = 客户的折扣百分比,计算方法为:1 — ([每月费用 — 折扣] / [每月费用])

  • “has_addon” = 一个标志,指示客户是否有至少一个带附加组件的订阅

  • “support_chats” = 客户在一个期间内发起的聊天次数

  • “is_churn” = 一个标志,指示客户是否将要流失(1)或不流失(0)

我认为使用我们历史原始数据来计算这些 KPI 的最佳方法是:

  • 确定一些固定的观察期(例如每月 20 号),留出一些合理的时间从我们的续订(我们假设每月 30 号发生)。

  • 创建一个表“A”,在其中,对于每一个“流失”的客户,我们包括他们过去的流失日期。

  • 创建另一个表“B”,其中,对于每个客户,在每月 20 号,我们根据过去 30 天的数据计算 KPI。换句话说,我们每月 20 号对客户指标进行月度快照。

  • 我们将表“A”和“B”按客户 ID 连接,并标记所有将在下一个观察日期流失的行。

这些观察期和 KPI 通常在数据仓库中计算,然后导出到 Python。我为项目模拟的数据正好代表了这种情况。假设我们刚刚从数据仓库中获得了以下数据集:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

虚拟流失指标 | 图片由作者提供

(注意:这是一个模拟数据集。所有连续指标都从一个多变量高斯分布中提取,近似真实数据。这就是为什么我们有负值和应不为负值或小数的 KPI 的原因。此外,每一行应对应一个客户 ID,但此信息并不相关)。

2.2 分析流失指标

一旦我们拥有一些指标,我们可以开始检查它们与流失的关系。

最直观的方式来调查这种关系是通过队列分析。通常,通过将每个指标数据拆分成 10 个相等大小的桶来生成 10 个队列,具体取决于它们的值。然后,我们通过计算每个队列中的流失率,将每个指标与“is_churn”标志相关联。如果指标不是连续的且具有少于 10 个分类值,那么我们只考虑每个类别一个队列。

在左侧图表中,我们可以看到,平均而言,拥有更高 mrr_ratio 的客户流失更多,因为他们每个订阅支付更多:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

流失指标队列 | 作者提供的图像

每当我们看到这样的行为,具有逻辑意义,并且根据指标值,我们看到流失有显著差异时,我们可以期望该指标在我们的分析中是相关的。

相反,如果我们看到指标中,无论队列的平均值如何,对流失没有影响(例如水平线),那么我们可能会考虑将该指标从模型中排除。

3. 机器学习中的流失预测

我们现在将使用数据集来预测流失。

请注意,流失的预测是不简单的。决定流失是主观的,而且可能并不总是一个逻辑选择:一个客户可能因为费用问题而流失,其他客户可能因为质量问题而流失。此外,糟糕的客户服务或对产品/品牌的负面感受也可能主观地引发流失决定。

基于这些原因,模型的表现不会像其他机器学习任务那样高。根据 Carl S. Gold [1]的说法,一个健康的流失预测模型的 AUC 得分应在 0.6 到 0.8 之间。

需要考虑的一些因素:

  • 流失是一个二分类任务:模型将学习预测记录是否属于类 1(流失客户)或类 0(未流失)。然而,我们将关注每条记录属于每个类别的概率。在选择模型时,请记住这一点。

  • 模型表现不能通过准确率来衡量。通常,少数客户流失,因此我们的数据集是不平衡的:仅约 10%的虚拟数据属于类 1(流失客户)。任何总是预测类 0 的模型将具有 90%的准确率,但这样的模型完全没有帮助。相反,我们将使用roc_auc 得分来衡量性能。

  • 我们将使用交叉验证来调整模型的超参数。由于我们处理的是时间序列数据集,我们不能简单地使用随机记录分配到每个折叠。我们需要训练我们的模型使用当前或过去的数据,而不是未来的数据。因此,最佳实践建议使用时间序列分割(来自sklearn [2]),它适用于任何按时间排序的数据集。

(注意:在交叉验证中,通常使用 10 个拆分。这里由于数据量有限和数据对类 0 极度不平衡,使用了 3 个拆分)。

现在让我们比较三种分类模型。

3.1 逻辑回归

逻辑回归是一个广义的线性回归模型,这是一种非常常见的分类技术,尤其用于二分类问题。由于它是一个回归模型,许多假设需要事先验证;例如,我们不应违反“无多重共线性”假设,这意味着我们需要确保没有特征是相关的,即每个特征应提供独特且独立的信息。

尽管这很容易验证,我可以预见逻辑回归不会是性能最好的模型,因此我们不会使用我们可能获得的任何无效结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

逻辑回归 | 作者插图

3.2 随机森林

随机森林是一种基于树的集成方法。

基于树的方法 是非常强大的分类(或回归)算法,它们通过根据多个决策节点来划分我们的训练数据。每个决策节点根据特定特征执行一次“划分”,做出 True / False 决策。划分决策的确定方式是为了在树的下一层尽可能减少我们的数据集的“熵”。熵是数据无序程度的度量,它与我们在分类/回归任务中可以获得的“信息增益”相关。

  • 例如,在一个二分类问题中,如果我们注意到通过根据一个特征来划分数据,我们得到的每个 True/False 结果分支中——95% 的数据属于一个类别,5% 的数据属于另一个类别,那么我们就成功地从数据中获得了更多的信息,降低了数据的无序程度或“熵”。

随机森林(RF) 构建了多个不同的树,然后取这些树的平均值或最频繁的结果来做最终预测。RF 确保每棵树与其他树的构建方式不同,这得益于两种方法:

  • Bagging(自助聚合):每棵树都是通过使用整个训练集的样本进行训练的,因此每棵树都是使用不同的数据构建的。

  • 特征随机性:每棵树都是通过限制可用特征来构建的,使用所有可用特征的一个子集。

现在让我们在数据集上调整 RF 的超参数,选择最佳模型,并展示最“重要”的特征(即每个特征用于生成决策分裂的频率):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

随机森林 | 作者插图

3.3 XGBoost

XGBoost 代表极端梯度提升,它是另一种基于树的集成技术,与 RF 类似,允许将多个决策树的预测结果进行结合。

XGBoost 是“梯度提升”方法的一个进化(“极端”)版本。因此,为了说明 XGBoost,让我们分别考察这两个方面。

  • 在“梯度提升”方法中,与随机森林(RF)不同,构建的树之间有很大关联。预测是由“弱学习者”(即简单树)做出的,这些树会不断改进。通常,初始预测是目标值的平均值,然后通过创建新树进行精炼。每棵新树是基于前一棵树的错误构建的:因此,从前一轮“弱学习者”的残差/错误预测开始,建立新树,最小化成本函数,并对产生错误的属性分配更多权重。最后,通过加权每棵树的结果来组合结果。

  • 从“梯度提升”开始,“极端梯度提升”是一个完整的算法,包括对梯度提升方法的几项改进,如性能优化和正则化参数(可以避免过拟合)。最重要的是,得益于这些附加元素,XGBoost 可以在像普通笔记本电脑这样的简单机器上运行。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4. 生成流失预测

表现最好的模型是 XGBoost,我们将使用它来预测 测试集(包含在训练阶段未使用的新记录)的流失概率。

在导入测试集后,我们计算每条记录属于类别 1(流失客户)的模型预测概率,并绘制 ROC_AUC 分数:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

测试集 AUC 分数 | 图片由作者提供

让我们将预测的类别添加到原始数据中。默认情况下,所有预测概率 ≥ .5 的记录将被分配到类别 1。我们可以降低这个阈值,并比较结果的混淆矩阵:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

混淆矩阵 | 图片由作者提供

通过降低阈值,我们可以识别更多的流失客户(真正的正例和假阳性),但仍有相当数量的客户会流失但我们未能识别(假阴性),尽管我们的 xgboost 模型表现良好。

我们可以尝试找到更好的模型,但预测流失通常很困难。因此,除了使用简单的流失与非流失区分外,一个想法是利用我们预测的概率来定义一些不同的留存策略:

  • 对于预测概率大于 .75 的客户 = 高风险流失,我们可以设计一种“强力”的留存策略。由于我们预期的假阳性很少,因此我们可以更有信心地对这些客户进行投资。

  • 预测概率在 .5 和 .75 之间的客户 = 中等流失风险和“中等”留存策略。

  • 预测概率在 .25 和 .5 之间的客户 = 低风险流失和“弱”留存策略。

5. 下一步

在这个阶段,我们应该有一个能够为任何新数据分配“流失概率”的工作模型。

我们分析的下一步是进一步定义前面提到的保留策略。我们的策略应包括:(a)可能导致流失减少的行动;(b)如何衡量我们行动的成功;(c)最后,推广计划。

这里有一些解决这些问题的想法:

确定导致流失减少的行动:

让我们结合上面看到的特征重要性与我们的预测。例如,两个基于树的模型将“subs_B”列为树中使用最多的特征。我们需要深入了解流失和非流失客户在 subs_B 方面的情况。之前看到的群体分析将有助于这里:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在训练数据上进行群体分析 | 作者图像

看起来高流失的客户有最低值(即 0 订阅,数据已经被转换,因此 x 轴值在这里不太易于解释),或者“subs_B”的数量过多。我们必须小心地得出“subs_B”和“is_churn”之间的因果结论,因为此分析并未证明任何因果关系。然而,我们可以测试一些假设:

  • 看起来客户对我们的 B 产品感到满意,将“B”产品交叉销售给仅拥有 A 产品的客户,是否有助于减少流失?

  • 我们还应该了解客户拥有这么多“B”订阅背后的业务原因。我们可以教育他们更有效地使用我们的产品,从而减少 B 订阅。

如何衡量我们行动的成功

一旦我们确定了一些建议的行动,我们可以规划我们的测量方法。

A/B 测试是一种非常常见的方式:

  • 我们从具有类似预测流失概率的客户中创建两个可比样本。一个样本将代表我们的处理组,并将暴露于我们的流失减少策略,另一个样本将代表我们的对照组,不会暴露于任何保留行动。

  • 我们希望证明我们的处理组的流失率显著低于对照组。

推广计划:

在建议保留行动时,我们不应忘记考虑其他背景因素。举几个例子:流失的担忧程度如何?(即是否有大量新客户以弥补流失?)解决问题的预算是多少?我们应该等多久才能看到结果?我们可以使用其他数据来改善模型吗?已经做了哪些工作?

这将帮助我们了解我们建议的可行性。

谢谢阅读!!

参考文献

[1] Carl S. Gold — “用数据对抗流失:客户保留的科学与策略”,2020 年

[2] Scikit-learn: Python 中的机器学习,Pedregosa ,JMLR 12,第 2825–2830 页,2011 年

构建生产就绪特征工程管道的框架

原文:towardsdatascience.com/a-framework-for-building-a-production-ready-feature-engineering-pipeline-f0b29609b20f

全栈 7 步 MLOps 框架

课程 1: 批量服务。特征存储。特征工程管道。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Paul Iusztin

·发表于 Towards Data Science ·13 分钟阅读·2023 年 4 月 28 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Hassan PashaUnsplash 上的照片

本教程代表一个包含 7 课时的课程中的第 1 课,将逐步指导你如何设计、实施和部署 ML 系统,使用MLOps 优良实践。在课程中,你将构建一个生产就绪的模型,用于预测丹麦未来 24 小时内的能源消耗水平,涵盖多个消费类型。

完成本课程后,你将了解使用批量服务架构设计、编码和部署 ML 系统的所有基本知识。

本课程针对中级/高级机器学习工程师,旨在通过构建自己的端到端项目来提升技能。

如今,证书随处可见。构建先进的端到端项目并展示是获得专业工程师认可的最佳途径。

目录:

  • 课程介绍

  • 课程内容

  • 数据源

  • 课程 1: 批量服务。特征存储。特征工程管道。

  • 课程 1: 代码

  • 结论

  • 参考文献

介绍

在这 7 课时的课程结束时,你将学会如何:

  • 设计批量服务架构

  • 使用 Hopsworks 作为特征存储

  • 设计一个从 API 读取数据的特征工程管道

  • 构建带有超参数调优的训练管道

  • 使用 W&B 作为 ML 平台来跟踪你的实验、模型和元数据

  • 实现批量预测管道

  • 使用 Poetry 构建自己的 Python 包

  • 部署自己的私人 PyPi 服务器

  • 使用 Airflow 协调一切

  • 使用预测结果编码一个使用 FastAPI 和 Streamlit 的 Web 应用

  • 使用 Docker 容器化你的代码

  • 使用 Great Expectations 确保数据验证和完整性

  • 监控预测性能的变化

  • 将所有内容部署到 GCP

  • 使用 GitHub Actions 构建 CI/CD 流水线

如果这些听起来很多,不用担心,完成本课程后你将理解我之前说的一切。最重要的是,你将了解我为何使用这些工具以及它们如何作为一个系统协同工作。

如果你想最大化本课程的收益, 我建议你访问包含所有课程代码的 GitHub 仓库 。我设计了这些文章,使你在阅读课程的同时可以阅读并运行代码。

到课程结束时,你将学会如何实现下面的图示。如果有些内容对你来说不太明白,不用担心。我会详细解释一切。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

课程中你将构建的架构图 [图示来源于作者]。

为什么批量服务?

模型的部署主要有 4 种类型:

  • 批量服务

  • request-response

  • 流式处理

  • 嵌入式

批量服务是获取实际操作经验的绝佳起点,因为大多数 AI 应用程序从使用批量架构开始,然后转向请求响应或流式处理。

课程内容:

  1. 批量服务。特征存储。特征工程流水线。

  2. 训练流水线。ML 平台。超参数调整。

  3. 批量预测流水线。使用 Poetry 打包 Python 模块。

  4. 私人 PyPi 服务器。使用 Airflow 协调一切。

  5. 使用 GE 进行数据验证以确保质量和完整性。模型性能持续监控。

  6. 使用 FastAPI 和 Streamlit 消费和可视化你的模型预测。将一切容器化。

  7. 将所有 ML 组件部署到 GCP。使用 Github Actions 构建 CI/CD 流水线。

  8. [附加] ‘不完美’ ML 项目的幕后——教训与见解

数据源:

我们使用了一个开放 API,提供丹麦所有能源消费者类型的每小时能源消耗值。

他们提供了一个直观的界面,你可以轻松查询和可视化数据。你可以在这里访问数据 [1]。

数据有 4 个主要属性:

  • 小时 UTC: 观察到数据点时的 UTC 日期时间。

  • 价格区域: 丹麦被划分为两个价格区域:DK1 和 DK2——由大贝尔特海峡分隔。DK1 位于大贝尔特的西侧,DK2 位于东侧。

  • 消费者类型: 消费者类型为工业代码 DE35,由丹麦能源公司拥有和维护。

  • 总消耗: 总电力消耗(kWh)

注意: 观察值有 15 天的滞后!但对于我们的演示用例,这不是问题,因为我们可以模拟与实时相同的步骤。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

应用程序中的截图展示了我们如何预测区域 = 1 和消费者类型 = 212 的能源消耗 [作者提供的图片]。

数据点具有每小时的分辨率。例如:“2023–04–15 21:00Z”,“2023–04–15 20:00Z”,“2023–04–15 19:00Z”等等。

我们将把数据建模为多个时间序列。每个独特的价格区域消费者类型元组表示其独特的时间序列。

因此,我们将构建一个模型,独立预测每个时间序列接下来 24 小时的能源消耗。

查看下面的视频,更好地理解数据的样子 👇

课程与数据源概览 [作者提供的视频]。

第 1 课:批量服务。特征存储。特征工程管道。

第 1 课的目标

在第 1 课中,我们将关注图中蓝色突出显示的组件:“API”,“特征工程”和“特征存储”。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终架构的示意图,其中第 1 课的组件用蓝色突出显示 [作者提供的图片]。

具体来说,我们将构建一个 ETL 管道,从能源消耗 API 中提取数据,经过特征工程管道,该管道清洗和转换特征,并将特征加载到特征存储中,以便在系统中进一步使用。

如你所见,特征存储站在系统的核心位置。

理论概念与工具

批量服务: 在批量服务模式中,你可以离线准备数据、训练模型并进行预测。之后,你将预测结果存储在数据库中,客户端/应用程序将在后续使用这些预测结果。批量这个词来源于你可以同时处理多个样本,这在这种模式下通常是有效的。我们计算了所有预测结果并将其存储在 blob 存储/桶中。

如果我们将架构过于简化,仅反映批量架构的主要步骤,它将如下所示 👇

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

批处理架构 [作者提供的图片]。

批处理服务范式的最大缺点是你的预测几乎总是会滞后。例如,在我们的案例中,我们预测未来 24 小时的能耗,由于这种滞后,我们的预测可能会迟到 1 小时。

查看这篇文章以了解更多关于Google Cloud 建议的 标准化架构,这在几乎任何机器学习系统中都可以利用。

**特征存储:**特征存储位于任何机器学习系统的核心。使用特征存储,你可以轻松地存储和共享系统中的特征。你可以直观地将特征存储视为一个高级数据库,增加以下功能:

  • 数据版本控制和血缘

  • 数据验证

  • 创建数据集的能力

  • 保存训练/验证/测试拆分的能力

  • 两种存储类型:离线(便宜,但延迟高)和在线(更贵,但延迟低)。

  • 时间旅行:在给定时间窗口内轻松访问数据

  • 除了特征本身外,还保存特征转换

  • 数据监控等……

如果你想阅读关于特征存储的内容,请查看这篇文章 [3]。

我们选择了Hopsworks作为我们的特征存储,因为它是无服务器的,并提供了慷慨的免费计划,这足以创建本课程。

此外,Hopsworks 设计非常优秀,并提供了上述所有功能。如果你在寻找无服务器的特征存储,我推荐他们。

如果你还想在阅读本课程时运行代码,你需要去Hopswork,创建一个账户和项目。所有其他步骤将在课程的其余部分中解释。

我确保了本课程中的所有步骤都能保留在他们的免费计划中。因此,它不会花费你任何$$$。

**特征工程管道:**读取来自一个或多个数据源的数据,清洗、转换、验证数据并将其加载到特征存储中的代码片段(基本上是 ETL 管道)。

**Pandas vs. Spark:**我们在本课程中选择使用 Pandas 作为数据处理库,因为数据较小。因此,它可以轻松地适应计算机的内存,使用如 Spark 这样的分布式计算框架会使一切变得过于复杂。但在许多现实世界的场景中,当数据太大无法适应单台计算机(即大数据)时,你将使用 Spark(或其他分布式计算工具)来完成与本课程相同的步骤。查看这篇文章以了解 Spark 如何处理大数据预测流失。

课程 1: 代码

你可以在这里访问 GitHub 仓库。

注意: 所有安装说明都在仓库的 README 文件中。这里我们将直接进入代码部分。

Lesson 1 中的所有代码都位于 feature-pipeline文件夹下。

feature-pipeline文件夹下的文件结构如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

显示 feature-pipeline 文件夹结构的截图[作者提供]。

所有代码都位于feature_pipeline目录下(注意是"_“而不是”-")

准备凭证

在本课程中,你将使用一个单一服务作为你的特征存储:Hopsworks(在我们的用例中,它将是免费的)。

Hopsworks上创建一个账户和一个新项目(或使用默认项目)。注意不要将你的项目命名为“energy_consumption”,因为 Hopsworks 要求在其无服务器部署中项目名称唯一。

现在,你需要一个来自HopsworksAPI_KEY来登录并使用他们的 Python 模块访问云资源。

直接在你的 git 仓库中存储凭证是一个巨大的安全隐患。因此,你将使用**.env文件注入敏感信息。.env.default**是你必须配置的所有变量的示例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

.env.default文件的截图[作者提供]。

从你的feature-pipeline目录中,在终端中运行:

cp .env.default .env

…并在FS_API_KEY变量下填写你新生成的 Hopsworks API KEY,在FS_PROJECT_NAME变量下填写你的 Hopsworks 项目名称(在我们的例子中,它是*“energy_consumption”*)。

查看下图,了解如何获取你自己的 Hopsworks API KEY 👇

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

进入你的Hopsworks项目。然后,在右上角点击你的用户名,再点击“账户设置”。最后,点击“新建 API KEY”,设置一个名称,选择所有作用域,点击“创建 API KEY”,复制 API KEY,你就完成了。你已经拥有了 Hopsworks API KEY[作者提供]。

然后,在feature_pipeline/settings.py文件中,我们将使用老牌的dotenv Python 包从**.env**文件中加载所有变量。

如果你想从当前目录以外的地方加载**.env文件,你可以在运行脚本时导出ML_PIPELINE_ROOT_DIR**环境变量。这是一个指向其余配置文件的"HOME"环境变量。

我们还将使用ML_PIPELINE_ROOT_DIR环境变量来指向一个单一目录,从中加载**.env**文件,并在所有进程中读写数据。

这是一个如何使用ML_PIPELINE_ROOT_DIR变量的示例:

export ML_PIPELINE_ROOT_DIR=/my/awesome/path python -m feature_pipeline.pipeline

使用以下代码,我们将通过 SETTINGS 字典访问代码中的所有凭据/敏感信息。

ETL 代码

feature_pipeline/pipeline.py 文件中,我们在**run()**方法下有管道的主要入口点。

如下所示,run 方法在高层次上遵循了 ETL 管道的确切步骤:

  1. extract.from_api() — 从能源消耗 API 提取数据。

  2. transform() — 转换提取的数据。

  3. validation.build_expectation_suite() — 构建数据验证和完整性套件。忽略这一步,因为我们将在第 6 课中重点讲解。

  4. load.to_feature_store() — 将数据加载到特征存储中。

请注意我如何使用日志记录器来反映系统的当前状态。当你的程序部署并全天候运行时,详细的日志记录对于调试系统至关重要。此外,总是使用 Python 的日志记录器而不是 print 方法,因为你可以选择不同的日志级别和输出流。

从高层次来看,这似乎很容易理解。让我们分别深入了解每个组件。

#1. 提取

在提取步骤中,我们请求给定窗口长度的数据。窗口的长度将等于days_export。窗口的第一个数据点是export_end_reference_datetime - days_delay - days_export,而窗口的最后一个数据点等于export_end_reference_datetime - days_delay

我们使用了参数days_delay来根据数据的延迟移动窗口。在我们的使用案例中,API 延迟为 15 天。

如上所述,该函数向 API 发出 HTTP GET 请求以请求数据。随后,响应被解码并加载到 Pandas DataFrame 中。

该函数返回 DataFrame 以及包含有关数据提取信息的附加元数据。

#2. 转换

转换步骤将原始 DataFrame 进行如下转换:

  • 将列重命名为 Python 标准化格式

  • 将列转换为其适合的类型

  • 将字符串列编码为整数

注意我们没有包括 EDA 步骤(例如,查找空值),因为我们的主要关注点是设计系统,而不是标准的数据科学过程。

#3. 数据验证

这是我们确保数据符合预期的地方。在我们的案例中,基于我们的 EDA 和转换,我们期望:

  • 数据中没有任何空值

  • 列的类型如预期

  • 值的范围如预期

有关此主题的更多内容,请参见第 6 课。

#4. 加载

这是我们将处理后的 DataFrame 加载到特征存储中的地方。

Hopsworks 有一系列很棒的教程,你可以在这里查看。但让我解释一下发生了什么:

  • 我们使用 API_KEY 登录到我们的 Hopsworks 项目中。

  • 我们获取特征存储的引用。

  • 我们获取或创建一个特征组,这基本上是一个数据库表,上面附加了特征存储的所有优点(更多信息请见这里 [5])。

  • 我们插入新的处理数据样本。

  • 我们为数据中的每个特征添加一组特征描述。

  • 我们指示 Hopsworks 为每个特征计算统计信息。

查看下面的视频,看看我刚才解释的内容在 Hopsworks 中是什么样的 👇

Hopsworks 概述[作者的视频]。

太棒了!现在我们有了一个 Python ETL 脚本,它从能耗 API 中提取数据,并将其加载到特征存储中。

创建特征视图与训练数据集

最后一步是创建一个特征视图和训练数据集,稍后将被引入训练管道中。

注意: 特征管道是唯一一个对特征存储进行写操作的过程。其他组件仅会查询特征存储中的各种数据集。通过这样做,我们可以安全地将特征存储作为唯一的真实来源,并在系统中共享特征。

feature_pipeline/feature_view.py文件中,我们有一个**create()**方法,它运行以下逻辑:

  1. 我们从特征管道中加载元数据。请记住,FE 元数据包含提取窗口的开始和结束时间、特征组的版本等。

  2. 我们登录 Hopswork 项目并创建对特征存储的引用。

  3. 我们删除所有旧的特征视图(通常,你不需要执行这一步。正好相反,你会希望保留旧的数据集。但是,Hopwork 的免费版本限制你只能使用 100 个特征视图。因此,我们想要保留我们的免费版本)。

  4. 我们根据给定版本获取特征组。

  5. 我们使用从加载的特征组中得到的所有数据创建一个特征视图。

  6. 我们仅使用给定的时间窗口创建训练数据集。

  7. 我们创建元数据的快照并保存到磁盘。

注意: 特征视图是一种将多个特征组组合成一个“数据集”的智能方法。它类似于 SQL 数据库中的 VIEW。你可以在这里 [4]了解更多关于特征视图的信息。

就这样。你建立了一个特征管道,它提取、转换并加载数据到特征存储中。基于特征存储中的数据,你创建了一个特征视图和训练数据集,这些将作为系统中的唯一真实来源。

注意: 你需要良好的软件工程原则和模式知识来构建健壮的特征工程管道。你可以在这里阅读一些实践示例

重要的设计决策

正如你所看到的,我们在这节课中实际上没有计算任何特征。我们只是清理、验证并确保数据已经准备好用于系统中。

但这被称为“特征管道”,为什么我们没有计算任何特征呢?

让我解释一下。

特征 = 原始数据 + 转换函数

如果我们将原始数据和转换函数存储在特征库中,而不是计算和存储特征,会怎样呢?

这样做我们可以获得以下好处:

  • 更快的实验,因为数据科学家不需要请求数据工程师计算新特征。他只需将新转换添加到特征库中。

  • 你可以节省大量存储空间。例如,与其保存 5 个从同一原始数据列计算出的特征,不如只保存原始数据列和 5 个转换,这样只使用原来的 1/5 的空间。

使用这种方法的缺点:

  • 你的特征将在运行时通过云端或推理管道进行计算。因此,你将在运行时增加额外的延迟。

但在使用批量服务范式时,延迟不是一个显著的限制。因此,我们确实这样做了!

查看第 2 课 以了解我们如何建模时间序列以预测接下来 24 小时的能源消耗。在 第 2 课 中,我们将展示如何将转换直接存储在特征库中。

结论

恭喜!你完成了第一课来自全栈 7 步 MLOps 框架课程。

你了解了如何设计批量服务架构以及开发自己的 ETL 管道,这些管道:

  • 从 HTTP API 中提取数据

  • 清理数据

  • 转换数据

  • 将数据加载到特征库中

  • 创建一个新的训练数据集版本

现在你已经理解了使用特征库的强大功能及其对任何 ML 系统的重要性,你可以在几周内而不是几个月内部署你的模型。

查看第 2 课 以了解有关训练管道、机器学习平台和超参数调整的信息。

此外你可以在这里访问 GitHub 仓库。

💡 我的目标是帮助机器学习工程师在设计和生产化 ML 系统方面提升水平。关注我 LinkedIn 或订阅我的 每周通讯 以获取更多见解!

🔥 如果你喜欢阅读类似的文章并希望支持我的写作,考虑 成为 Medium 会员。通过使用 我的推荐链接,你可以在没有额外费用的情况下支持我,同时享受 Medium 丰富故事的无限访问。

[## 通过我的推荐链接加入 Medium - 保罗·伊斯津]

🤖 加入以获取关于设计和构建生产级机器学习系统的独家内容 🚀 解锁完整访问权限…

pauliusztin.medium.com

参考文献

[1] 丹麦 API 中的 DE35 行业代码能源消耗丹麦能源数据服务

[2] Hopsworks 教程,Hopsworks 文档

[3] 吉姆·道林,特征存储与数据仓库(2020 年),KDnuggets

[4] Hopsworks 特征视图,Hopsworks 文档

[5] Hopsworks 特征组,Hopsworks 文档

《温和介绍:通过 LangChain 链接 LLMs、代理和工具》

原文:towardsdatascience.com/a-gentle-intro-to-chaining-llms-agents-and-utils-via-langchain-16cd385fca81?source=collection_archive---------0-----------------------#2023-04-21

#初学者的 LLM

理解代理、工具和提示的基础知识以及一些学习经验

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Varshita Sher 博士

·

关注 发布于 Towards Data Science ·20 分钟阅读·2023 年 4 月 21 日

受众:对于那些被庞大(但卓越)库感到不知所措的人…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者使用DALL.E 2生成的图像

介绍

如果我说我掌握了整个 LangChain 库,那我就是在撒谎——实际上,我远远没有做到。但是,围绕它的热议足以让我摆脱写作 hiatus,去尝试一下 🚀。

最初的动机是看看 LangChain 在实践中添加了什么(在实际水平上),这使它不同于上个月我用openai包中的ChatCompletion.create()函数构建的聊天机器人。在这样做的过程中,我意识到需要先理解 LangChain 的基础构建块,然后再转向更复杂的部分。

这就是本文所做的事情。请注意,随着我对这个库的着迷和持续探索,将会有更多的部分出现。

让我们从理解 LangChain 的基本构建块 —— 即链条开始。如果你想跟进,请查看这个GitHub 仓库

LangChain 中的链条是什么?

链条是通过以逻辑方式连接一个或多个大型语言模型(LLMs)而得到的。 (虽然链条可以由除 LLMs 以外的实体构建,但现在让我们暂时使用这个定义以简化问题)。

OpenAI 是一种 LLM(提供者),你可以使用它,但还有其他像 Cohere、Bloom、Huggingface 等。

注意:几乎所有这些 LLM 提供者都需要您申请 API 密钥才能使用它们。所以请确保在继续阅读本博客的其余部分之前,您已经这样做了。例如:

import os
os.environ["OPENAI_API_KEY"] = "..."

P.S. 我将在本教程中使用 OpenAI,因为我有一个一个月后过期的积分密钥,但请随意替换为任何其他 LLM。无论如何,这里涵盖的概念都将是有用的。

链条可以简单(例如通用)或专业化(例如实用)。

  1. 通用 — 单个 LLM 是最简单的链条。它接受一个输入提示和 LLM 的名称,然后使用 LLM 进行文本生成(即输出提示的结果)。这里是一个例子:

让我们构建一个基本的链条 —— 创建一个提示并获取预测结果

在 Lanchain 中,使用PromptTemplate创建提示(Prompt)有点花哨,但这可能是因为根据用例的不同,可以有多种不同的方式来创建提示(我们将涵盖AIMessagePromptTemplate等等)。

HumanMessagePromptTemplate等等将在下一篇博客文章中涵盖。现在先看一个简单的例子:

from langchain.prompts import PromptTemplate

prompt = PromptTemplate(
    input_variables=["product"],
    template="What is a good name for a company that makes {product}?",
)

print(prompt.format(product="podcast player"))

# OUTPUT
# What is a good name for a company that makes podcast player?

注意:如果您需要多个 *input_variables*,例如:* *input_variables=["product", "audience"]* 用于模板,例如 *“一个公司的好名字,为{product}制作{audience}”*,则需要执行* print(prompt.format(product="podcast player", audience="children”) *以获取更新后的提示。

一旦您建立了一个提示,我们就可以调用所需的 LLM。为此,我们创建一个LLMChain实例(在我们的例子中,我们使用OpenAI的大型语言模型text-davinci-003)。要获取预测结果(即 AI 生成的文本),我们使用run函数和product的名称。

from langchain.llms import OpenAI
from langchain.chains import LLMChain

llm = OpenAI(
          model_name="text-davinci-003", # default model
          temperature=0.9) #temperature dictates how whacky the output should be
llmchain = LLMChain(llm=llm, prompt=prompt)
llmchain.run("podcast player")

# OUTPUT
# PodConneXion

如果你有多个输入变量,那么就不能使用run。相反,你需要将所有变量作为dict传递。例如,llmchain({"product": "podcast player", "audience": "children"})

注意 1:根据 OpenAI*davinci* 文本生成模型的费用是其聊天对应模型的 10 倍,即 *gpt-3.5-turbo*,因此我尝试从文本模型切换到聊天模型(即从 *OpenAI* *ChatOpenAI*),结果差别不大。

注意 2:你可能会看到一些教程使用 *OpenAIChat*而不是 *ChatOpenAI*。前者已经 弃用 并且将不再受支持,我们应使用 *ChatOpenAI*

from langchain.chat_models import ChatOpenAI

chatopenai = ChatOpenAI(
                model_name="gpt-3.5-turbo")
llmchain_chat = LLMChain(llm=chatopenai, prompt=prompt)
llmchain_chat.run("podcast player")

# OUTPUT
# PodcastStream

这部分关于简单链的介绍到此为止。需要注意的是,我们很少将通用链作为独立链使用。更常见的是它们作为实用链的构建块(正如我们接下来会看到的)。

2. 实用工具 — 这些是专门的链,由许多 LLM 组成,以帮助解决特定任务。例如,LangChain 支持一些端到端的链(如[AnalyzeDocumentChain](https://python.langchain.com/docs/use_cases/question_answering/how_to/analyze_document) 用于总结、问答等)和一些特定的链(如[GraphQnAChain](https://python.langchain.com/en/latest/modules/chains/index_examples/graph_qa.html#querying-the-graph) 用于创建、查询和保存图形)。在本教程中,我们将深入探讨一个名为 PalChain 的特定链。

PAL 代表 程序辅助语言模型PALChain 读取复杂的数学问题(用自然语言描述)并生成程序(用于解决数学问题)作为中间推理步骤,但将解决步骤委托给如 Python 解释器等运行时。

为了确认这一点,我们可以检查基础代码中的 _call() 这里。在底层,我们可以看到这个链:

附注:检查 *_call()* *base.py* 中是一个好习惯,可以查看 LangChain 中的任何链如何在底层工作。

from langchain.chains import PALChain
palchain = PALChain.from_math_prompt(llm=llm, verbose=True)
palchain.run("If my age is half of my dad's age and he is going to be 60 next year, what is my current age?")

# OUTPUT
# > Entering new PALChain chain...
# def solution():
#    """If my age is half of my dad's age and he is going to be 60 next year, what is my current age?"""
#    dad_age_next_year = 60
#    dad_age_now = dad_age_next_year - 1
#    my_age_now = dad_age_now / 2
#    result = my_age_now
#    return result
#
# > Finished chain.
# '29.5'

注意 1:如果你不需要看到中间步骤,*verbose* 可以设置为 *False*。*

现在,有些人可能会想 — 但提示呢?我们肯定没有像我们建立的通用 *llmchain* 那样传递它。 实际上,当使用.from_math_prompt()时,它会自动加载。您可以使用palchain.prompt.template检查默认提示,或者直接查看提示文件这里

print(palchain.prompt.template)
# OUTPUT
# 'Q: Olivia has $23\. She bought five bagels for $3 each. How much money does she have left?\n\n# solution in Python:\n\n\ndef solution():\n    """Olivia has $23\. She bought five bagels for $3 each. How much money does she have left?"""\n    money_initial = 23\n    bagels = 5\n    bagel_cost = 3\n    money_spent = bagels * bagel_cost\n    money_left = money_initial - money_spent\n    result = money_left\n    return result\n\n\n\n\n\nQ: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n\n# solution in Python:\n\n\ndef solution():\n    """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?"""\n    golf_balls_initial = 58\n    golf_balls_lost_tuesday = 23\n    golf_balls_lost_wednesday = 2\n    golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n    result = golf_balls_left\n    return result\n\n\n\n\n\nQ: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n\n# solution in Python:\n\n\ndef solution():\n    """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"""\n    computers_initial = 9\n    computers_per_day = 5\n    num_days = 4  # 4 days between monday and thursday\n    computers_added = computers_per_day * num_days\n    computers_total = computers_initial + computers_added\n    result = computers_total\n    return result\n\n\n\n\n\nQ: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n\n# solution in Python:\n\n\ndef solution():\n    """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?"""\n    toys_initial = 5\n    mom_toys = 2\n    dad_toys = 2\n    total_received = mom_toys + dad_toys\n    total_toys = toys_initial + total_received\n    result = total_toys\n    return result\n\n\n\n\n\nQ: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n\n# solution in Python:\n\n\ndef solution():\n    """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?"""\n    jason_lollipops_initial = 20\n    jason_lollipops_after = 12\n    denny_lollipops = jason_lollipops_initial - jason_lollipops_after\n    result = denny_lollipops\n    return result\n\n\n\n\n\nQ: Leah had 32 chocolates and her sister had 42\. If they ate 35, how many pieces do they have left in total?\n\n# solution in Python:\n\n\ndef solution():\n    """Leah had 32 chocolates and her sister had 42\. If they ate 35, how many pieces do they have left in total?"""\n    leah_chocolates = 32\n    sister_chocolates = 42\n    total_chocolates = leah_chocolates + sister_chocolates\n    chocolates_eaten = 35\n    chocolates_left = total_chocolates - chocolates_eaten\n    result = chocolates_left\n    return result\n\n\n\n\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n\n# solution in Python:\n\n\ndef solution():\n    """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"""\n    cars_initial = 3\n    cars_arrived = 2\n    total_cars = cars_initial + cars_arrived\n    result = total_cars\n    return result\n\n\n\n\n\nQ: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n\n# solution in Python:\n\n\ndef solution():\n    """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"""\n    trees_initial = 15\n    trees_after = 21\n    trees_added = trees_after - trees_initial\n    result = trees_added\n    return result\n\n\n\n\n\nQ: {question}\n\n# solution in Python:\n\n\n'

注意:大多数实用链条的提示作为库的一部分是预定义的(在这里查看 这里)。它们有时非常详细(即:有很多令牌),因此在成本和 LLM 响应质量之间肯定存在权衡。

是否存在不需要 LLM 和提示的链条?

尽管 PalChain 需要一个 LLM(以及相应的提示)来解析用户用自然语言编写的问题,但在 LangChain 中有一些链条不需要。这些主要是预处理提示的转换链条,例如删除额外的空格,然后将其输入 LLM。你可以在另一个例子中看到 这里

我们能进入精彩部分并开始创建链条吗?

当然可以!我们已经有了开始逻辑地将 LLM 连接在一起的基本构建块。为此,我们将使用SimpleSequentialChain

文档中有一些很好的例子,例如,你可以在这里看到如何组合两个链条,其中链条#1 用于清理提示(删除额外空格,缩短提示等),而链条#2 用于使用这个干净的提示调用 LLM。这里还有另一个例子,其中链条#1 用于为一部戏剧生成简介,而链条#2 则用于基于此简介撰写评论。

虽然这些都是很好的例子,但我想专注于其他事情。如果你还记得,我提到链条可以由除了 LLM 以外的实体组成。更具体地说,我对将代理和 LLM 组合在一起很感兴趣。但首先,什么是代理?

使用代理动态调用 LLM

对于解释代理的作用与其是什么,将会更加容易。

假设我们想知道明天的天气预报。如果我们使用简单的 ChatGPT API 并给它一个提示Show me the weather for tomorrow in London,它不会知道答案,因为它无法访问实时数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们能有一个安排,利用 LLM 理解我们的查询(即提示),然后代表我们调用天气 API 来获取所需数据,那不是很有用吗?这正是代理所做的(当然还有其他事情)。

代理可以访问 LLM 和一套工具,例如 Google 搜索、Python REPL、数学计算器、天气 API 等。

LangChain 支持很多代理——完整列表请见这里,但坦率说,我在教程和 YouTube 视频中最常见的代理是 zero-shot-react-description。这个代理使用了ReAct(Reason + Act)框架,根据输入查询从工具列表中选择最合适的工具。

P.S.: 这里 有一篇深入探讨 ReAct 框架的好文章。

让我们使用 initialize_agent 初始化一个代理,并传递它所需的 toolsLLM。代理可以使用的工具清单可以在这里找到。对于我们的例子,我们使用了上面提到的同一个数学解决工具,叫做 pal-math。这个工具在初始化时需要一个 LLM,因此我们传递给它之前相同的 OpenAI LLM 实例。

from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.agents import load_tools

llm = OpenAI(temperature=0)
tools = load_tools(["pal-math"], llm=llm)

agent = initialize_agent(tools,
                         llm,
                         agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                         verbose=True)

让我们在上述相同的例子上测试一下:

agent.run("If my age is half of my dad's age and he is going to be 60 next year, what is my current age?")

# OUTPUT
# > Entering new AgentExecutor chain...
# I need to figure out my dad's current age and then divide it by two.
# Action: PAL-MATH
# Action Input: What is my dad's current age if he is going to be 60 next year?
# Observation: 59
# Thought: I now know my dad's current age, so I can divide it by two to get my age.
# Action: Divide 59 by 2
# Action Input: 59/2
# Observation: Divide 59 by 2 is not a valid tool, try another one.
# Thought: I can use PAL-MATH to divide 59 by 2.
# Action: PAL-MATH
# Action Input: Divide 59 by 2
# Observation: 29.5
# Thought: I now know the final answer.
# Final Answer: My current age is 29.5 years old.

# > Finished chain.
# 'My current age is 29.5 years old.'

注意 1:在每一步,你会注意到代理做了三件事之一——它要么有一个 *observation*,要么有一个 *thought*,要么采取一个 *action*。这主要是由于 ReAct 框架和代理使用的相关提示:

print(agent.agent.llm_chain.prompt.template)
# OUTPUT
# Answer the following questions as best you can. You have access to the following tools:
# PAL-MATH: A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.

# Use the following format:

# Question: the input question you must answer
# Thought: you should always think about what to do
# Action: the action to take, should be one of [PAL-MATH]
# Action Input: the input to the action
# Observation: the result of the action
# ... (this Thought/Action/Action Input/Observation can repeat N times)
# Thought: I now know the final answer
# Final Answer: the final answer to the original input question
# Begin!
# Question: {input}
# Thought:{agent_scratchpad}

注意 2:你可能会想,为什么要让代理做 LLM 可以做的事情。一些应用不仅需要一个预定的 LLM/其他工具调用链,可能还需要一个取决于用户输入的未知链 [来源]。在这些类型的链中,有一个“代理”,可以访问一套工具。

例如,* 这是 一个代理的示例,它可以根据问题是指文档 A 还是文档 B,获取正确的文档(从向量存储中)。

为了有趣,我尝试使输入问题更复杂(用 Demi Moore 的年龄作为 Dad 实际年龄的占位符)。

agent.run("My age is half of my dad's age. Next year he is going to be same age as Demi Moore. What is my current age?")

不幸的是,答案有些偏差,因为代理没有使用最新的 Demi Moore 年龄(由于 OpenAI 模型的训练数据截至到 2020 年)。这可以通过包含另一个工具轻松修复——

tools = load_tools([“pal-math”, **"serpapi"**], llm=llm)serpapi 对于回答当前事件的问题非常有用。

注意: 添加尽可能多的相关工具对用户查询是很重要的。使用单一工具的问题在于,即使它不适用于特定的观察/行动步骤,代理也会继续尝试使用相同的工具。

这是另一个你可以使用的工具示例——podcast-api。你需要获取你自己的 API 密钥并将其插入下面的代码中。

 tools = load_tools(["podcast-api"], llm=llm, listen_api_key="...")
agent = initialize_agent(tools,
                         llm,
                         agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                         verbose=True)

agent.run("Show me episodes for money saving tips.")

# OUTPUT
# > Entering new AgentExecutor chain...
# I should search for podcasts or episodes related to money saving
# Action: Podcast API
# Action Input: Money saving tips
# Observation:  The API call returned 3 podcasts related to money saving tips: The Money Nerds, The Rachel Cruze Show, and The Martin Lewis Podcast. These podcasts offer valuable money saving tips and advice to help people take control of their finances and create a life they love.
# Thought: I now have some options to choose from 
# Final Answer: The Money Nerds, The Rachel Cruze Show, and The Martin Lewis Podcast are great podcast options for money saving tips.

# > Finished chain.

# 'The Money Nerds, The Rachel Cruze Show, and The Martin Lewis Podcast are great podcast options for money saving tips.'

注意 1: 有一个 已知错误 *,在使用这个 API 时你可能会看到,*openai.error.InvalidRequestError: This model’s maximum context length is 4097 tokens, however you requested XXX tokens (XX in your prompt; XX for the completion). Please reduce your prompt; or completion length.* 当 API 返回的响应可能过大时会发生这种情况。为了解决这个问题,文档建议返回更少的搜索结果,例如,通过将问题更新为 "Show me episodes for money saving tips, return only 1 result"

注意 2: 在使用这个工具时,我注意到了一些不一致的地方。响应第一次生成时并不总是完整的,例如,以下是两次连续运行的输入和响应:

输入: “提高法语水平的播客”

回应 1: “学习法语的最佳播客是评价分数最高的那个。”

回应 2: “学习法语的最佳播客是‘FrenchPod101’。”

在底层,这个工具首先使用 LLMChain 来构建 API URL,根据我们的输入指令(类似于 [listen-api.listennotes.com/api/v2/search?q=french&type=podcast&page_size=3](https://listen-api.listennotes.com/api/v2/search?q=french&type=podcast&page_size=3%29))和进行 API 调用。接收到响应后,它使用另一个 LLMChain 来总结响应,以获得对我们原始问题的回答。你可以在这里查看两个 LLMchains 的提示,它们详细描述了这个过程。

我倾向于猜测上述不一致的结果是由于总结步骤造成的,因为我已经通过 Postman 单独调试并测试了由 LLMChain#1 创建的 API URL,并且得到了正确的响应。为了进一步确认我的疑虑,我还对总结链进行了压力测试,作为一个独立链使用了一个空的 API URL,希望它能抛出一个错误,但得到了*“发现了‘投资’播客,总共有 3 个结果。”* 🤷‍♀ 我很好奇其他人是否在使用这个工具时比我更幸运!

用例 2:结合链创建一个适合年龄的礼物生成器

让我们充分利用代理和顺序链的知识,创建我们自己的顺序链。我们将结合:

  • 链 #1 — 我们刚创建的agent,能够解决数学中的年龄问题

  • 链 #2 — 一个 LLM,它接受一个人的年龄并建议一个适合他们的礼物。

# Chain1 - solve math problem, get the age
chain_one = agent

# Chain2 - suggest age-appropriate gift
template = """You are a gift recommender. Given a person's age,\n
 it is your job to suggest an appropriate gift for them.

Person Age:
{age}
Suggest gift:"""
prompt_template = PromptTemplate(input_variables=["age"], template=template)
chain_two = LLMChain(llm=llm, prompt=prompt_template) 

现在我们已经准备好了两个链,我们可以使用SimpleSequentialChain将它们结合起来。

from langchain.chains import SimpleSequentialChain

overall_chain = SimpleSequentialChain(
                  chains=[chain_one, chain_two],
                  verbose=True)

需要注意几点:

  • 我们不需要为SimpleSequentialChain明确传递input_variablesoutput_variables,因为其基本假设是链 1 的输出作为链 2 的输入。

最后,我们可以用之前的数学问题来运行它:

question = "If my age is half of my dad's age and he is going to be 60 next year, what is my current age?"
overall_chain.run(question)

# OUTPUT
# > Entering new SimpleSequentialChain chain...

# > Entering new AgentExecutor chain...
# I need to figure out my dad's current age and then divide it by two.
# Action: PAL-MATH
# Action Input: What is my dad's current age if he is going to be 60 next year?
# Observation: 59
# Thought: I now know my dad's current age, so I can divide it by two to get my age.
# Action: Divide 59 by 2
# Action Input: 59/2
# Observation: Divide 59 by 2 is not a valid tool, try another one.
# Thought: I need to use PAL-MATH to divide 59 by 2.
# Action: PAL-MATH
# Action Input: Divide 59 by 2
# Observation: 29.5
# Thought: I now know the final answer.
# Final Answer: My current age is 29.5 years old.

# > Finished chain.
# My current age is 29.5 years old.

# Given your age, a great gift would be something that you can use and enjoy now like a nice bottle of wine, a luxury watch, a cookbook, or a gift card to a favorite store or restaurant. Or, you could get something that will last for years like a nice piece of jewelry or a quality leather wallet.

# > Finished chain.
# '\nGiven your age, a great gift would be something that you can use and enjoy now like a nice bottle of wine, a luxury watch, a cookbook, or a gift card to a favorite store or restaurant. Or, you could get something that will last for years like a nice piece of jewelry or a quality leather wallet

有时你可能需要将一些额外的上下文传递给第二个链,而不仅仅是从第一个链接收的内容。例如,我想根据第一个链返回的年龄为礼物设定预算。我们可以使用SimpleMemory来实现。

首先,让我们更新chain_two的提示,并在input_variables中传递一个名为budget的第二个变量。

template = """You are a gift recommender. Given a person's age,\n
 it is your job to suggest an appropriate gift for them. If age is under 10,\n
 the gift should cost no more than {budget} otherwise it should cost atleast 10 times {budget}.

Person Age:
{output}
Suggest gift:"""
prompt_template = PromptTemplate(input_variables=["output", "budget"], template=template)
chain_two = LLMChain(llm=llm, prompt=prompt_template)

如果你比较我们为SimpleSequentialChain准备的template与上述的模板,你会注意到我还将第一个输入的变量名从age更新为output。这是一个关键步骤,如果失败,将在链验证时引发错误 — *缺少必需的输入键:{age},只有 {input, output, budget}*

这是因为链中的第一个实体(即agent)的输出将作为第二个实体(即chain_two)的输入,因此变量名必须匹配**。** 检查agent的输出键时,我们发现输出变量叫做output,因此进行了更新。

print(agent.agent.llm_chain.output_keys)

# OUTPUT
["output"]

接下来,让我们更新我们正在制作的链的类型。我们不能再使用SimpleSequentialChain,因为它仅适用于单输入单输出的情况。由于chain_two现在需要两个input_variables,我们需要使用SequentialChain,它专门处理多个输入和输出。

overall_chain = SequentialChain(
                input_variables=["input"],
                memory=SimpleMemory(memories={"budget": "100 GBP"}),
                chains=[agent, chain_two],
                verbose=True)

需要注意几点:

  • SimpleSequentialChain不同,对于SequentialChain,传递input_variables参数是强制性的。它是一个包含链中第一个实体(即我们案例中的agent)期望的输入变量名称的列表。

    现在,你们中的一些人可能想知道如何知道agent将要使用的输入提示中使用的确切名称。我们确实没有为这个agent(如我们为chain_two所做的那样)编写过提示!事实上,通过检查llm_chain的提示模板,找出它其实非常简单。

print(agent.agent.llm_chain.prompt.template)

# OUTPUT
#Answer the following questions as best you can. You have access to the following tools:

#PAL-MATH: A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.

#Use the following format:

#Question: the input question you must answer
#Thought: you should always think about what to do
#Action: the action to take, should be one of [PAL-MATH]
#Action Input: the input to the action
#Observation: the result of the action
#... (this Thought/Action/Action Input/Observation can repeat N times)
#Thought: I now know the final answer
#Final Answer: the final answer to the original input question

#Begin!

#Question: {input}
#Thought:{agent_scratchpad}

正如您可以在提示的最后看到的那样,最终用户提出的问题存储在一个名为input的输入变量中。如果因某种原因您必须在提示中操纵这个名称,请确保在创建SequentialChain时同时更新input_variables

最后,您可以在不查看整个提示的情况下找到相同的信息:

print(agent.agent.llm_chain.prompt.input_variables)

# OUTPUT
# ['input', 'agent_scratchpad']
  • [SimpleMemory](https://github.com/hwchase17/langchain/blob/master/langchain/memory/simple.py#L6) 是一种存储上下文或其他信息片段的简便方法,这些信息在提示之间不应更改。它在初始化时需要一个参数 — memories。您可以以dict形式传递元素给它。例如,SimpleMemory(memories={“budget”: “100 GBP”})

最后,让我们用与之前相同的提示运行新链。您会注意到,最终输出包括一些奢侈礼品推荐,例如周末度假,与我们更新的提示中的更高预算相匹配。

overall_chain.run("If my age is half of my dad's age and he is going to be 60 next year, what is my current age?")

# OUTPUT
#> Entering new SequentialChain chain...

#> Entering new AgentExecutor chain...
# I need to figure out my dad's current age and then divide it by two.
#Action: PAL-MATH
#Action Input: What is my dad's current age if he is going to be 60 next year?
#Observation: 59
#Thought: I now know my dad's current age, so I can divide it by two to get my age.
#Action: Divide 59 by 2
#Action Input: 59/2
#Observation: Divide 59 by 2 is not a valid tool, try another one.
#Thought: I can use PAL-MATH to divide 59 by 2.
#Action: PAL-MATH
#Action Input: Divide 59 by 2
#Observation: 29.5
#Thought: I now know the final answer.
#Final Answer: My current age is 29.5 years old.

#> Finished chain.

# For someone of your age, a good gift would be something that is both practical and meaningful. Consider something like a nice watch, a piece of jewelry, a nice leather bag, or a gift card to a favorite store or restaurant.\nIf you have a larger budget, you could consider something like a weekend getaway, a spa package, or a special experience.'}

#> Finished chain.
For someone of your age, a good gift would be something that is both practical and meaningful. Consider something like a nice watch, a piece of jewelry, a nice leather bag, or a gift card to a favorite store or restaurant.\nIf you have a larger budget, you could consider something like a weekend getaway, a spa package, or a special experience.'}

结论

希望通过本文分享的学习内容能让您更轻松地深入了解这个库。本文只是皮毛,还有很多内容可以探讨。例如,如何在自己的数据集上构建问答聊天机器人,以及如何优化这些聊天机器人的记忆,以便您可以选择性地/总结性地发送对话,而不是将所有以前的聊天历史作为提示的一部分发送出去。

如往常一样,如果有更简单的方法来执行/解释本文中提到的一些内容,请告诉我。总的来说,避免未经请求的破坏性/垃圾/敌对的评论!

直到下次见 ✨

我喜欢撰写逐步初学者指南、如何教程、解码 ML/AI 术语等。如果您希望全面访问我的所有文章(以及 Medium 上的其他文章),可以使用 我的链接在这里注册

[## 逐步指南:在数据科学面试中解释您的 ML 项目。

并附带一个示例脚本,让您可以悄悄展示您的技术技能!

时间序列建模使用 Scikit、Pandas 和 Numpy [## 时间序列建模使用 Scikit、Pandas 和 Numpy

直观地使用季节性来提高模型准确性。

数据科学家实用的 GitHub Actions 介绍 [## 使用 GitHub Actions 进行动手实践的介绍

学习如何使用 Weights & Biases 自动化实验跟踪、单元测试、工件创建以及更多内容…

数据科学家实用的 GitHub Actions 介绍 [## 使用少量点击部署端到端深度学习项目:第二部分

将模型从 Jupyter notebook 转移到 Flask 应用程序,使用 Postman 测试 API 端点,并进行 Heroku 部署

将端到端深度学习项目部署到 Heroku

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值