贝叶斯深度学习的温和介绍
原文:
towardsdatascience.com/a-gentle-introduction-to-bayesian-deep-learning-d298c7243fd6
欢迎来到激动人心的概率编程世界!这篇文章是对该领域的温和介绍,你只需对深度学习和贝叶斯统计有基本了解。
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 7 月 26 日
–
到文章结束时,你应该对这个领域、它的应用以及它与更传统的深度学习方法的不同之处有一个基本的了解。
如果你像我一样听说过贝叶斯深度学习,并且猜测它涉及贝叶斯统计,但你不确切知道它是如何使用的,那么你来对地方了。
传统深度学习的局限性
传统深度学习的主要局限性之一是,即使它们是非常强大的工具,它们也不能提供不确定性的度量。
Chat GPT 可能会以明显的自信说出错误的信息。分类器的输出概率通常未经过校准。
不确定性估计是决策过程中的一个关键方面, 尤其是在医疗保健、自动驾驶汽车等领域。我们希望模型能够估计在对脑癌的分类非常不确定时,并在这种情况下需要进一步的医疗专家诊断。同样,我们希望自主汽车能够在识别到新环境时减速。
为了说明神经网络在估计风险时可能有多糟糕,我们来看看一个非常简单的带有 softmax 层的分类器神经网络。
softmax 的名字很容易理解,它是一个软最大函数,这意味着它是一个“更平滑”的最大函数。这是因为如果我们选择了一个“硬”的最大函数,只取概率最高的类别,我们将对所有其他类别的梯度为零。
使用 softmax 时,一个类别的概率可以接近 1,但永远不可能正好是 1。由于所有类别的概率总和是 1,因此仍有一些梯度流向其他类别。
硬最大值与软最大值,图像来源:作者
然而,softmax 函数也存在一个问题。它输出的概率校准不佳。在应用 softmax 函数之前值的微小变化会被指数函数压缩,从而导致输出概率变化极小。
这通常会导致过度自信,模型在面对不确定性时仍然给出某些类别的高概率,这是 softmax 函数‘max’特性固有的特征。
比较传统的神经网络(NN)与贝叶斯神经网络(BNN)可以突显不确定性估计的重要性。BNN 在遇到训练数据中的熟悉分布时,确定性较高,但当我们远离已知分布时,不确定性增加,从而提供更现实的估计。
下面是不确定性估计的一个示例:
传统神经网络与贝叶斯神经网络,图像来源:作者
你可以看到,当我们接近训练中观察到的分布时,模型非常确定,但当我们远离已知分布时,不确定性增加。
贝叶斯统计的简短回顾
在贝叶斯统计中有一个中心定理:贝叶斯定理。
贝叶斯定理,图像来源:作者
-
先验是我们在任何观察之前认为最可能的 theta 分布。例如,对于抛硬币,我们可以假设得到正面概率是围绕 p = 0.5 的高斯分布。
-
如果我们想尽可能减少归纳偏置,我们也可以说 p 在[0,1]之间是均匀的。
-
似然性是给定参数 theta 的情况下,我们得到观察 X,Y 的可能性。
-
边际似然性是对所有可能的 theta 积分后的似然性。它被称为“边际”的原因是我们通过对所有概率进行平均来边际化 theta。
在贝叶斯统计中,关键的概念是从先验开始,它是你对参数可能值的最佳猜测(它是一个分布)。通过你所做的观察,你调整你的猜测,并获得一个后验分布。
请注意,先验和后验不是 theta 的点估计,而是概率分布。
以此为例:
图像来源:作者
在这张图中,你可以看到先验向右移动,但似然性将我们的先验重新调整到左侧,而后验位于两者之间。
贝叶斯深度学习简介
贝叶斯深度学习是一种结合了两种强大数学理论的方法:贝叶斯统计和深度学习。
与传统深度学习的区别在于对模型权重的处理:
在传统深度学习中,我们从头开始训练模型,随机初始化一组权重,并训练模型直到其收敛到一组新的参数。我们学习的是单一的一组权重。
相反,贝叶斯深度学习采用了更为动态的方法。我们从对权重的先验信念开始,通常假设它们遵循正态分布。当我们将模型暴露于数据时,我们调整这一信念,从而更新权重的后验分布。本质上,我们学习的是权重的概率分布,而不是单一的一组权重。
在推断过程中,我们对所有模型的预测进行平均,并根据后验概率加权它们的贡献。这意味着,如果一组权重的概率很高,则其对应的预测将获得更多的权重。
让我们将这些正式化:
推断,图片来自作者
贝叶斯深度学习中的推断通过使用后验分布对所有可能的θ(权重)值进行积分。
我们还可以看到,在贝叶斯统计中,积分无处不在。这实际上是贝叶斯框架的主要限制。这些积分往往是不可解的(我们并不总是知道后验的原始函数)。因此,我们必须进行非常计算密集的近似。
贝叶斯深度学习的优势
优势 1:不确定性估计
- 可以说,贝叶斯深度学习最显著的好处是其不确定性估计的能力。在包括医疗保健、自动驾驶、语言模型、计算机视觉和定量金融等许多领域中,量化不确定性的能力对做出明智的决策和管理风险至关重要。
优势 2:提高训练效率
- 与不确定性估计的概念密切相关的是提高的训练效率。由于贝叶斯模型能够意识到自身的不确定性,它们可以优先从那些不确定性——即学习潜力——最高的数据点中学习。这种方法被称为主动学习,能够实现令人印象深刻的有效和高效训练。
主动学习效果的演示,图片来自作者
如下图所示,使用主动学习的贝叶斯神经网络仅用 1,000 张训练图像就达到了 98%的准确率。相比之下,不利用不确定性估计的模型往往学习速度较慢。
优势 3:归纳偏差
贝叶斯深度学习的另一个优点是通过先验有效利用归纳偏置。先验允许我们编码对模型参数的初始信念或假设,这在存在领域知识的场景中尤为有用。
考虑生成式 AI,其思想是创建与训练数据类似的新数据(例如医学图像)。例如,如果你正在生成脑部图像,并且你已经知道脑部的一般布局——白质在内,灰质在外——这些知识可以包含在你的先验中。这意味着你可以给图像中心的白质分配更高的概率,而将灰质分配到边缘。
本质上,贝叶斯深度学习不仅使模型能够从数据中学习,还使其能够从知识点开始学习,而不是从头开始。这使其成为广泛应用的强大工具。
白质和灰质,图片由作者提供
贝叶斯深度学习的局限性
贝叶斯深度学习看起来非常不可思议!那么为什么这个领域会被低估呢?确实,我们经常谈论生成式 AI、Chat GPT、SAM 或更传统的神经网络,但我们几乎从未听说过贝叶斯深度学习,这是为什么呢?
限制 1: 贝叶斯深度学习非常慢
理解贝叶斯深度学习的关键在于我们“平均”模型的预测,而每当有平均时,就会有对参数集的积分。
但是计算积分通常是不可解的,这意味着没有一个封闭或显式的形式可以使积分计算快速。因此,我们不能直接计算它,我们必须通过采样一些点来近似积分,这使得推断非常慢。
想象一下,对于每个数据点 x,我们必须平均 10,000 个模型的预测,并且每个预测可能需要 1 秒来运行,这样我们最终得到的模型就是无法扩展到大量数据。
在大多数业务场景中,我们需要快速且可扩展的推断,这就是为什么贝叶斯深度学习不那么受欢迎。
限制 2: 近似误差
在贝叶斯深度学习中,通常需要使用近似方法,如变分推断,来计算权重的后验分布。这些近似可能导致最终模型中的错误。近似的质量取决于变分家族和散度度量的选择,这可能很难选择和调整。
限制 3: 模型复杂性和可解释性的增加
虽然贝叶斯方法提供了改进的不确定性度量,但这也增加了模型的复杂性。BNNs 可能难以解释,因为我们现在有的是一个可能权重的分布,而不是单一的权重集。这种复杂性可能会导致在解释模型决策时遇到挑战,特别是在解释性至关重要的领域。
对于 XAI(可解释人工智能)的兴趣日益增长,而传统深度神经网络本身就难以解释,因为难以理解权重,贝叶斯深度学习则更具挑战性。
感谢阅读!在你离开之前:
- 查看我在 Github 上的 AI 教程合集
[## GitHub - FrancoisPorcher/awesome-ai-tutorials: 最佳 AI 教程合集,让你成为…
最佳 AI 教程合集,让你成为数据科学的高手!- GitHub …
github.com](https://github.com/FrancoisPorcher/awesome-ai-tutorials?source=post_page-----d298c7243fd6--------------------------------)
你可以在你的收件箱中获取我的文章。 点击这里订阅。
如果你希望获得 Medium 上的高级文章,只需每月$5 的会员费用。如果你通过 我的链接注册,你将以不增加额外费用的方式支持我。
如果你觉得这篇文章有见地且有益,请考虑关注我并留下掌声,以便获取更多深入内容!你的支持帮助我继续制作有助于我们共同理解的内容。
参考文献
-
Ghahramani, Z. (2015). 概率机器学习与人工智能。自然,521(7553),452–459。 链接
-
Blundell, C., Cornebise, J., Kavukcuoglu, K., & Wierstra, D. (2015). 神经网络中的权重不确定性。arXiv 预印本 arXiv:1505.05424。 链接
-
Gal, Y., & Ghahramani, Z. (2016). Dropout 作为贝叶斯近似:在深度学习中表示模型不确定性。国际机器学习会议(第 1050–1059 页)。 链接
-
Louizos, C., Welling, M., & Kingma, D. P. (2017). 通过 L0 正则化学习稀疏神经网络。arXiv 预印本 arXiv:1712.01312。 链接
-
Neal, R. M. (2012). 贝叶斯神经网络学习(第 118 卷)。Springer Science & Business Media. 链接
补充对数-对数回归的温和介绍
一种在特殊条件下的逻辑回归替代方法
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 10 月 2 日
–
在统计建模和回归分析中,有许多技术可以选择。其中一种常被忽视但在某些场景中非常有用的方法是补充对数-对数(Cloglog)回归。在这篇文章中,我们将详细介绍什么是 Cloglog 回归、何时使用它以及它是如何工作的。
Cloglog 回归的前身
Cloglog 回归是一种用于分析二元响应变量的统计建模技术。我们知道,当涉及到建模二元结果时,首先想到的模型是逻辑回归。实际上,cloglog 是逻辑回归在特殊场景中的替代方案。我假设大家对逻辑回归有基本的了解。然而,如果你对逻辑回归不熟悉,建议首先获得对其的基本了解。网上有大量关于逻辑回归的资源,可以帮助你熟悉这个主题。
Cloglog 回归是逻辑回归模型的扩展,当事件的概率非常小或非常大时尤其有用。大多数时候,cloglog 回归用于处理稀有事件或结果极度偏斜的情况。
对 Cloglog 回归的需求
众所周知,逻辑回归遵循 S 型函数的形式。下面展示了 S 型曲线:
作者提供的图像
从这个图形表示中可以明显看出,对于较小的‘x’值,结果的概率保持相对较低,而对于较大的值,结果的概率变得更高。曲线在‘Y’的值为 0.5 处表现出对称性。这种对称性意味着在逻辑回归中,存在一个潜在的特征,即成功或事件发生的概率(Y = 1)围绕 0.5 对称分布。这意味着概率的最显著变化发生在图表的中间部分,而在极端的‘x’值下,概率相对较不敏感。当我们的结果变量有大量成功或事件的情况时,这一假设是成立的,示例包括:
抑郁症的流行情况
作者提供的图像
或学生考试及格
作者提供的图像
然而,当事件非常稀少或非常频繁时,这一假设可能不成立,在这种情况下,成功或事件发生的概率要么极低,要么极高。例如,考虑人们在心脏骤停后的生存情况,其中成功的可能性显著降低:
作者提供的图像
或医院内青光眼手术的成功率(成功的机会非常高):
作者提供的图像
在这种情况下,0.5 处的对称分布并不理想,建议使用不同的建模方法,这就是互补对数-对数回归的应用场景。
与 logit 和 probit 不同,Cloglog 函数是不对称的,并且偏向一侧。
互补对数-对数回归的工作原理
Cloglog 回归使用互补对数-对数函数,生成一个 S 形曲线但不对称。Cloglog 回归的形式如下:
作者提供的图像
方程的左侧称为互补对数-对数变换。与 logit 和 probit 变换类似,这种变换也将二元响应(0 或 1)转换为(-∞到+∞)。该模型也可以写成:
作者提供的图像
在下图中,我们可视化了在 R 中使用 logit、probit 和 cloglog 变换生成的曲线。
# Load the ggplot2 package
library(ggplot2)
# Create a sequence of values for the x-axis
x <- seq(-5, 5, by = 0.1)
# Calculate the values for the logit and probit functions
logit_vals <- plogis(x)
probit_vals <- pnorm(x)
# Calculate the values for the cloglog function manually
cloglog_vals <- 1 - exp(-exp(x))
# Create a data frame to store the values
data <- data.frame(x, logit_vals, probit_vals, cloglog_vals)
# Create the plot using ggplot2
ggplot(data, aes(x = x)) +
geom_line(aes(y = logit_vals, color = "Logit"), size = 1) +
geom_line(aes(y = probit_vals, color = "Probit"), size = 1) +
geom_line(aes(y = cloglog_vals, color = "CLogLog"), size = 1) +
labs(title = "Logit, Probit, and CLogLog Functions",
x = "x", y = "Probability") +
scale_color_manual(values = c("Logit" = "red", "Probit" = "blue", "CLogLog" = "green")) +
theme_minimal()
作者提供的图像
从图中我们可以观察到明显的差异:虽然 logit 和 probit 变换在值 0.5 附近是对称的,但 cloglog 变换表现出不对称。在逻辑回归和 probit 函数中,概率在接近 0 和 1 时以类似的速率变化。在数据在[0, 1]区间内不对称,且在小到中等值时变化缓慢但在接近 1 时急剧变化的情况下,logit 和 probit 模型可能不适合。在这些情况下,当响应变量的非对称性明显时,互补对数-对数模型(cloglog)成为一个有前景的替代方案,提供了更好的建模能力。从 Cloglog 函数的图中可以看到,P(Y = 1)在接近 0 时较慢,而在接近 1 时则急剧上升。
让我们以一个例子来说明:检查锌缺乏
我模拟了一个特定人群中的锌缺乏数据 [注:这些数据是作者为个人使用而创建的模拟数据]。数据集还包括年龄、性别和体重指数(BMI)等因素的数据。值得注意的是,数据集中只有 2.3%的人表现出锌缺乏,这表明在这个人群中锌缺乏的发生率相对较低。我们的结果变量是锌缺乏(二元变量(0 = 否,1 = 是)),预测变量是年龄、性别和体重指数(BMI)。我们在 R 中使用逻辑回归、概率回归和 Cloglog 回归,并通过 AIC 比较这三种模型。
> #tabulating zinc deficiency
> tab = table(zinc$zinc_def)
> rownames(tab) = c("No", "Yes")
> print(tab)
No Yes
8993 209
> #tabulating sex and zinc deficieny
> crosstab = table(zinc$sex, zinc$zinc_def)
> rownames(crosstab) = c("male" , "female")
> colnames(crosstab) = c("No", "Yes")
> print(crosstab)
No Yes
male 4216 159
female 4777 50
> #definig sex as a factor variable
> zinc$sex = as.factor(zinc$sex)
> #logistic regression of zinc deficiency predicted by age, sex and bmi
> model1 = glm(zinc_def ~ age + sex + bmi, data = zinc, family = binomial(link = "logit"))
> summary(model1)
Call:
glm(formula = zinc_def ~ age + sex + bmi, family = binomial(link = "logit"),
data = zinc)
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -2.064053 0.415628 -4.966 6.83e-07 ***
age -0.034369 0.004538 -7.574 3.62e-14 ***
sex2 -1.271344 0.164012 -7.752 9.08e-15 ***
bmi 0.010059 0.015843 0.635 0.525
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 1995.3 on 9201 degrees of freedom
Residual deviance: 1858.8 on 9198 degrees of freedom
(1149 observations deleted due to missingness)
AIC: 1866.8
Number of Fisher Scoring iterations: 7
> #probit model
> model2 = glm(zinc_def ~ age + sex + bmi, data = zinc, family = binomial(link = "probit"))
> summary(model2)
Call:
glm(formula = zinc_def ~ age + sex + bmi, family = binomial(link = "probit"),
data = zinc)
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -1.280983 0.176118 -7.273 3.50e-13 ***
age -0.013956 0.001863 -7.493 6.75e-14 ***
sex2 -0.513252 0.064958 -7.901 2.76e-15 ***
bmi 0.003622 0.006642 0.545 0.586
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 1995.3 on 9201 degrees of freedom
Residual deviance: 1861.7 on 9198 degrees of freedom
(1149 observations deleted due to missingness)
AIC: 1869.7
Number of Fisher Scoring iterations: 7
> #cloglog model
> model3 = glm(zinc_def ~ age + sex + bmi, data = zinc, family = binomial(link = "cloglog"))
> summary(model3)
Call:
glm(formula = zinc_def ~ age + sex + bmi, family = binomial(link = "cloglog"),
data = zinc)
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -2.104644 0.407358 -5.167 2.38e-07 ***
age -0.033924 0.004467 -7.594 3.09e-14 ***
sex2 -1.255728 0.162247 -7.740 9.97e-15 ***
bmi 0.010068 0.015545 0.648 0.517
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 1995.3 on 9201 degrees of freedom
Residual deviance: 1858.6 on 9198 degrees of freedom
(1149 observations deleted due to missingness)
AIC: 1866.6
Number of Fisher Scoring iterations: 7
> #extracting AIC value of each model for model comparison
> AIC_Val = AIC(model1, model2, model3)
> print(AIC_Val)
df AIC
model1 4 1866.832
model2 4 1869.724
model3 4 1866.587
系数的解释
在 Cloglog 回归中,系数的解释类似于逻辑回归。每个系数代表预测变量变化一个单位时,结果对数几率的变化。通过对系数取指数,我们得到比值比。
在我们特定的模型中,年龄的系数是-0.034。这意味着年龄每增加一年,锌缺乏的对数几率减少 0.034 单位。通过对这个系数取指数,我们可以计算出比值比:
比值比 = exp(-0.034) = 0.97
这表明年龄增加一年与锌缺乏的几率降低 3%相关。
类似地,对于变量‘性别’:
比值比 = exp(-1.25) = 0.28
这表明,与男性相比,女性经历锌缺乏的几率降低了 72%。
我们也可以解释 BMI 系数,但需要注意的是,BMI 的 p 值为 0.52,表明在这个模型中,它与锌缺乏并没有显著关联。
应用及使用
Cloglog 回归被广泛应用于各种研究领域,包括稀有疾病流行病学、药物疗效研究、信用风险评估、缺陷检测和生存分析。特别是,Cloglog 模型在生存分析中具有重要意义,因为它与事件发生的连续时间模型密切相关。
互补对数-对数回归是一种强大且常被忽视的统计技术,在传统的逻辑回归不适用的情况下,它可能非常有价值。通过理解其原理和应用,你可以将这个多功能工具加入到你的数据分析工具箱中。
《深入浅出 JAX 中的深度强化学习》
在一秒钟内用 DQN 解决 CartPole 环境
·
关注 发表在Towards Data Science · 10 分钟阅读·2023 年 11 月 21 日
–
由Thomas Despeyroux拍摄,发布于Unsplash
最近在强化学习(RL)方面的进展,例如 Waymo 的自动驾驶出租车或 DeepMind 的超人类棋类代理,结合了经典 RL和深度学习组件,如神经网络和梯度优化方法。
在之前介绍的基础和编码原则上构建,我们将探索并学习如何使用 JAX 实现深度 Q 网络(DQN)和回放缓冲区来解决 OpenAI 的CartPole环境。所有这些操作都在不到一秒的时间内完成!
对于JAX、向量化环境和Q-learning的介绍,请参阅以下内容:
## 使用 JAX 向量化和并行化 RL 环境:Q-learning 的光速⚡
学习如何在 CPU 上向量化 GridWorld 环境,并同时训练 30 个 Q-learning 代理,每个代理进行 180 万步…
towardsdatascience.com
我们选择的深度学习框架是 DeepMind 的Haiku库,我最近在 Transformer 的上下文中介绍过:
## 使用 JAX 和 Haiku 从头开始实现 Transformer 编码器 🤖
理解 Transformer 的基础构建模块。
towardsdatascience.com
本文将涵盖以下几个部分:
-
为什么我们需要深度 RL?
-
深度 Q 网络的理论和实践
-
回放缓冲区
-
将CartPole环境转换为JAX
-
JAX编写高效训练循环的方式
如往常一样,本文中提供的所有代码都可在 GitHub 上找到:
[## GitHub - RPegoud/jym:JAX 实现的 RL 算法和向量化环境
JAX 实现的 RL 算法和向量化环境 - GitHub - RPegoud/jym: JAX 实现的 RL…
github.com](https://github.com/RPegoud/jym?source=post_page-----c1e45a179b92--------------------------------)
为什么我们需要深度 RL?
在之前的文章中,我们介绍了时间差分学习算法,特别是 Q-learning。
简单来说,Q-learning 是一种离策略算法(目标策略与用于决策的策略不同),用于维护和更新Q 表,一个明确的状态到相应动作值的映射。
尽管 Q 学习是离散行动空间和受限观察空间环境的实际解决方案,但在更复杂的环境中很难扩展。事实上,创建 Q 表需要定义行动和观察空间。
考虑自动驾驶的例子,观察空间由来自摄像头和其他感知输入的无限潜在配置组成。另一方面,行动空间包括广泛的方向盘位置以及施加到刹车和油门的不同力度。
尽管理论上我们可以离散化行动空间,但实际应用中可能会导致不切实际的 Q 表,因为可能的状态和行动数量庞大。
在大型复杂状态-动作空间中寻找最优行动因此需要强大的函数逼近算法,这正是神经网络所擅长的。在深度强化学习中,神经网络用作Q 表的替代品,并为大状态空间引入的维度灾难提供了高效的解决方案。此外,我们不需要显式定义观察空间。
深度 Q 网络与重播缓冲区
DQN 同时使用两种类型的神经网络,并行进行,首先是用于Q 值预测和决策的“在线”网络。另一方面,“目标”网络用于通过损失函数评估在线网络的性能,以生成稳定的 Q 目标。
与 Q 学习类似,DQN 代理由两个函数定义:act
和update
。
行动
act
函数实现了关于 Q 值的ε-贪心策略,Q 值由在线神经网络估计。换句话说,代理根据给定状态的最大预测 Q 值选择动作,同时以一定概率随机执行动作。
您可能还记得 Q 学习在每一步之后更新其 Q 表,但在深度学习中,通常使用梯度下降在输入批次上计算更新。
因此,DQN 将经验(包含state, action, reward, next_state, done_flag
的元组)存储在重播缓冲区中。为了训练网络,我们将从此缓冲区中随机抽取一批经验,而不仅仅使用最后一次经验(有关更多详细信息,请参见重播缓冲区部分)。
展示了DQN 行动选择过程的视觉表示(作者制作)
这里是 DQN 行动选择部分的 JAX 实现:
这个代码片段的唯一细微之处在于 model
属性不包含任何内部参数,这与 PyTorch 或 TensorFlow 等框架中通常的情况不同。
在这里,模型是一个 函数,表示我们架构中的 前向传递,但 可变的 权重是外部存储 并作为 参数 传递。这解释了为什么我们可以在传递 self
参数时使用 jit
,作为 *静态(**模型在其他类属性中是无状态的)*。
更新
update
函数负责训练网络。它根据 时间差(TD) 误差 计算 均方误差(MSE)损失:
DQN 中使用的均方误差
在这个损失函数中,θ 表示 在线网络的参数,而 θ− 代表 目标网络的参数。目标网络的参数每隔 N 步被设置为在线网络的参数,类似于一个 检查点 (N 是一个超参数)。
参数的分离(当前 Q 值的 θ 和目标 Q 值的 θ−)对于稳定训练至关重要。
如果对两个网络使用相同的参数,就类似于瞄准一个移动的目标,因为 对网络的更新 会 立即改变目标值。通过 定期更新 θ−(即在设定步数内冻结这些参数),我们确保了 稳定的 Q 目标,同时在线网络继续学习。
最后,(1-done) 项 调整目标 用于 终止状态。实际上,当一个回合结束时(即 ‘done’ 等于 1),没有下一个状态。因此,下一状态的 Q 值被设为 0。
DQN 参数更新 过程的可视化表示(作者制作)
实现 DQN 的更新函数稍微复杂一些,我们分解一下:
-
首先,
_loss_fn
函数实现了之前描述的用于 单一经验 的平方误差。 -
然后,
_batch_loss_fn
作为_loss_fn
的包装器,并通过vmap
装饰它,将损失函数应用于 一批经验。然后我们返回这一批的平均误差。 -
最后,
update
作为我们损失函数的最终层,计算其 梯度 相对于在线网络参数、目标网络参数以及一批经验。然后我们使用 Optax (一个常用于优化的 JAX 库) 执行优化步骤并更新在线参数。
注意,与回放缓冲区类似,模型和优化器是 纯函数,修改 外部状态。以下一行很好地说明了这一原则:
updates, optimizer_state = optimizer.update(grads, optimizer_state)
这也解释了为什么我们可以对在线网络和目标网络使用一个模型,因为参数是外部存储和更新的。
# target network predictions
self.model.apply(target_net_params, None, state)
# online network predictions
self.model.apply(online_net_params, None, state)
为了提供背景,我们在本文中使用的模型是一个多层感知机,定义如下:
N_ACTIONS = 2
NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]
online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
@hk.transform
def model(x):
# simple multi-layer perceptron
mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)
return mlp(x)
online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))
target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))
prediction = model.apply(online_net_params, None, state)
重放缓冲区
现在,让我们退一步,更详细地看一下重放缓冲区。它们在强化学习中被广泛使用,原因有很多:
-
泛化: 通过从重放缓冲区采样,我们打破了连续经验之间的相关性,通过混合它们的顺序。这种方式避免了对特定经验序列的过拟合。
-
多样性: 由于采样不限于最近的经验,我们通常会观察到更新的方差较低,并且避免对最新经验的过拟合。
-
增加样本效率: 每个经验可以从缓冲区中被多次采样,使模型能够从个体经验中学习更多。
最后,我们可以使用几种采样方案来管理我们的重放缓冲区:
-
均匀采样: 经验以均匀的随机方式进行采样。这种采样类型易于实现,并允许模型从经验中独立于它们被收集的时间步长中学习。
-
优先采样: 这个类别包括不同的算法,如优先经验重放(“PER”,Schaul 等,2015)或梯度经验重放(“GER”,Lahire 等,2022)。这些方法试图根据与其“学习潜力”(PER 的 TD 误差幅度和 GER 的经验梯度的范数)相关的某些指标优先选择经验。
为了简单起见,我们将在本文中实现一个均匀重放缓冲区。然而,我计划在未来详细讨论优先采样。
正如承诺的那样,均匀重放缓冲区的实现非常简单,但与 JAX 和函数式编程的使用相关的一些复杂性需要解决。与往常一样,我们必须使用纯函数,这些函数没有副作用。换句话说,我们不能将缓冲区定义为具有变量内部状态的类实例。
相反,我们初始化一个buffer_state
字典,该字典将键映射到具有预定义形状的空数组,因为 JAX 在对 XLA 进行 JIT 编译时要求固定大小的数组。
buffer_state = {
"states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
"next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),
"dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
我们将使用UniformReplayBuffer
类与缓冲区状态进行交互。这个类有两个方法:
-
add
:解开一个经验元组,并将其组件映射到特定索引。idx = idx % self.buffer_size
确保当缓冲区满时,添加的新经验会覆盖旧的经验。 -
sample
:从均匀随机分布中采样一系列随机索引。序列长度由batch_size
设定,而索引的范围为[0, current_buffer_size-1]
。这确保了在缓冲区尚未满时不会采样到空数组。最后,我们使用 JAX 的vmap
结合tree_map
返回一批经验。
将CartPole环境转换为JAX
现在我们的 DQN 代理已准备好进行训练,我们将快速实现一个使用与早期文章介绍的相同框架的矢量化 CartPole 环境。 CartPole 是一个具有大型连续观察空间的控制环境,这使得测试我们的 DQN 变得相关。
CartPole 环境的可视化表示(鸣谢和文档:OpenAI Gymnasium,MIT 许可证)
这个过程非常简单,我们大部分都重用了OpenAI 的 Gymnasium 实现,同时确保我们使用 JAX 数组和 lax 控制流,而不是 Python 或 Numpy 的替代方案,例如:
# Python implementation
force = self.force_mag if action == 1 else -self.force_mag
# Jax implementation
force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag) )
# Python
costheta, sintheta = math.cos(theta), math.sin(theta)
# Jax
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
# Python
if not terminated:
reward = 1.0
...
else:
reward = 0.0
# Jax
reward = jnp.float32(jnp.invert(done))
为了简洁起见,完整的环境代码在此处可用:
[## jym/src/envs/control/cartpole.py at main · RPegoud/jym
JAX 实现的 RL 算法和矢量化环境 - jym/src/envs/control/cartpole.py at main ·…
以JAX的方式编写高效的训练循环
DQN 实现的最后部分是训练循环(也称为推演)。 正如前文所述,为了利用 JAX 的速度,我们必须遵守特定的格式。
推演函数可能一开始看起来令人生畏,但它大部分的复杂性纯粹是语法上的,因为我们已经涵盖了大多数构建块。 这是一个伪代码演示:
1\. Initialization:
* Create empty arrays that will store the states, actions, rewards
and done flags for each timestep. Initialize the networks and optimizer
with dummy arrays.
* Wrap all the initialized objects in a val tuple
2\. Training loop (repeat for i steps):
* Unpack the val tuple
* (Optional) Decay epsilon using a decay function
* Take an action depending on the state and model parameters
* Perform an environment step and observe the next state, reward
and done flag
* Create an experience tuple (state, action, reward, new_state, done)
and add it to the replay buffer
* Sample a batch of experiences depending on the current buffer size
(i.e. sample only from experiences that have non-zero values)
* Update the model parameters using experience batch
* Every N steps, update the target network's weights
(set target_params = online_params)
* Store the experience's values for the current episode and return
the updated `val` tuple
现在我们可以运行 DQN 进行20,000 步并观察其表现。 大约在 45 集后,代理成功地保持了超过 100 步的稳定平衡。
绿色条表示代理成功地在200 多步内平衡了杆,解决了环境。 值得注意的是,代理在第 51 集上创下了393 步的记录。
DQN 的性能报告(作者制作)
20,000 训练步骤在一秒多一点内执行完毕,速度为每秒 15,807 步(在单个 CPU 上)!
这些表现提示了 JAX 令人印象深刻的扩展能力,允许从业者使用最小的硬件要求进行大规模并行实验。
Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:01<00:00, 15807.81it/s]
我们将更详细地看看并行化的推演过程,以运行具有统计显著性的实验和超参数搜索在未来的文章中!
与此同时,随时可以使用这本笔记本重现实验并尝试不同的超参数:
[## jym/notebooks/control/cartpole/dqn_cartpole.ipynb at main · RPegoud/jym
JAX 强化学习算法和向量化环境的实现 - jym/notebooks/control/cartpole/dqn_cartpole.ipynb at…
github.com](https://github.com/RPegoud/jym/blob/main/notebooks/control/cartpole/dqn_cartpole.ipynb?source=post_page-----c1e45a179b92--------------------------------)
结论
如往常一样,**感谢您读到这里!**希望本文为您在 JAX 中的深度强化学习提供了一个不错的介绍。如果您对本文内容有任何问题或反馈,请务必告诉我,我总是乐意聊一聊 😉
直到下次见面 👋
致谢:
- Cartpole Gif,OpenAI Gymnasium 库,(MIT 许可证)
《初学者友好的生成式 AI 介绍》
原文:
towardsdatascience.com/a-gentle-introduction-to-generative-ai-for-beginners-8c8752085900
让我们深入了解生成式 AI 的整体图景
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 6 月 29 日
–
图片由 Susan Cipriano 提供,来源于 Pixabay
过去几个月里,所谓的“生成式 AI”得到了广泛关注,这是一种人工智能(AI)的子领域。像 ChatGPT 这样的工具已经成为最常被提及的词汇之一,并且正成为许多工作中日常任务的基础工具(甚至用来学习编程)。
诸如“DALL-E”、“ChatGPT”和“生成式 AI”这些词汇在过去几个月里充斥了社交网络、媒体、同事聊天和我们世界的方方面面。几乎每个人都在谈论这些。
那么,什么是生成式 AI?它与“普通” AI 有何不同?
在这篇文章中,我们将澄清生成式 AI 背后的大图景。因此,如果你参与了相关讨论但对这个话题没有明确的理解,这篇文章绝对适合你。
这是一种探讨性解释,以理解生成式 AI 背后的基础内容。因此,不用担心:这里不会有任何代码。只有想法和描述,它们将以非常简洁的方式呈现。特别是,我们将重点关注大规模语言模型和图像生成模型。
这里是你将学习的内容概要:
**Table of Contents:**
What is generative AI and how does it differ from trditional AI?
Large Language Models
Image generation
什么是生成式 AI,它与传统 AI 有何不同?
生成式 AI 是 AI 的一个子领域,涉及创建能够生成新数据的算法,如图像、文本、代码和音乐。
生成式人工智能和“传统人工智能”之间的主要区别在于前者根据训练数据生成新的数据。此外,它可以处理“传统人工智能”无法处理的数据类型。
让我们稍微技术性地说一下:
-
“传统人工智能”可以定义为区分性人工智能。在这种情况下,实际上,我们训练机器学习模型,使其能够对新的、未见过的数据进行预测或分类。这些机器学习模型只能处理数字,有时也处理文本(例如,在自然语言处理的情况下)。
-
在生成式人工智能中,我们训练一个机器学习模型,它创建的输出类似于它所训练的数据。这些类型的机器学习模型可以处理不同类型的数据,如数字、文本、图像和音频。
让我们可视化这些过程:
传统人工智能的过程。图片来源于作者。
所以,在传统人工智能中,我们训练一个机器学习模型从数据中学习。然后,我们将其输入新的、未见过的数据,它可以进行区分,做出预测或分类。
关于所展示的示例,我们已经训练了一个机器学习模型来识别图像中的狗。然后,我们将训练好的机器学习模型输入新的、未见过的狗的图片,它将能够分类这些新图像是否代表狗。
这是深度学习算法在分类问题中的典型任务。
生成式人工智能的过程。图片来源于作者。
在生成式人工智能的情况下,我们用来自各种来源的数据训练一个机器学习模型,使用大量的数据。然后,得益于一个提示(用户插入的自然语言查询),模型给出一个类似于它所训练的数据的输出。
以这个示例为例,我们的模型已经在大量(文本)数据上进行了训练,其中包括解释什么是狗的数据。然后,如果用户向模型查询什么是狗,模型将用自然语言描述什么是狗。
这是像 ChatGPT 这样的工具执行的典型任务。
现在,让我们看看一些生成式人工智能模型的类型。
大型语言模型
让我们从大型语言模型(LLMs)开始,深入了解各种生成式人工智能子领域。[LLM 是](https://en.wikipedia.org/wiki/Large_language_model#:~:text=A%20large%20language%20model%20(LLM,learning%20or%20semi%2Dsupervised%20learning.)(来自维基百科):
是一个计算机化的语言模型,由一个具有大量参数(从几千万到几十亿)的人工神经网络组成,使用自监督学习或半监督学习在大量未标记的文本上进行训练。
尽管“大型语言模型”这个术语没有正式定义,但它通常指的是具有数百万甚至数十亿个参数的深度学习模型,这些模型已在大量语料库上“预训练”。
因此,LLM 是深度学习(DL)模型(即神经网络),使用数百万个参数在大量文本上进行训练(这就是我们称之为“大型”的原因),并且对解决一些语言问题很有用,如:
-
文本分类
-
问答
-
文档总结
-
文本生成
因此,标准机器学习模型的另一个重要区别在于,在这种情况下,我们可以训练一个可以用于不同任务的深度学习算法。
让我进一步解释。
如果我们需要开发一个可以识别图像中狗的系统,如前所述,我们需要训练一个深度学习算法来解决分类任务,即:告诉我们新的、未见过的图像是否代表狗。仅此而已。
相反,训练一个 LLM 可以帮助我们完成上述所有任务。因此,这也证明了训练 LLM 所需的计算能力(和资金!)的必要性(这需要 PB 级的数据!)。
我们知道,LLM 是通过用户的提示来查询的。现在,我们需要区分提示设计和提示工程:
-
提示设计。这是创造一个专门适用于系统执行的具体任务的提示的艺术。例如,如果我们想让我们的 LLM 将文本从英语翻译成意大利语,我们必须用英语写一个具体的提示,要求模型将我们粘贴的文本翻译成意大利语。
-
提示工程。这是创建提示以提高我们的 LLM 性能的过程。这意味着使用我们的领域知识来向提示中添加细节,如特定的关键词、特定的上下文和示例,以及必要时所需的输出。
当然,当我们进行提示时,有时会使用两者的混合。例如,我们可能希望将从英语到意大利语的翻译应用于特定知识领域,如力学。
例如,一个提示可能是:“将以下内容翻译成意大利语:
光束受到正常应力作用。
考虑到我们处于力学领域,因此‘正常应力’必须与其相关。
因为,你知道:“正常”和“应力”可能会被模型(甚至是人类!)误解。
三种类型的 LLM
有三种类型的 LLM:
-
通用语言模型。这些模型能够基于训练数据中的语言预测一个词(或一个短语)。例如,可以考虑你的电子邮件自动完成功能来理解这种类型。
-
指令调优模型。这些模型被训练以预测对输入中给出的指令的响应。总结给定文本是一个典型的例子。
-
对话调优模型。这些模型被训练与用户进行对话,使用后续的回应。一个典型的例子是 AI 驱动的聊天机器人。
无论如何,请注意,实际分发的模型具有混合特征。或者,至少,它们可以执行多个这些类型典型的操作。
例如,如果我们考虑 ChatGPT,我们可以明确地说它:
-
可以根据输入预测对指令的响应。事实上,例如,它可以总结文本,提供我们通过提示提供的某个论点的见解,等等……因此,它具有指令调优模型等功能。
-
经过训练以与用户对话。这很明显,因为它会根据后续提示进行工作,直到我们对其答案感到满意。因此,它还具有对话调优模型等功能。
图像生成
图像生成已经存在了一段时间,这与人们的看法相反。
无论如何,近年来它得到了流行,特别是像“DALL-E”或“稳定扩散”这样的工具,它们明确了其用途,使这一技术对全球大众变得可及。
我们可以说图像生成可以分为四类:
-
变分自编码器(VAEs)。变分自编码器是 “需要神经网络作为其整体结构的一部分的概率生成模型”。用操作性的话来说,它们将图像编码成压缩大小,并解码回原始大小。在此过程中,它们学习数据的分布。
-
生成对抗模型(GANs)。这些通常是最知名的,至少在生成 AI 领域的术语中是如此。GAN 是 “一个机器学习框架的类别,其中两个神经网络相互对抗,其中一个的收益是另一个的损失”。这意味着一个神经网络创建图像,而另一个预测它是真实的还是伪造的。
-
自回归模型。在统计学中,自回归模型是随机过程的表示。在生成图像的背景下,这些模型通过将图像视为像素序列来生成图像。
-
扩散模型。扩散模型受到热力学的启发,毫无疑问是图像生成子领域中最具前景和最有趣的模型。
这是在扩散模型的“引擎盖下”工作的过程:
-
前向分布过程。我们有一个初始的迭代过程,其中图像的结构在数据分布中被“破坏”。简单来说,就像我们反复向图像中添加噪声,直到所有像素变成纯噪声,图像无法被识别(通过人眼)。
-
反向扩散过程。然后,有一个反向扩散过程,它是实际的学习过程:它恢复数据的结构。这就像我们的模型学习如何“去噪”像素以重建图像。
连接这一切的力量
如果你保持了注意力到现在,你的脑海中自然会浮现一个问题:“好的,费德里科,这很清楚。但我还有一点没弄明白:当我使用‘DALL-E’时,我输入一个提示,它输出一张图像:我们还没有讨论过这一点,是吗?!”。
不,我们没有。
上面我们简要描述了目前最有前景(也是使用最多)的图像生成模型,但缺少的部分是提示。
我们实际上讨论了它们在高层次上的工作原理。也就是说,我们简要解释了它们的学习过程是如何运作的。
但这些模型的真正力量体现在它们与大型语言模型的结合上。实际上,这种结合使我们能够结合提示工程的力量来请求模型输出。
换句话说:我们结合了使用自然语言作为输入的可能性,这些模型可以实际理解这些输入并生成相应的图像。
这难道不是一种超级能力吗?!
结论
总结来说,我们可以说生成型人工智能是人工智能的一个子领域,旨在生成类似于训练数据的新数据。
一方面,大型语言模型可以根据训练数据生成文本,而图像生成模型可以基于训练图像生成新图像,生成型人工智能的真正力量,至少在图像的情况下,依赖于大型语言模型和图像生成模型的结合。这让我们可以根据提示输入创建图像。
注意:本文灵感来自谷歌提供的生成型人工智能课程,部分参考资料来自其中。我建议* 参加此课程,以更好地理解生成型人工智能。
费德里科·特罗塔
嗨,我是费德里科·特罗塔,我是一名自由职业的技术写作员。
想与我合作吗? 联系我。
GPT 模型简介
欢迎来到新的令牌生成器的世界
·
关注 发表于 Towards Data Science ·9 分钟阅读·2023 年 4 月 12 日
–
图片来源于 Pixabay — 作者修改
随着 ChatGPT 和 GPT-4 的最近发布,GPT 模型引起了科学界的极大关注。这些 OpenAI 的 GPT 模型新版本如此强大和多才多艺,以至于我们可能需要很长时间才能充分挖掘它们的潜力。
尽管它们令人印象深刻,但你可能不知道,GPT 模型背后的主要思想和算法远非新颖。
无论你是经验丰富的数据科学家还是仅仅对 GPT 感到好奇的人,了解 GPT 模型的演变对数据的影响以及未来几年的预期都是特别有启发性的。
在这篇文章中,我解释了 GPT 模型如何发展到今天的状态。我将主要关注 OpenAI 如何在这些年里扩展 GPT 模型。如果你想开始使用 GPT 模型,我也会给出一些指引。
生成预训练语言模型
GPT 模型是语言模型。
语言模型已经存在了 超过 50 年。
第一代语言模型是“n-gram 基础”的。它们对给定一些前置词的情况下,预测一个词的概率。
例如,如果你有以下句子:
猫在厨房里睡觉。
使用 n=3,你可以从一个 3-gram 语言模型中获得“in”跟在“cat sleeps”后面的概率。
n-gram 模型在许多自然语言和语音处理任务中仍然很有用,直到 2010 年代初。
这些模型存在几个限制。计算复杂性随着 n 的增加而急剧增加。因此,这些模型通常限制在 n=5 或更低。
然后,得益于神经网络和更强大的机器,这一主要限制得到了缓解,可以计算更长 n-gram 的概率,例如 n=20 或更高。
使用这些模型生成文本也是可能的,但它们的输出质量很差,因此很少用于这个目的。
然后,在 2018 年,OpenAI 提出了第一个 GPT 模型。
GPT 代表“生成预训练”。“预训练”意味着模型只是基于大量文本进行训练,以建模概率,除了语言建模没有其他目的。GPT 模型可以进一步微调,即进一步训练,以执行更具体的任务。
例如,你可以使用一个小的数据集来获得一个在新闻摘要方面表现非常好的 GPT 模型,或者在法英翻译上进行微调,以获得一个能够将法语翻译成英语的机器翻译系统。
注意:术语“预训练”暗示模型尚未完全训练,还需要另一个步骤。随着最近模型的出现,微调的需求趋于消失。预训练模型现在可以直接在应用中使用。
GPT 模型现在在几乎所有自然语言处理任务中表现都很优秀。我特别研究了它们在机器翻译方面的能力,你可以在以下文章中阅读:
机器翻译,但没有机器翻译系统
towardsdatascience.com
训练的规模和它们利用的 Transformer 神经网络架构是它们能够生成流畅文本的主要原因。
自 2018 年首次发布 GPT 以来,出现了多个版本和子版本的 GPT。
4 个版本和更多子版本
GPT 和 GPT-2
GPT-2 在首次发布 GPT 后的几个月内推出。注:在描述首次 GPT 的科学论文中从未提到“GPT”这个术语。可以说,“GPT-1”实际上并不存在。据我所知,它也从未发布。
GPT 和 GPT-2 有什么区别?
规模。GPT-2 比 GPT 大得多。
GPT 是在包含 7000 本书的 BookCorpus 上进行训练的。该模型有 1.2 亿个参数。
什么是参数?
参数是模型训练过程中学习到的变量。通常,参数更多的模型更大,表现更好。
120 百万在 2018 年是一个巨大的数字。
借助 GPT-2,OpenAI 提出了一个包含 15 亿参数的更大模型。
它是在一个未公开的语料库 WebText 上进行训练的。这个语料库比 BookCorpus 大 10 倍(根据描述 GPT-2 的论文)。
OpenAI 逐步发布了 4 个版本的 GPT-2:
-
small: 124 百万参数
-
medium: 355 百万参数
-
large: 774 百万参数
-
xl: 15 亿参数
它们都是公开的,可以用于商业产品中。
虽然 GPT-2-XL 在生成自然流畅的文本方面表现出色,即没有任何特定的指令或微调,但在特定任务上仍然远不如更新的 GPT 模型强大。
GPT-2-XL 的发布是 OpenAI 最后一次公开发布的 GPT 模型。GPT-3 和 GPT-4 只能通过 OpenAI 的 API 使用。
GPT-3
GPT-3 于 2020 年发布。其拥有 1750 亿参数,比 GPT-2 的跳跃更大。
这也是因为 OpenAI 停止公开 GPT 模型的精确训练信息。
今天,通过 OpenAI 的 API 提供了 7 个 GPT-3 模型,但我们对它们了解甚少。
借助 GPT-3,OpenAI 展示了如果用户提供一些他们希望模型完成的任务示例,GPT 模型可以在特定的语言生成任务中表现得极为出色。
GPT-3.5
随着 GPT-3 模型在 API 中运行并吸引越来越多的用户,OpenAI 能够收集到一个非常大的用户输入数据集。
他们利用这些输入进一步改进了他们的模型。
他们使用了一种叫做人类反馈强化学习(RLHF)的技术。我不会在这里详细解释,但你可以在 OpenAI 发布的一篇博客文章中找到这些细节。
简而言之,得益于 RLHF,GPT-3.5 在遵循用户指令方面比 GPT-3 好得多。OpenAI 将这类 GPT 模型称为“instructGPT”。
使用 GPT-3.5,你可以“提示”模型执行特定任务,而无需给它任务的示例。你只需写出“正确”的提示以获得最佳结果。这就是“提示工程”变得重要的地方,也是为什么熟练的提示工程师正在获得令人难以置信的工作机会。
GPT-3.5 是当前用于驱动 ChatGPT 的模型。
GPT-4
GPT-4 于 2023 年 3 月发布。
我们几乎对其训练过程一无所知。
与 GPT-3/GPT-3.5 的主要区别在于 GPT-4 是双模态的:它可以接收图像和文本作为输入。
它可以生成文本,但不会直接生成图像。注意:GPT-4 可以生成生成图像的代码,或者从网络上检索图像。
在撰写这些文字时,GPT-4 仍处于“有限 beta”阶段。
ChatGPT
ChatGPT 只是一个具有聊天功能的用户界面。当你在 ChatGPT 中写东西时,生成答案的是一个 GPT-3.5 模型。
ChatGPT 的一个特点是,它不仅仅是像开箱即用的 GPT 模型那样接收用户的当前查询。为了作为聊天引擎正常工作,ChatGPT 必须跟踪对话:已经说了什么,用户的目标是什么,等等。
OpenAI 并没有透露它是如何做到这一点的。鉴于 GPT 模型只能接受有限长度的提示(我稍后会解释这一点),ChatGPT 不能简单地将所有对话回合串联在一起放在同一个提示中。这种提示可能会过大,GPT-3.5 无法处理。
如何使用 GPT 模型?
你可以轻松地在网上获取 GPT-2 模型并在计算机上使用它们。如果你想在你的机器上运行大型语言模型,你可能会对阅读我的教程感兴趣:
使用 PyTorch 和 Hugging Face 的 device_map
pub.towardsai.net](https://pub.towardsai.net/run-very-large-language-models-on-your-computer-390dd33838bb?source=post_page-----e02b093a495b--------------------------------)
对于 GPT-3 和 GPT-3.5,我们别无选择,只能使用 OpenAI 的 API。你首先需要在他们的网站上创建一个 OpenAI 账户。
一旦你有了账户,你可以开始在“playground”(这是 OpenAI 提供的实验模型的沙盒)中玩耍。只有在登录后,你才能访问它。
如果你想在你的应用程序中直接使用这些模型,OpenAI 和开源社区提供了许多语言的库,如Python、Node.js 和 PHP,以通过 OpenAI API 调用模型。
你可以在你的 OpenAI 账户中创建并获取你的 OpenAI API 密钥。注意:请保密此密钥。任何拥有它的人都可以消耗你的 OpenAI 额度。
每个模型有不同的设置,你可以进行调整。请注意,GPT 模型是非确定性的。如果你用相同的提示两次调用模型,很有可能会得到两个相似但不同的回答。
注意:如果你想减少相同提示下回答的变异性,可以将模型的“温度”参数设置为 0。副作用是,它也会显著减少答案的多样性,换句话说,生成的文本可能会更冗余。
你还需要注意“最大内容长度”。这是你的提示长度加上 GPT 生成的回答长度。例如,GPT-3.5-turbo 的“最大内容长度”是 4,096 令牌。
一个令牌不是一个单词。
令牌是 GPT 模型用于生成文本的最小单位。是的,GPT 模型并非真正的单词生成器,而是令牌生成器。令牌可以是一个字符、一个词的一部分、一个单词,甚至是某些语言中的词组。
OpenAI 在API 文档中给出了一个示例。
*"ChatGPT is great!"*
被编码成六个令牌:*["Chat", "G", "PT", " is", " great", "!"]*
。
一般而言,750 个英语单词约等于 1,000 个令牌。
我认为,管理“最大内容长度”是使用 OpenAI API 中最繁琐的部分。首先,没有简单的方法来知道你的提示包含多少个令牌。然后,你不能提前知道模型的回答将包含多少个令牌。
你需要猜测。只有当你有一定的模型经验时,才能猜对。我建议你多进行实验,以更好地评估在给定提示的情况下,回答可能有多长。
如果你的提示太长,回答将会被截断。
我不会在这里提供关于 API 的更多细节,因为它可能变得相当技术性。
GPT 模型的局限性
GPT 模型仅仅是基于网络训练的令牌生成器。它们受限于训练数据的内容,因此不能被认为是完全安全的。
自 GPT-3.5 起,OpenAI 已训练其模型以避免回答有害内容。为实现这一目标,他们使用了机器学习技术,因此这种“自我调节”无法 100%被信任。
这种自我调节可能对某个特定提示有效,但在仅仅改变一个词后,可能会完全失效。
我还建议阅读 OpenAI 产品的使用条款。在这份文档中,我认为 GPT 模型的局限性更为清晰。
如果你计划使用 API 构建应用程序,你应该特别注意这一点:
使用这些服务必须年满 13 岁。如果你未满 18 岁,必须获得父母或法定监护人的许可才能使用这些服务。如果你代表其他人或实体使用这些服务,你必须有权接受其条款。你必须提供准确和完整的信息来注册账户。你不得将访问凭证或账户提供给你组织以外的其他人,你对使用你的凭证发生的所有活动负责。
意大利暂时禁止使用 ChatGPT,因为它可能会生成不适合 18 岁以下人群的回答,以及其他原因。
面向 ChatGPT 包装器的繁荣“准备好迎接意大利”
如果你是一个开发者,并在 OpenAI API 的基础上构建应用程序,你必须检查用户的年龄。
OpenAI 还发布了一份使用政策清单,指出了所有禁止使用模型的情况。
结论
GPT 模型非常简单,其架构自 2018 年以来没有发生太大变化。但当你在大规模合适数据上训练一个简单模型,并使用合适的超参数时,你可以获得像 GPT-3 和 GPT-4 这样的极其强大的 AI 模型。
它们如此强大,以至于我们几乎没有完全探索其所有潜力。
尽管最近的 GPT 模型不是开源的,但通过 OpenAI 的 API 使用它们仍然很容易。你也可以通过ChatGPT来体验这些模型。
对开源大型语言模型的温馨介绍
原文:
towardsdatascience.com/a-gentle-introduction-to-open-source-large-language-models-3643f5ca774
开放语言模型
为什么每个人都在谈论美洲驼、羊驼、猎鹰和其他动物
·发表于 Towards Data Science ·11 分钟阅读·2023 年 8 月 11 日
–
作者提供的图像(通过 Midjourney 生成)
除非你过去一年一直在过隐居生活,否则你已经见证了 ChatGPT 革命,以及大家似乎无法停止使用它的现象。在这篇文章中,我们将探索它的替代品,深入了解开源模型的世界。这是系列文章《开放语言模型》的第一篇,对希望入门并了解开源大型语言模型的人很有帮助,以及如何使用它们和为何使用它们。
目录
— 我们为什么需要开源模型?
— 越大越好?训练大型语言模型
— 微调大型语言模型
— 最佳开源大型语言模型
— 在你的计算机上运行大型语言模型
— 限制
— 结论
什么是大型语言模型?
一个**大型语言模型(LLM)**是一个能够理解和生成自然语言的人工智能。核心是一种叫做变换器的神经网络,它通过预测句子中下一个词来工作。词汇大型描述了这些模型的广泛性质,因为它们可以拥有数十亿甚至万亿个参数。它们的不同之处在于能够专注于特定任务,如代码生成或翻译,或应用于一般的指令跟随聊天机器人。这些模型的一个开创性方面是它们支持零-shot和少-shot学习,因为它们展示了前所未有的能力来学习未经过明确训练的任务。[1]
我们为什么需要开源模型?
假设你使用 GPT API 创建了一个创新的应用程序,并迅速获得了关注。一切进展顺利,直到 OpenAI 改变了他们的行动计划。他们可能会停止服务、提高费用,甚至降低模型的能力——这已经在发生。[2]
目前,你唯一的解决方案是适应他们的新政策。此外,依赖第三方 API 会导致你的数据传输到他们的服务器。虽然 OpenAI 可能不会利用 GPT API 的数据进行模型训练,[3] 但部署你自己的语言模型可以保证你对这些操作的完全控制。即使这看起来是一个理想的计划,部署你自己的 LLM 也有其自身的限制和挑战,这些将在本文中讨论。
(左) 一只驼鹿。照片由 Sébastien Goldberg 拍摄 | (中) 一只美洲驼。照片由 Dušan veverkolog 拍摄 | (右) 一只羊驼。照片由 Adrian Dascal 拍摄。所有照片均来自 Unsplash。
越大越好?训练大型语言模型
如果你碰巧看到像LLaMA 65B这样的模型,你可能会想知道65B的含义。它简单地指的是模型中的参数数量。随着模型规模的增加,它需要更多的训练时间,并在推理时消耗更多的内存。与机器学习中的常见观点不同,复杂的模型可以更容易地泛化到不同的任务。有些模型的参数数量非常庞大:GPT-3 拥有 1750 亿,而 GPT-4 拥有超过 1 万亿。估计从零开始训练这些模型需要数百万美元。例如,谷歌的 PaLM 540B 在 2240 个 GPU 上进行训练[4]。相比之下,EfficientNet-B7 是最受欢迎的图像分类深度学习模型之一,仅有 6600 万个参数。
显然,这不是你可以在笔记本电脑上训练的模型。[5]
在 2022 年,谷歌声称:
随着模型规模的增加,性能在各个任务上得到提升,同时也解锁了新的能力。
在当前的 LLM 状态下,更多的参数通常意味着更好[4]。公司专注于构建更大的模型,但当前的开源趋势是创建更小、更高效的模型。虽然最受欢迎的开源模型通常有多达 70B 的参数,但在特定任务上,小型的、经过微调的模型表现可能优于更大的模型。此外,更大的模型在训练和推理时需要更多资源,部署起来也更具挑战性。
实际上,在一年内,连谷歌的观点也发生了变化。
开源模型更快、更可定制、更私密,而且在性价比方面更具优势。他们用 $100 和 13B 的参数做到了我们在 $10M 和 540B 的参数下都难以做到的事情。而且他们是在几周内完成的,而不是几个月。
从泄露的文档我们没有护城河,OpenAI 也没有 [6]*,可以看出他们承认开源模型的惊人演变,这些模型通过使用更小、更便宜的模型迅速赶上了。
ChatGPT 对此反应良好。
多亏了过去一年开源社区的杰出工作,现在有了可用且免费的替代品。在谷歌的 PaLM 发布不到一年后,LLaMA也发布了;在论文中,作者声称他们最大的模型LLaMA 65B在许多任务上超越了 GPT-3(175B)和 PaLM(540B)[5]。
微调大型语言模型
LLMs 因其在单一框架内处理多种语言任务的多功能性而赢得了声誉。然而,你的特定应用可能要求模型在单一任务上表现出色。为此,你可以使用针对你的任务的数据集(例如文本摘要)来微调一个预训练的模型。令人着迷的是,即使数据集较小,也能取得良好的结果。尽管模型最初是用数十亿个文本片段进行训练的,你可能只需要 500 到 1000 个示例就能显著提高性能。
一个来自Alpaca 数据集的示例。该模型经过微调,以跟随用户给出的指令。
一种流行的技术是指令微调。这种方法涉及使用示例来训练模型,说明它应该如何响应特定的指令。这一过程的结果是一个指令模型——这是基模型的增强版本,在遵循指令方面表现出色,而不仅仅是完成文本。指令模型的例子包括Alpaca和Vicuna。
最佳开源大型语言模型
2023 年 2 月,Meta 的LLaMA模型以不同规模进入开源市场,包括 7B、13B、33B 和 65B。最初,该模型仅对研究人员开放,使用的是非商业许可证,但不到一周时间,它的权重就被泄露了。这一事件引发了开源大型语言模型(LLMs)领域的革命,因为其训练代码在 GPL 3 许可证下自由获取。因此,已经发布了几种强大的微调变体。
第一个是Alpaca,由斯坦福大学发布。该模型利用 GPT 生成的指令进行了 52K 示例的微调。紧随其后的是Vicuna,令人惊讶的是,它在许多任务中超越了 Alpaca,达到了 90%的 ChatGPT 质量。其显著特点是它是在 ShareGPT 数据上进行微调的。
在这些强大模型的基础上,新增了GPT4All——它受到使 LLMs 易于访问的愿景的启发,提供了一系列对消费者 CPU 友好的模型,并附带了一个互动 GUI 应用程序。
WizardLM 也加入了这些杰出的基于 LLaMa 的模型。通过一种名为 Evol-Instruct 的新颖独特的方法,它在复杂指令数据上进行了微调,并显示出与 ChatGPT 相似的表现,平均达到了 97.8%。
尽管如此,并不是所有最近的模型都基于 LLaMA。像 MPT 这样的模型以其变体能够生成多达惊人的 65k 上下文长度而闻名——一次生成整本书!
Falcon 也加入了这个趋势,提供了 7B 和 40B 两种变体。令人惊讶的是,它在 OpenLLM 排行榜上超越了 LLaMA,这要归功于其高质量的训练数据集 RefinedWeb。
然而,Falcon 在 HuggingFace 的 OpenLLM 排行榜上的统治地位并没有持续太久。2023 年 7 月,Meta 揭开了其著名模型的继任者 LLaMa 2 的面纱。这个下一代模型将其前身的令牌限制翻了一番,并将其上下文长度增加到 4K 令牌。同时,llama-2-chat——一个针对对话应用优化的版本——产生了重大影响。撰写本文时,像 StableBeluga2、Airoboros 和 Guanaco 这样的 LLaMa 2 微调版本 是最强大的开源大语言模型,主导了 OpenLLM Leaderboard。
OpenLLM Leaderboard 上的前 10 名模型大多基于 LLaMa 2。
如果你对某个模型感到好奇并且想尝试,最简单的方法可能是访问 HuggingFace 模型页面,然后打开一个使用该模型的 Space。它们是简单的 Gradio 界面,允许你向模型发送输入并接收输出。由于模型运行在他们的服务器上并且请求量很大,你通常需要排队等待一段时间。不过,鉴于他们的服务质量,等待几分钟也不是什么大问题。
HuggingFace Spaces 使尝试语言模型变得非常简单。作者提供的图片。
硬件要求和优化
去年,你需要一个高端 GPU 才能在本地运行 LLM。即使是小型模型也需要大量内存,并且你必须将它们加载到显卡内存中。情况随着 llama.cpp** 的发布发生了变化**。最初是 Facebook 的 LLaMA 模型在 C/C++ 中的移植,现在支持许多其他模型。它允许将 LLM 从 PyTorch 转换为 GGML,他们的新格式,允许在 CPU 上进行快速推理。由于用 C++ 编译,推理是多线程的。
得益于这一惊人的工作,现在可以在你的桌面电脑上,甚至在 MacBook 上运行许多 LLM。
主要的限制是它们所消耗的内存。例如,一个 7B 模型的权重大约为 14GB。一个 65B 的模型将需要大约 130GB 的内存(RAM 和磁盘空间),这超过了我们大多数计算机的容量。幸运的是,有一种压缩它们的方法。
进入量化。
在每个机器学习模型中,参数是一个数字。这个数字通常表示为 float32,即 32 位(4 字节)表示。
由于每个参数占用两个字节的空间,一个具有 65B 参数的模型将占用大约46510⁹ = 260GB的内存。
模型量化指的是通过使用舍入将模型参数表示为较低精度的数字,从而减少模型权重的想法。关于量化的数学细节,HuggingFace 博客上有一篇很棒的文章。[7]
8 位整数量化。[7]
作为一种压缩过程,量化会导致一定的性能损失。随着**LLM.int8()**及新技术的引入,这种损失已被大大减少,使得它成为 LLMs 的必备技术。
Llama.cpp支持高达 4 位整数量化。使用Q4_0,可以将模型大小减少最多 4 倍。
量化对模型困惑度和文件大小的影响。图片由作者提供。数据来自 llama.cpp GitHub 仓库。
Q4_0将文件大小从 25.8GB 减少到 6.8GB,同时将困惑度减少了约 2%。在 16GB 内存的计算机上可以加载一个 13B 模型,而如果你有 64GB 内存,你甚至可以运行最大的 70B 模型。
具有不同量化技术的 LLaMa 2 模型。TheBloke的 HuggingFace 个人资料上有许多预训练模型。
另一个值得一提的量化技术是GPTQ,它可以将 VRAM 使用量减少多达 75%,同时保持准确性。[8] 它于 2023 年 3 月发布,使得第一次可以在单个 GPU 上运行 175B 模型。谈到消费级硬件,你可以用单张 RTX 3090 显卡运行最多 30B 的模型。
在你的计算机上运行大型语言模型
与 ChatGPT 最相似的体验是GPT4All应用程序,它是一个聊天界面,允许你与最喜欢的模型聊天。它不仅限于 GPT4All 模型,还支持许多最流行的模型。该应用程序可以在 Windows、Mac OS 和 Linux 上运行,并且完全开源。
在 Ryzen 5600 CPU 上进行 LLaMA-2 7B 推理。图片由作者提供。
另外,你可以使用像 text-generation-webui 或 openplayground** 这样的 GUI 工具。**虽然它们都提供了一个可以轻松生成文本的图形界面,但第一个可能是最完整的工具:它提供了许多功能,如聊天、训练、GPU 支持和 HTML/Markdown 输出等。第二个与 OpenAI 的 Playground 非常相似,是一个很好的工具,可以快速测试和比较不同参数的 LLM。
文本生成网页 UI 是一个使用 Gradio 构建的高级网页界面。
如果你想在 Python 中使用 LLM,有几个选项,比如 llama-cpp-python 或 HuggingFace Transformers 库,它提供了一种与任何 HF 模型交互的高级语法。它支持 PyTorch、Tensorflow 和 Jax 后端,可以与 HF 上找到的任何预训练模型一起使用,适用于文本、图像或音频。使用 transformers,与你喜欢的 LLM 生成文本就像编写两行代码一样简单:
from transformers import pipeline
pipe = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf")
在 HF hub 上成千上万的免费模型中,有不同的格式。Transformers 库需要 HF 格式的模型,而 llama.cpp 需要 GGML 模型。
限制
尽管我们已经展示了大多数模型不需要强大的计算机,可扩展性仍然是你计划构建允许数百或数千名用户与模型互动的系统时首要考虑的问题。使用 GPT 的一个优势是 OpenAI 提供了便宜的、高速限制的 API,可以轻松地扩展你的应用。
另一个限制是 安全性和管理。由于你可能需要对模型的输出负责,你必须特别小心模型生成的内容。商业 LLM 具有使用强化学习与人类反馈 (RLHF) 构建的先进审查过滤器,用于限制有害内容。
使用前记得检查 模型的许可证。开源 并不总是意味着模型可以用于商业用途。
结论
这篇文章展示了开源模型是快速改进的可行和免费的替代方案。今年这些模型的发展取得了令人难以置信的进展,现在甚至可以在你的笔记本电脑上运行它们。
如果你正在考虑为下一个项目使用开源 LLM,希望你觉得这篇文章对你有用。系列中的后续文章将深入探讨 开源语言模型 的不同方面和挑战。
下次见!
如果你喜欢这篇文章,加入 Text Generation ——我们的新闻通讯每周发布两篇文章,提供有关生成 AI 和大型语言模型的最新见解。
此外,你还可以在 LinkedIn找到我。
参考资料
[1] 斯坦福科学家发现,ChatGPT 确实变得更愚蠢了 (2023), (futurism.com)
[2] T. Brown 等人,语言模型是少样本学习者 (2020), arXiv.org
[3] M. Schade,您的数据如何用于提高模型性能 (2023), OpenAI 帮助中心
[4] Google AI 博客,Pathways 语言模型 (PaLM): 扩展到 5400 亿参数以实现突破性性能 (2022), ai.googleblog.com
[5] A. Fan 等人,LLaMA: 开放且高效的基础语言模型 (2023), arXiv.org
[6] D. Patel 和 A. Ahmad,Google: 我们没有护城河(OpenAI 也没有) (2023), semianalysis.com
[7] Y. Belkada 和 T. Dettmers,使用 Hugging Face Transformers、Accelerate 和 bitsandbytes 对变换器进行大规模 8 位矩阵乘法的温和介绍 (2022), Hugging Face 博客
[8] E. Frantar 等人,GPTQ: 生成预训练变换器的准确后训练量化 (2023), arXiv:2210.17323v2 [cs.LG]
可调整神经网络的温和介绍(第一部分)
什么是可调整神经网络及其背景
·
关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 11 月 21 日
–
介绍
几何深度学习作为深度学习的一个分支,旨在扩展传统的 AI 框架,如卷积神经网络,以处理表示为图、流形或点云的三维或二维几何对象。通过直接将几何关系和空间依赖性整合到学习框架中,几何深度学习利用数据的固有结构特性,消除了对内存密集型数据增强技术的需求。出于所有这些原因,几何深度学习可以被视为在计算机视觉、自然语言处理等领域处理复杂数据场景的有价值工具。关于任务类型和转换类型,迄今已提出了大量新的 CNN 架构,如“球形神经网络” (链接), “图神经网络” (链接) 和 “可转向神经网络”。
可转向神经网络 因其将常规卷积神经网络(CNNs)的能力扩展到新的领域而引起了广泛关注。这些网络可以被视为 CNNs 的演变,其中核被条件化以满足特定约束条件。虽然 CNNs 在对平移的等变性方面表现出色,但可转向神经网络通过提供增强的灵活性和捕获更广泛的转换,如旋转,而更进一步。
本教程 将介绍“可转向神经网络”(S-CNNs)的简介,试图传达对其背后数学概念的直观理解以及如何设计这些网络的逐步解释。本第一篇文章作为介绍可转向神经网络的起点,解释其目的并深入探讨支持 S-CNNs 的概念和形式化。第二篇文章(这里)在高层次上讨论了可转向滤波器的设计和整体可转向网络。
本工作旨在填补当前科学文献与更广泛数据科学受众之间的差距。它非常适合技术专业人士以及这一新的机器学习分支的研究人员。
来自论文[3]的一个简单可转向神经网络的示例。可以看到输入图像的旋转效果反映在网络输出响应中。
以下论文作为参考:
[1] “3D 可转向 CNN:在体积数据中学习旋转等变特征”,Weilier 等,(link);
[2] “可转向 CNN”,Cohen 等,(link);
[3] “学习用于旋转等变 CNN 的可转向滤波器”,Weilier 等,(link)
[4] “通用 E(2)-等变可转向 CNN” Weilier 等,(link)
[5] “适用于局部尺度不变卷积神经网络的尺度可转向滤波器”,Ghosh 等,(link)
[6] “构建 E(n)-等变可转向 CNN 的程序。” Cesa 等,(link)
什么是可转向神经网络:
可转向神经网络得名于它们使用的特定类型的滤波器。这些滤波器称为 g-可转向滤波器,它们的灵感来自于在图像识别领域中用于边缘检测或定向纹理分析的可转向滤波器,这些滤波器在90 年代初获得了广泛的应用。可转向通常指的是可操控的、可管理的、能够被控制的。按照这种惯例,可转向滤波器的响应是可定向的,并且可以适应输入的特定方向(例如一张图像)。可转向性与另一个非常重要的属性相关,这就是等变性。在等变滤波器中,如果滤波器的输入经过了一个精确且明确的几何变换 g(平移、旋转、移动),则输出(即输入与滤波器卷积的结果)也会经过相同的变换 g。通常,等变性并不要求变换(输入和输出的变换)是相同的。这个概念将在下一个段落中得到更好的阐述,但目前这使我们能够提供对可转向滤波器和可转向 CNN 的初步定义。
一个 可转向 CNN 滤波器 可以定义为一个其内核结构为不同可转向滤波器的串联的滤波器。这些滤波器在 卷积操作 相对于一组定义明确的几何变换方面显示出等变性特性。
正如我们稍后将看到的,卷积操作上的等变性条件导致对内核结构及其权重施加特定的约束。从这个定义中,现在可以定义什么是可转向 CNN:可转向神经网络是由一系列可转向滤波器组成的神经网络。
S-CNN 的用途:
普通 CNN 的优势在于其对平移的等变性。然而,可导神经网络更加灵活,可以展示其他类型的变换,例如旋转。在旋转等变问题中,未经修改的 CNN 被迫学习相同滤波器的旋转版本,从而引入了冗余的自由度,并增加了过拟合的风险。
因此,可导 CNN 网络可以通过直接整合输入处几何变换的信息,优于经典 CNN。这一特性使得 S-CNN 在处理具有几何描述和表示的输入(如图像、流形或向量场)时特别有用。
可能的实际应用例如:
-
挑战性的 2D 图像分割: 给定输入显微镜图像预测细胞边界。
-
3D 模型分类: 对 3D 物体进行分类和识别。
-
3D 化学结构分类: 预测给定化学结构的分子 3D 化学结构。一个可能的例子是根据氨基酸序列预测其空间偏好,具体见论文的第 5.4 节 [2]。
3D 可导神经网络在 3D 物体识别中的应用示例。输入物体(在顶部)以及两个不同隐藏层特征图的表示。摘自 Link
初步定义和背景
在介绍了可导神经网络及其应用后,让我们深入探讨它们背后的理论。本节提供了等变性和可导性的更正式解释,提供了理解后续文章中可导滤波器构造所需的基本定义和正式框架。
本文依赖于对映射和几何变换的理解,更多信息请参考这篇 文章。
1. 等变性:
等变性是对称问题中特别感兴趣的特性。如前所述,在等变模型中,当输入经过变换作用时,输出也会受到相应作用,从而使得变换的应用可以在模型应用之前或之后进行,而整体行为不发生变化。在日常环境中有许多等变性的例子。例如,驾驶时,当转动方向盘时,汽车的转向方向与汽车所指方向是等变的。形式上,如果我们有一个映射 𝛙: X → Y,其中 X⊂ℝᵈ 和 Y⊂ℝᵈ¹,以及 g,一个属于群体 G 的几何变换,𝛙 对 G 是等变的,如果:
Eq.1: 表示𝛙对 g 的等变性的数学方程。
其中Π₀(g) : X → X’和Π₁(g): Y→ Y’是由应用g到 x 确定的两个线性映射(例如,通常是通过乘法应用的矩阵)。下图提供了一个来自论文[2]的视觉示例。在图像中,g是旋转,具体为“旋转-90°”,因此被称为r。*Π₀®在领域𝛙(=X)中操作,而Π₁®*在𝛙(=Y)的值域中工作。
如果X=ℝ²,2 维笛卡尔空间,且 r 是“顺时针旋转 90°”的变换,则矩阵*Π₀®*将等于θ=π/2 的 2x2 欧拉矩阵。
应注意,如果𝛙对 G 是等变的,那么施加变换后再计算映射会产生与先计算映射再施加变换相同的结果,这一属性以前称为交换性。
Fig2A: 函数Ѱ对变换 r 等变的视觉示例。摘自文章[2]。
此时还值得提到一个特例。 不变性,一种特殊类型的等变性,其中X=X’和Y=Y’。无论输入如何变换,输出始终保持不变。从深度学习的角度来看,不变滤波器例如在物体识别中可能有用:无论输入图像如何旋转,滤波器的输出始终保持不变。需要注意的是,X和Y的空间可能不具有相同的维度,例如,如果我们试图确定图片中汽车的方向(Y作为 2 维向量)而X作为像素的 2 维数组,则变换*Π₁(g)和Π₀(g)*将不同,因为它们适用于不同的空间,即使它们共享相同的 g。
2. 可操控滤波器:
与汽车的可操控性相比,可操控滤波器稍微难以直观理解。然而,两者都共享实现对特定参数一致和可预测响应的基本目标——这种响应与滤波器本身的固有变换密切相关。
一个直观的例子可能如下:想象一下屋顶上的风向标,显示风的方向。与其为每种可能的风向安装单独的传感器(这是不切实际的),不如安装一个可以旋转以与当前风向对齐的风向标。可转向滤波器就像一个风向标,它根据输入信号中编码的方向自适应,而无需为每种可能的输入方向使用独立的传感器。同样,在图像处理中,可转向滤波器适应图像中的不同特征或方向,而无需为每种可能的输入方向使用独立的滤波器。这种方法为建模系统提供了智能和有效的方法。在机器学习的背景下,它使我们能够专注于构建有价值的模型,而不必担心增强或增加额外的权重以处理不同的方向。
尽管可转向性可以普遍应用于任何一组变换,我们将在此使用旋转来更正式地介绍这个概念。
让 𝛙: ℝᵈ →ℝᵈ¹ 成为其核函数为 k 的卷积映射。
对于 x∈ℝⁿ,给定一个依赖于 x 的输入信号 f(x) ∈ ℝᵈ,并且输出信号 f₁(x) ∈ ℝᵈ¹,我们可以写成:f₁(x)= 𝛙(f(x)),这意味着 f₁(x)= k(x) ∗ f(x)。
如果对旋转的转向滤波器定义如下:
(1) 每个输出元素的卷积核 k(x) 可以表示为基函数 ψⱼ(x) 的和,其中 j=1,…M*。
(2) 通过任意角度 θ 旋转滤波器的 g_θ 可以用每个基函数的旋转表示(对于每个 θ 均适用)。数学上来说,这意味着:
Eq.2: 可转向滤波器的定义
由于这一特性,可以通过修改 wⱼ 的值来唯一定向滤波器对输入的响应。我们来举个例子。
在二维空间中,一个可定向单个可转向滤波器的最简单的例子是其核函数为 二维高斯 的方向导数。在这种情况下,k: ℝ² →ℝ,且 x = (x₁,x₂) ∈ ℝ²:
Eq.3: 二维高斯 的方向导数(上)和函数 k R² →R 在 gθ 下的转换。
在接下来的几行中,我们将展示该滤波器按上述方式是可转向的。
从理论上我们知道,鉴于 k 的值域是 ℝ,我们可以将旋转后的滤波器写成 Eq.3(有关更多信息,请参见下一节中的 Eq.3)。
通过推导这个方程,我们可以展示其可转向性:
Eq.5: 二维高斯 的方向导数可转向的数学证明
在这种情况下,我们应用了变换 g_θ: ℝ²→ℝ²,并且它由二维欧拉矩阵表示(见下文诱导表示)。如果我们计算 k(g_θ ⁻¹***(x₁,x₂)),** 我们可以通过一些代数运算看到,这种冲激滤波器的通用旋转版本可以表示为两个基函数 ѱ₁(x₁,x₂)* 和 ѱ₂*(x₁,x₂)* 的线性组合,系数由 θ 参数化。
如下方程(方程 6)所示,由于卷积的线性特性,输入函数 f 与θ旋转的冲激响应 **g_θ(k(x,y))=**k_θ 的卷积始终可以表示为 f 与 k 的单一基函数 ѱ₁、ѱ₂ 的卷积的线性组合。
方程 6:一个可转向滤波器与 f 的卷积。
这个公式突出了 可转向滤波器在神经网络中的力量。
通过引入这些滤波器,我们有可能构造一个可转向的核,它根据输入的方向“调整”其响应。每个基函数像一个多功能工具,允许网络使用学习到的权重‘w₁’和‘w₂’来高效地混合这些函数,以准确响应不同的方向。例如,当网络遇到具有不同方向的数据,如图像中的旋转物体时,它配置这些权重以使核的响应与输入数据的方向对齐。这种适应性提高了效率和效果,从而在参数更少的情况下达到相同或更好的结果。因此,这种方法可以作为使用可转向属性处理各种输入方向的更强大的 CNN 的基础。
在下一篇文章中,我们将进一步探讨这个问题,并了解如何使用可转向滤波器的概念来构建等变滤波器。
然而,在深入之前,一些定义将提供清晰度并帮助我们的讨论。因此,在下一段中我们引入了一些关于卷积的形式化内容。
3. 形式化:
在这一部分,我们试图给读者提供一个所有分析元素的示意性解释。这种形式化将允许我们更正式地定义 CNN 及其在输入层操作的几何变换。这将使我们在下一篇 文章 中理解可转向 CNN 的工作原理。
元素:
-
一个空间 S:分析发生的空间。虽然 S 可以扩展到任意数量的维度,但最容易在二维或三维空间中进行可视化。例如,如果我们考虑一幅图像,初始空间是二维的,对应于像素的坐标平面(ℤ²)。如果我们考虑一个“3D 物体”,那么空间 S 是三维的,ℤ³。因此,一个点 x∈S 确定了一个位置。
-
一个输入函数 f: 函数 f: S → F₀ = ℝ ͨ 描述了我们几何空间中的输入(它可以是流形或向量场)。这可以看作是从空间 S 到 ℝ ͨ 的一个函数,其中每个位置 x 与“特征” f(x) 相关联,也称为 x 点的 f 的纤维。举些例子,一个灰度图像可以看作是一个函数 f: ℝ² → ℝ,S=ℝ² 且 c=1。如果考虑一个彩色的 3D 流形,函数将是 f: ℝ³→ ℝ³,其中每个位置分配一个 RGB 颜色,S=ℝ³,c=3\。
实际上,函数 f 通常表示为一些采样空间上的纤维的打包结构;对于标准格式的图像,纤维将水平和垂直地规则分布(即像素)。函数 f 构成了神经网络的输入层(见图 2A,图 2B)。从现在起,这个起始层将被称为 F₀。
-
一组变换 G: 一旦分析对象被适当地定义,我们可以定义网络应该保持等变性的变换集。单个变换 g∈G 总是可以被描述为与应用它的数学空间相关的函数。给定输入函数 f:S→ℝ ͨ, 可以表征 π(g): ℝ ͨ → ℝ ͨ,作为“g 在 ℝ ͨ 中的诱导变换”。*函数 *f 存在于 ℝ ͨ 中,但变换 g 操作在 S 空间中。π(g) 描述了 f(在 ℝ ͨ 中)在应用 g(在 S 中)下的变换。考虑 g 作为由两个组件 r(旋转)和 t(平移)组成的旋转-平移,一般来说,输入函数 f(x) 在变换 g 下的变换如 Eq.7 所述\。
在下图中,如果 f 是一个向量场,π(g) 是一个 cxc 维度的矩阵,而**,** 如果 f 是一个标量场(f: ℝ² → ℝ),π® = 1。
所考虑的变换组 G 通常是旋转(在这种情况下我们将讨论SO(2) 网络)或旋转 + 平移(在这种情况下我们将讨论 SE(2) 网络)。类似地,在三维空间中,考虑 3D 刚体运动(SO(3) 或 SE(3))。
图 2B: 变换 g 对标量场(左)或向量场(右)的应用的图形表示。摘自论文 [3]。
Eq.7: f 如何通过变换 g 应用于 x 而被变换
-
特征图: 根据第二点给出的 f 定义,神经网络每一层的输出可以看作是函数 f ₙ 在初始空间 S 上的应用结果。形式上可以表示为从 S 到对域空间 Fₙ 的函数,( f : S → Fₙ),其中 Fₙ=ℝ ͨ ʿⁿ ʾ 和 cⁿ 是层 n 的特征数量。如果以图 2B 为例,我们可以看到初始信号(输入)可以看作是函数 f : S=ℝ² → F₀= ℝ³。
f₁: S=ℝ² → F₁= ℝ²。
-
NN 滤波器 φn: 滤波器可以定义为两个连续层之间的映射,如**φ*😗Fₙ→ Fₙ₊₁。将这种滤波器应用于一层意味着与相应的内核k进行卷积。在这种情况下如何定义卷积对理解可导 NN 至关重要。因此,我们在下面专门讨论了这一点。
NN 滤波器和卷积
在这种情况下,内核可以看作是一个函数 k: S → ℝ ͨ ʿⁿ ʾ ˟ ͨ ʿⁿ⁺ ¹ ʾ,其中 S 中的每个位置都连接到一个维度为 cʿⁿ ʾ ˟ cʿⁿ⁺ ¹ ʾ 的矩阵。为了清晰起见,cⁿ 和 cⁿ ⁺ ¹ 分别是 Fₙ 和 Fₙ₊₁ 的维度(特征数量)。
我们可以定义卷积如下:
Eq.8: 上方:连接层 n 和层 n+1 的关系。下方:空间 S 中的卷积定义
上面的方程 Eq.8 代表连接层 n 和 n+1 的函数;下面的是 n 维空间 S 中的卷积定义。函数 σ*(x)* 代表应用于卷积输出的非线性函数。
在图 2B 中,可以看到在离散域中,内核与输入层之间的卷积是如何计算的。我们用一个灰度图像 f ₀: ℝ² -> ℝ 来说明这一点。我们可以应用第二部分中讨论的滤波器,这是一个具有函数的可导滤波器。
k(x₁, x₂) 是一个定义为 k: ℝ² -> ℝ¹˟¹=ℝ 的 2D 高斯滤波器。
在这种情况下,将滤波器 k 应用于f₀ 是经典的 2D 卷积,可以表示为:*
Eq.9: 卷积的定义
不同的是,在图 2B 中,你可以看到另一个例子,其中 f ₀: ℝ²-> ℝ³(例如 rgb 图像)和 f₁: ℝ²-> ℝ² 以及 k₀: ℝ²-> ℝ³˟ ²。
图 2B: 如上定义的滤波器卷积的视觉示例,S=R²。F⁰是信号 f⁰存在的输入空间,在此案例中是 R³。可以注意到,卷积操作已被相关操作替代,如[4]中所建议。
综合我们迄今讨论的所有要点,可以在这一形式化框架内可视化神经网络。每个单独的特征图可以被解释为一个函数 f: S → Fₙ,其中 Fₙ= ℝʿⁿ ʾ 和 f₀(x) 代表网络的输入。滤波器的应用涉及与其在 Eq.8 中定义的核函数卷积。值得注意的是,到目前为止,主要的创新在于将 f 作为在位置空间 S 中操作的函数的几何表示,以及在这一空间内卷积的定义。
以下是我们提供的神经网络在这一背景下的表示:
Eq.10: 使用上述形式化表达的神经网络的符号表示。
我们将在下一篇文章中了解这种形式化定义如何帮助我们设计可引导的 CNN 滤波器。
结论
在我们《可引导神经网络》教程的初始部分,我们已经建立了可引导神经网络、等变性和可引导滤波器的基本概念。还介绍了一个数学框架,为理解这些概念提供了严格的基础。等变性在变换下保持行为不变,而可引导滤波器能够智能地适应输入的方向。这一基础工作为设计等变 CNN 滤波器铺平了道路,增强了边缘检测和基于方向的识别。下一篇文章将利用这些概念更深入地探讨可引导 CNN 滤波器的机制,完成我们对这一强大神经网络范式的探索。
✍️ 📄. 关于作者:
1️⃣ Matteo Ciprian,机器学习工程师/研究员
-
帕多瓦大学电信工程硕士。当前从事传感器融合、信号处理和应用 AI 领域的工作。具有与 AI 在电子健康和可穿戴技术中的应用相关的项目经验(包括学术研究和企业领域)。专注于开发异常检测算法,以及推进深度学习和传感器融合技术。
对哲学充满热情。YouTube 内容创作者。
🔗 链接: 💼 Linkedin
📹 Youtube
👨💻Instagram
2️⃣ Robert Schoonmaker,信号处理/机器学习研究员
-
杜伦大学计算凝聚态物理博士。专注于应用机器学习和非线性统计,目前研究 GPU 计算方法在合成孔径雷达及类似系统中的应用。经验包括开发用于传感器融合和定位技术的对称机器学习方法。
🔗 链接: 💼 Linkedin
《可操控神经网络简介(第二部分)》
如何构建可操控滤波器和可操控 CNN
·
关注 发表在 Towards Data Science · 10 分钟阅读 · 2023 年 11 月 21 日
–
1) 介绍
本文是**《可操控神经网络简介》**教程的第二部分,也是最后一部分。它接续了第一部分的内容(在这里)。
第一篇文章提供了 Steerable 神经网络(S-CNNs)的简明概述,解释了它们的目的和应用。它还深入探讨了基础形式主义和关键概念,包括等变性和 Steerable 滤波器的定义。尽管下一段落将对形式主义进行快速回顾,但我们建议您阅读第一篇文章以全面了解。
在本教程的最后部分,我们希望提供一个关于如何构建 Steerable Filter 的指南,并在最后,介绍如何组合一个 Steerable 神经网络。
快速回顾术语:
图 3A:按照形式主义表示的神经网络。
-
S:输入域空间。对象存在的空间(通常为ℝ³或ℝ²)。
-
fₙ****😗** 一个映射/函数,fₙ:* S → ℝ ͨ ʿⁿ ʾ(Fₙ)*,*描述了 NN 的第 n 个特征映射。请注意,f⁰是描述输入(输入层)的函数,而对于 n>0,fₙ描述了第 n 个特征映射。
-
Fₙ**= ℝ ͨ ʿⁿ ʾ**,它是描述fₙ*的值域。
-
Φₙ**: Fₙ→ F** ₙ₊₁***,* 第n 个 NN 的滤波器可由核函数kⁿ: S →* ℝ ͨ ʿⁿ ʾ ˟ ͨ ʿⁿ⁺ ¹ ʾ *来描述。卷积的定义如上第二个方程式所示。
-
G:变换的组(单一元素g)。
综合考虑所有这些概念,我们已经能够定义如下卷积:
2) Steerable CNN 滤波器的设计
图 3A:等变 CNN 滤波器的视觉示例。给定对 S 的变换 g 及由Π₀(g)给定的输入信号 f 的旋转,f₁由Π₁(g)旋转。
2.1 问题的形式化
我们可以声明,如果对于每个g 在 G,0,当输入函数 f₀变换为Π₀(g)时,那么第 n 层的输出函数将变换为Πn(g),则 n 层的 CNN 对于一组变换 G 是等变的。
使这个陈述成立的一个充分条件是每个连续的层对其直接输入的变换具有等变性(见图 3A)。网络的等变性是通过归纳来实现的。根据第二篇文章中给出的定义,如果滤波器Φ满足以下条件,则Φ是等变的:
方程 0:等变性的定义
现在可以宣称 steerable 神经网络理论的主要结果。
设k连接层fₙ和f的核心函数,使得fₙ₊₁ = k f ₙ*.*
卷积k* f ₙ对于变换 g 是等变的,当且仅当:
或更简单
Eq.1: 关于变换 g 的核等变性的必要且充分条件。
在更广泛的文献中[2,3],遵循此约束的核被称为 g-可导核。由于核约束以线性方式操作,它生成的解构成了标准 CNN 中通常使用的无约束核的向量空间中的一个线性子空间。经过更仔细的审查,这一定义与最后一篇文章第 2 段中介绍的可导滤波器的概念非常契合 这里。在实际操作中,为了获得此工作,我们需要一个此核子空间的基,记作{k_1, …k_D},它符合方程(1)。这个基的大小,记作 D,可以计算为 D = cʿⁿ ʾ ˟ cʿⁿ⁺¹ʾ。核*k(x)*随后通过这个基的线性组合得出,网络在过程中学习权重:
Eq.2: 方程(1)的线性使得解等于以下线性组合。
在训练场景中,我们的方法涉及将输入层和输出层的大小设置为特定的值,即 cʿⁿ ʾ和 cʿⁿ⁺¹。然后,根据我们寻求等变的变换,解决方程并确定一个核基。随后,在训练过程中,我们学习与这些核相关的权重。
2.2 解方程
方程(1)中呈现的约束的解远非简单。它依赖于三个主要元素:
-
空间 S,无论是 S= ℝ³还是 S= ℝ²。
-
群体 G。
-
层的输入输出维度:cʿⁿ ʾ 和 cʿⁿ⁺ ¹ ʾ*。
更具体地说,我们可以说群 G 的选择定义了网络的类型。具体来说,我们主要对以下类型的网络感兴趣:
-
SO 网络:对特殊正交群(SO)中的旋转具有等变性。
-
SE 网络:对特殊欧几里得群(SE)中的旋转和平移具有等变性。
-
E 网络:对欧几里得群(E)中的旋转、平移和反射具有等变性。
如果我们在 2D 输入域中操作,我们有 SO(2),SE(2)和 E(2)网络 [4]。相反,对于 3D 输入域,我们使用 SO(3),SE(3)和 E(3)网络[1],并且这可以扩展到任何 E(n)空间 [6]。
将这项工作扩展到其他空间和对称性是一个持续的研究领域,感兴趣的读者可以调查被称为 Hilbert 空间和 Green 函数的数学研究领域,这里不在本文讨论范围之内。
然而,可以看到在 SE(n) 网络的情况下,方程式 1 的一般解是 S=ℝⁿ中的一个谐波基函数。在上面的图像中(图 3B),可以看到ℝ²左侧的谐波函数和ℝ³中的谐波函数。
图 3B:二维(左)和三维(右)中的谐波函数基。这些基分别构成 SE(2) 和 SE(3) 网络中可操控等变滤波器的基。
考虑一个更具滤波器设计场景的情况,在下图 Fig 3C 中,我们可以看到如何为输入层 f ₀: ℝ²->ℝ³ 和输出层 f₁: ℝ²->ℝ² 构建一个 SO2 可操控等变核。
核是一个函数 k: ℝ²->ℝ³ˣ²。矩阵的每个单独元素是通过对在位置*(x₁,x₂)*采样的 D 基的线性加权组合得到的函数。我们可以查看上面的示例位置 x=(1,2)。
接下来,我们将展示一些这个方程的简单解,考虑 S=ℝ² 和 G 作为旋转变换的群体,包含 SO2 网络。
图 3C:使用 6 个谐波函数的基构建的 3x2 可操控核的可视化表示。
2.3 实际解决方案
- Case1A: SO2 网络,k: S=ℝ² → ℝ
假设实际情况下输入为灰度图像,我们想要构建一个可操控滤波器来处理它。首先,我们必须决定输出层的维度(特征数量)。为了简便起见,假设维度为 1。
在这个设置中,我们有一个输入函数 f: ℝ²-> ℝ 和一个类似的输出函数 f₁: ℝ²-> ℝ。因此,核函数是 k: ℝ² -> ℝ。我们希望我们的 CNN 层对一个变换群体 G 是等变的,G 代表了角度 theta 在 [0,2π) 范围内的旋转(SO 网络)。对于这个问题,核函数的基需要使用方程式 1。由于 f 和 f¹ 都是标量,Pi_out = 1 和 Pi_in = 1。结果是 *k[****g_***θ(x)] = k[x],如方程式 3 中所写。
如果 x = (x₁, x₂) 在 ℝ² 中,g(theta) 与 2D 欧拉矩阵对齐。
方程式 3:在 k: S=ℝ² → ℝ 的情况下重写方程式(1)
很容易看出,这可以通过在(x₁, x₂)中的每个各向同性函数来解决。具体来说,这可以通过一维的各向同性(旋转不变)核来解决。(即 k(x₁, x₂) = x₁² + x₂²)
Case 2: SO2 滤波器,k: ℝ² → ℝ²
现在考虑一个更复杂的情况。输入函数为f: ℝ² → ℝ²,输出层为函数f ₁: ℝ² → ℝ²。因此,内核可以写为函数k: S= ℝ² → ℝ² ˣ ²; 换句话说,对于ℝ²中的每个位置 x,我们有一个二维矩阵 2x2(见下方方程)。我们想要构建 S02 滤波器,因此需要考虑的变换群再次是 G={g(θ)} ={r(θ)},θ ∈ 0,2Π[. 由于ℝ²是f和f ₁的值域,Π_out=Π_θ* 和 Π_in=Π_θ,其中Π_θ是ℝ²中的欧拉矩阵。考虑到所有这些条件,我们可以以下述方式重写 Eq.1:
Eq.(5): 为 SO2 内核 k 重写 Eq.(1): S=ℝ² → ℝ²。
欲更全面地理解该方程的解及更多见解,请参考论文[4]中的附录部分。
2.4 网络非线性
到目前为止,我们仅考虑了相对于卷积操作的等变性,而没有考虑由函数σ*(f(x))😗 **ℝ=**ℝ ͨ→ℝ ͨ’给出的非线性部分。论文[1]的第 4.3 节和论文[4]的第 2.6 节对此进行了广泛讨论。
给定函数 f(x),等变性条件可以总结如下:
Eq.(5): 激活函数的等变性条件。
如相关的 YouTube 讲座中所提到的这里,可以通过利用所谓的基于范数的激活函数,如σ*(u) =* σ*(||u||)*,来创建满足该标准的激活函数。其动机在于标量范数是透明不变的,因此对其应用任何非线性函数将产生不变的输出。为了证明这一点,当我们将此公式应用于上述条件时,会得到以下方程:
Eq.(6): 将 Eq.(5)重写为基于范数的函数。
如果‘g’属于 E 变换群,则范数保持不变。因此,当Π’(g)等于单位矩阵时,该方程在普遍情况下是有效的。这意味着特别设计的激活函数始终具有旋转不变性。例如,Norm-ReLUs,其定义为η(|f(x)|) = ReLU(|f(x)| − b)
研究论文和讲座中提出了额外的非线性激活函数,如非门控激活函数。我们建议读者查阅这些来源以获取进一步的解释。
3) 设计一个可操控的 CNN
图 3D: 可操控 CNN 的架构,如[3]中所述。注意第 2 层中使用的可操控滤波器与 G 卷积的结合。
在上一节中,我们掌握了构建单个可引导滤波器的基础知识。在本节中,我们将深入探讨如何将这些滤波器有效地整合以建立一个功能全面的可引导神经网络。
在上图中,我们可以看到一篇论文中的示例[3]。我们特别关注第 2 层,其中使用了可引导滤波器。
在这里,每个水平表示都是一个可引导滤波器——由加权谐波函数组成——它产生一个不同的输出,表示为单个 fⁿ。观察其结构,很明显虽然谐波函数在各个滤波器中保持一致,但它们的方向在每个滤波器之间有所变化。这是 G-卷积技术的一个典型特征,这是一种复杂的方法,有助于构建对变换不变的网络(你可以在这里找到更多关于该技术的信息)。该网络利用最大池化的强大功能,将来自可引导滤波器阵列中的最强响应传递到下一层。这种选择性传输的原则确保了最强的特征在网络中传递和增强。这种方法与其他工作中实现的方法类似,例如参考文献[5]成功构建了一个尺度不变的可引导网络。这种可引导 CNN 的架构受益于这种技术,因为它自然地结合了尺度和旋转不变性,从而增强了网络以更抽象但更强大的方式识别模式和特征的能力。无论如何,从图片中可以看出,最终结果是一个对旋转不变的网络。
图 3E:在旋转图像上应用 2D 可引导滤波器的视觉示例(原始图像可以在这里找到)
关于可引导神经网络设计的优秀逐步解释可以在此链接中找到,该链接包含在 Github 库*“e2cn”(link)。在该库中,可以找到设计 SE2 可引导网络的 PyTorch 代码。关于 SE3 网络的有用代码可以在此链接中找到,而关于 3D 等变网络的快速课程已在这里发布。
文献:
[1] “3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data”,Weilier et al.,(link);
[2] “Steerable CNNs”,Cohen et al. ( link);
[3] “学习旋转等变 CNN 的可调节滤波器”,Weilier 等人 (link)
[4] “通用 E(2)-等变可调节 CNN”,Weilier 等人 (link)
[5] “适用于局部尺度不变卷积神经网络的尺度可调节滤波器”,Ghosh 等人 (link)
[6] “构建 E(n)-等变可调节 CNN 的程序。” Cesa 等人 (link)
✍️ 📄. 关于作者:
1️⃣ Matteo Ciprian,机器学习工程师/研究员
-
硕士学位,电信工程,帕多瓦大学。当前在传感器融合、信号处理和应用人工智能领域工作。参与与人工智能在电子健康和可穿戴技术中的应用相关的项目(学术研究和企业领域)。专注于开发异常检测算法,以及推进深度学习和传感器融合技术。
对哲学充满热情。YouTube 内容创作者。
🔗 链接: 💼 Linkedin
📹 Youtube
👨💻 Instagram
2️⃣ Robert Schoonmaker,信号处理/机器学习研究员
-
博士学位,计算凝聚态物理,杜伦大学。专注于应用机器学习和非线性统计学,目前正在研究 GPU 计算方法在合成孔径雷达及类似系统中的应用。经验包括开发用于传感器融合和定位技术的对称机器学习方法。
🔗 链接: 💼 Linkedin
对分析流处理的温和介绍
为工程师及其他相关人员构建心理模型
·
关注 发表在 Towards Data Science · 17 分钟阅读 · 2023 年 3 月 31 日
–
流处理可以被温柔而细致地处理,也可以被狂野而几乎失控地处理!你可以判断你更愿意拥抱哪种未来。来源:@psalms 原始照片
介绍
在许多情况下,将数据流中或实时可用的数据进行处理,可以将由于数据流量和规模而导致的庞大数据问题,转化为更可管理的问题。通过更频繁地处理较小的数据集,你可以有效地解决那些可能因成本和时间限制而难以处理的数据问题。
从批处理思维转变为流处理思维虽然也可能很棘手,所以让我们从小做起,逐步构建。
从庞大的数据回到大数据
假设你负责构建一个分析应用程序,该应用程序必须处理大约10 亿个事件(1,000,000,000)每天。虽然这在开始时可能感觉难以实现,但由于数据的巨大规模,通常有助于退后一步,思考应用程序的意图(它做了什么?)和你正在处理的内容(数据是什么样的)?问问自己事件数据是否可以被分解(划分和分区)并作为流处理操作(即流内)并行处理,还是必须通过多个步骤串行处理?无论哪种情况,如果你将应用程序的视角修改为查看有限的时间窗口,那么你现在只需要创建一个可以摄取和处理仅每秒 11,500 个事件(k)(或者如果事件流是恒定的,则每分钟约 695k 个事件)的应用程序,这是一个更容易理解的数字。
虽然这些数字仍然可能显得难以触及,但这正是分布式流处理真正发挥光芒的地方。从本质上讲,你是在减少问题的视角或范围,以在时间上跨越分区数据集实现目标。虽然并非所有问题都能在流处理中解决,但许多问题确实适合这种处理模式。
注意*:本章是我书中的一部分* “现代数据工程与 Apache Spark:构建关键流应用程序的实用指南”。本书带你从简单的脚本编写,到应用程序的构建,最后到部署和监控你的关键 Apache Spark 应用程序。
本章学习内容
本章将作为流处理的温和介绍,为我们直接进入第十章构建自己的端到端结构化流应用程序做好准备,而无需回顾和讨论许多决策过程背后的理论。
到本章结束时,你应该能高层次地理解以下内容:
-
如何减少流数据问题中的时间问题
-
时间、时间戳和事件视角的问题
-
从批处理到流处理思维模型的不同处理模式
流处理
流数据是非静态的。事实上,你可以将其视为活跃的(即使只是短时间)。这是因为流数据是捕捉当前时刻的事件和动作的数据。让我们来看一个实际的,尽管是理论上的例子,它从一个简单的传感器数据流开始。把你最后一次访问的停车场(或停车库)固定在你的脑海中。
用例:实时停车可用性
停车是个噩梦:大多数停车基础设施的问题,或客户的常见痛点,往往是在确保按时到达的情况下找到一个可用的车位。照片来自 Unspash 和 @ryansearle
想象一下,你刚刚找到一个停车位,感谢一些有用的标志指引你到一个空车位。现在假设这一切都因为来自连接的本地停车传感器网络的数据。传感器的唯一目的是用于识别此时此刻可用停车位的数量。
这是一个实时数据问题,实时准确性既可以测量,也可以由停车结构的用户实际感受到。这些能力的实现始于系统场景的声明。
产品推广:“我们希望创建一个系统来跟踪所有可用停车位的状态,识别何时有车辆停车、车辆在特定车位停留多久,并且尽可能地自动化这个过程”
优化这样的系统可以从一个简单的传感器开始,该传感器位于每个停车位(与传感器.id / 车位.id 参考关联)。每个传感器负责以事件的形式发出数据,包含车位标识符、时间戳和简单的位(0 或 1),以表示车位是空的还是被占用。然后,这些数据可以编码为紧凑的消息格式,如示例 9–1,并定期从每个停车位高效地发送。
示例 9–1. 为了清晰起见,展示了一个示例传感器事件(封装在Google Protocol Buffer消息格式中)。
message ParkingSensorStatus {
uint32 sensor_id = 1;
uint32 space_id = 2;
uint64 timestamp = 3;
bool available = 4;
}
在一天的正常交通流中,传感器的状态(停车位的可用性)会根据车辆到达或离开每个车位而开或关(二进制状态)。由于每个驾驶员的动态日程,这种行为是不可预测的,但随着规模的扩大,模式总是会显现出来。
利用收集的传感器数据提供的实时状态,可以轻松构建实时的、现实生活中的(IRL)“报告”,以便更新驾驶员停车结构的活动状态:停车基础设施是否已满,如果未满,则车库中目前有 X 个可用车位。
传感器数据的作用
这些数据可以帮助自动化人类决策过程,甚至可以通过简单的网络服务在线提供,以便实时跟踪状态,因为最终驾驶员只是想尽快停车,而不是浪费时间!此外,这些数据还可以用于跟踪每个传感器上次检查的时间(刷新),这可以用来诊断故障传感器,甚至跟踪传感器离线或故障的频率。
现在,技术更先进的车库甚至能够通过指示牌和提示引导驾驶员到结构内的空车位。这既减少了车库间的交通和拥堵,又提高了顾客满意度,所有这一切只需捕捉实时的传感器数据流并进行近实时处理。
高峰定价和数据驱动决策
根据从这些传感器事件流中收集的时间(时间戳)信息,一个精明的车库运营者可以利用先前的趋势来实时减少或增加每日或每小时的价格,这取决于对车位的需求,考虑到当前的可用性(剩余车位数)。通过优化定价(在现实限制范围内),运营者可以找到一个完美的阈值,使得每小时/每日价格能使车库更多时间达到满员。换句话说,“以什么价格大多数人愿意停车且车位不会空置?”。
这是一个优化问题的例子,源于实时传感器数据的收集。组织越来越常见地查看如何重用数据来同时解决多个问题。物联网(IOT)用例只是你在编写流应用程序时可能处理的众多数据流中的一种。
在书中早些时候,我们讨论了“创建一个系统,可以获取咖啡店的占用信息,以告知人们离他们最近的店铺是否有适合他们人数的座位”。在故事中的那个阶段,我们只是创建了一个合成表来展示这个例子,但这也是一个可以用传感器或简单的签到系统解决的问题,该系统发出相关事件数据,通过我们的流数据管道可靠地传递下游。
这里讨论的两个例子(停车基础设施和咖啡帝国扩展)都采用了基本的分析(统计),并可以从简单的机器学习中受益,以发现新的行为模式,从而实现更优的操作。在我们过于深入之前,先休息一下,深入了解流数据网络提供的功能。
时间序列数据和事件流
从一个关于固定视图或时间点的静态数据思维方式,转变为一个将数据视为在时间中流动的视角,涉及到对许多视图和时间点中无限数据流的解释,这是一个视角上的练习,但起初可能具有挑战性。通常,当你考虑流式系统时,连续事件流的概念会浮现出来。这是一个更常见的用例,可以作为对流数据概念的温和引入。例如,图 9–1中所示的抽象时间序列。
图 9–1:事件发生在精确的时间点,可以单独收集和处理(t1->t4),也可以在时间窗口(w1)中聚合。图片来源:作者(Scott Haines)
正如你所看到的,数据本身在不同的状态下存在,取决于系统(或应用程序)应用的视角或观点。每个事件(T1->T4)只在其狭窄的参考范围内理解发生了什么,或者换句话说,事件捕捉了时间的有限(相对)视角。当一系列事件在一个有界集合(窗口)中一起处理时,你会得到一系列数据点(事件),这些数据点要么完全实现的概念,要么部分实现的概念。当你缩小视角并查看整个时间线时,你可以更准确地描绘从第一个事件到最后一个事件发生的故事。
让我们进一步探讨这个想法。
事件是否独立存在?
考虑这个简单的事实。你的事件数据存在于完整的概念中,或作为部分概念或思想。我发现将数据视为一个随时间变化的故事有助于赋予这些数据字节以生命。因此,每个数据点负责帮助构建一个完整的故事,作为一系列随时间展开或呈现的交织思想和观念。
数据组合是一个有用的视角,适用于你在采用分布式数据视图时。我还发现它在构建和定义新的分布式数据模型时很有用,同时在处理大规模的现实世界数据网络(织物)时也很有帮助。作为一种组合来看待,这些事件汇聚在一起讲述了一个特定的故事,这些基于事件的痕迹可以揭示事物发生的顺序,并通过每次事件的时间戳得到极大的增强。没有时间的事件描绘了一个平面的发生过程,而时间的加入赋予了你关于动量或速度的概念,或者在事件之间的时间延续和拉伸,或者对于一系列数据点的完整过程。理解数据在许多管道和数据通道中流动的行为对数据操作至关重要,并且需要可靠的监控以保持数据在最佳速度下流动。
让我们来看一个用例,其中时间维度帮助讲述一个现实世界场景的更好故事。
用例:追踪客户满意度
一家安静的咖啡店,每杯咖啡都倾注爱心。照片由 Nafinia Putra 提供,来源于 Unsplash
设想你是一个数据工程师,与一个名为“CoffeeCo”的虚拟咖啡帝国的数据应用特性团队合作,讨论的是哪些数据能够描述客户满意度随时间变化的故事(时间序列分析)。
如果我告诉你两位客户进入我们的咖啡店,点了饮料并带着饮料离开了店里。你可能会问我为什么要告诉你这些,因为这正是咖啡店里发生的事情。如果我告诉你这两个咖啡订单是在差不多同时 下的,并且故事中的第一位客户在咖啡店待了不到五分钟。如果我告诉你这是一个工作日,且这个故事发生在早高峰时段?如果我告诉你第二位客户,恰好排在第一位客户之后,在咖啡店里待了三十分钟?你可能会问这个客户是否在读报纸或者使用设施。这两个问题都是合理的。
如果我告诉你第二位客户因为在四步咖啡生产线的第 3 步和第 4 步之间发生错误而等待,那么我们将更好地理解如何在未来简化客户体验。
四个步骤是:
1. 客户订单: {customer.order:initialized}
2. 付款完成 {**customer.order:payment:processed}
3. 订单排队: {customer.order:queued}
4. 订单完成: {customer.order:fulfilled}
无论错误是在自动化过程中,还是由于现实世界系统的故障(如打印机卡纸、咖啡师漏单或其他任何原因),结果是客户需要插手(人工干预),并通知操作系统(咖啡生产线)“似乎有人忘记制作我的饮料”。
此时讨论可能会转向如何处理客户的情绪反应,这些反应可能在积极和消极之间大幅波动:从乐于助人(1),到轻微的挫折(4),再到对咖啡生产线的延迟和故障的明显愤怒(10)。但通过分析一个假设的用例,我们现在对如何利用捕捉好数据的艺术有了更深入的了解。
事件时间、事件捕捉顺序和事件间的延迟都讲述了一个故事
如果不了解从第一个事件(customer.order:initialized)到终端事件(customer.order:fulfilled)之间的时间经过了多久,或者每个步骤通常需要多长时间来完成,我们将无法对体验进行评分或真正理解发生了什么,基本上就会在系统中创造出对异常延迟或故障的盲点。了解客户通常等待不同大小订单的时间的统计数据(平均值、中位数和 99 百分位数)是有益的,因为这些历史数据点可以通过自动化用于预先解决问题,例如当一个订单的处理时间比预期的要长时。这可能意味着客户的不满和终身客户之间的差别。
这是公司请求客户反馈的主要原因之一——无论是对体验的好评/差评,奖励基于应用程序的参与(用你的积分换取免费商品和服务),还是跟踪实时反馈,比如“你的订单比预期的时间长,这里有$2 折扣下次咖啡使用。只需使用应用程序兑换”。这些通过现实世界互动收集和捕获的数据,以事件形式编码,并为你的利益处理,最终是值得的,如果它积极影响公司的运营和声誉。只要确保遵循数据隐私规则和法规,最终不要让客户感到不适。
这个小小的思想实验旨在揭示事件数据中捕获的细节(以及数据故事随时间的演变)可以是一个游戏规则改变者,并进一步说明时间是赋予这些旅程动力或速度的维度。只有一个时间的问题。
时间的麻烦
虽然事件发生在精确的时间点,但时间的问题在于它也受时间和空间(位置)问题的影响。爱因斯坦利用他的相对论理论在宇宙尺度上解释了这个问题,但在更局部的尺度上这也是一个问题。例如,我有家人住在美国不同的地方。在大家的时间协调上可能会有困难,这种情况发生在简单的事件中,比如远程视频聊天或在现实世界中聚会。即使一切都已协调好,人们也习惯于稍微迟到。
从我的家庭或一般人的角度看,关于事件的中心协调问题,你会开始看到这个问题不仅仅是跨时区(东部/中部或西海岸)的同步问题,而是如果你更仔细地看,时间相对于我们的本地/物理空间,会受到一定的时间漂移或时钟偏差的影响。
以现代数字时钟为例。它作为一个过程运行在你的智能手机、手表或众多“智能”连接设备中。保持不变的是时间始终明显同步(即使漂移的范围是毫秒级)。许多人仍然使用模拟的非数字时钟。这些设备的准确度从高端手表(“计时器”)的极度准确到需要每几天重新校准的便宜时钟不等。
关键点在于,两个系统在精确时间上达成一致的情况很少,就像两个人在时间和空间上协调类似问题一样。因此,必须使用一个中央参考点(视角)来同步跨多个时区运行的系统的时间。
时间校正
在任何现代云基础设施中运行的服务器都利用一个称为网络时间协议(NTP)的过程来校正时间漂移的问题。ntp过程负责使用一个可靠的中央时间服务器同步本地服务器时钟。这个过程将本地时间校正到与协调世界时间(UTC)相差几毫秒。这是一个重要的概念,因为在大型网络中运行的应用程序,产生事件数据,将负责创建时间戳,而这些时间戳需要非常精确,以便分布式事件能够对齐。还有一个狡猾的问题是夏令时(每 6 个月增加或减少一小时),因此,协调跨时区以及跨本地日期时间语义(全球)的系统数据需要从这个中央、同步的视角来看待时间。
我们已经从理论上看过时间如何与事件驱动的数据相关,但为了全面了解背景,我们还应该考虑时间如何与数据在系统(无论是流式还是其他)中需要被捕获和处理的优先级相关。
优先级排序事件处理模式
你可能对这句名言很熟悉:“时间就是生命。”这句话的意思是某事很重要,是首要任务。解决问题的速度至关重要。这种优先级感可以作为一个工具或定义指标,用来为实时、接近实时、批处理或*最终处理(按需处理)*的数据处理方式辩护。这四种处理模式以不同的方式处理时间,通过对数据问题施加窄或宽的焦点来应对。这里的范围是基于一个过程必须完成的速度,这反过来限制了工作的复杂性作为时间的一个因素。可以把这些处理风格看作是以截止日期驱动的,完成一个动作的时间是有限的。
实时处理
实时系统的期望是,从上游系统发出事件的时间到该事件被处理并可用于分析和洞察的时间,端到端延迟应在毫秒到低秒级别内。这些事件直接写入事件流处理服务,如 Apache Kafka,在正常情况下,允许监听者(消费者)一旦写入事件就能立即使用。真正的实时系统有许多典型用例,包括物流(如停车位示例以及在咖啡馆找到桌子),以及影响整个业务的过程,如欺诈检测、主动网络入侵检测或其他坏演员检测,其中检测的平均时间(毫秒/秒到检测)越长,可能会导致声誉、财务或两者的灾难性后果。
对于其他系统而言,运行在近实时状态是完全可以接受的。考虑到解决难题需要时间,实时决策需要高效、预计算或低延迟的答案。这确实是纯内存流处理。
近实时处理
近实时是大多数人在考虑实时时的想法。这里发生的模式类似于你刚刚在实时部分阅读的,唯一的区别是端到端延迟的期望放宽到高秒级别到几分钟。对于大多数系统而言,没有真正的理由对每个到达的事件做出立即反应,因此,虽然时间仍然很重要,但数据可用性的 SLA 优先级会有所延长。
操作仪表板和度量系统通常更新迅速(每 30 秒—5 分钟刷新图表和检查监控),足够快以捕捉问题,并给出接近现实的表示。对于所有其他数据系统,你会有批处理或按需处理的概念。
批处理
我们在前两章中涵盖了批处理和周期性调度,但为了明确,将数据从可靠的真实数据源(数据湖或数据库)推送到其他连接系统的周期性作业,一直以来都是世界数据处理的主要方式。
这背后的简单原因是成本。这涉及到操作成本和维护大型流媒体系统的人力成本。
流处理系统要求全天候访问从 CPU 和 GPU 到网络 IO 和 RAM 的可变数量资源,期望这些资源不会短缺,因为流处理中的延迟(阻塞)可能会迅速积累。另一方面,批处理在长期维护上可能更容易,只要数据的消费者理解从数据首次发出到下游数据可用之间始终存在间隔。
最后需要考虑的是按需处理(或即时处理)。
按需或即时处理
说实话,有些问题(即查询)问得非常少,或者以一种不适合任何预定义模式的方式提出。
例如,自定义报告任务和探索性数据分析是两种适合这些范式的数据访问风格。大多数情况下,回答这些查询的后端数据直接从数据湖中加载,然后使用共享计算资源或隔离计算集群进行处理。为这些查询提供的数据可能是其他实时或接近实时系统的副产品,这些系统被处理和存储用于批处理或历史分析。
使用这种模式,数据可以解冻,并通过将记录从较慢的对象存储(如 Amazon S3)导入内存,或通过快速访问的固态硬盘(SSD),或者根据数据的大小、格式和布局,直接从云对象存储中查询。这种模式可以轻松委托给 Apache Spark,使用SparkSQL。这使得通过像 Apache Zeppelin 这样的工具进行临时分析成为可能,或通过 JDBC 绑定直接在应用程序中使用Apache Spark thrift-server和 Apache Hive Metastore。
这四种处理方式之间的区别在于时间。
回到视角和观点的概念,每种方法或模式都有其时间和地点。流处理处理的是在特定时间点捕获的事件,正如我们在本章前半部分讨论的那样,我们如何关联时间,以及我们如何捕捉和测量一系列事件(作为数据),共同绘制了当前发生的情况或过去发生的情况的画面。在我们对流处理的温和介绍中,重要的是还要讨论流处理的基础。在下一节中,我们将讨论处理连续、无界数据流的一些常见问题和解决方案。因此,讨论数据作为核心支柱并从那里扩展开来是有意义的。
希望你喜欢第九章的前半部分。如果你想继续阅读第二部分,它在下面有链接。👇
构建可靠分布式系统的架构基础。
如果你想了解更多,请查看我的书!
现代数据工程与 Apache Spark:构建关键任务流处理的实用指南
亚马逊网站:现代数据工程与 Apache Spark:构建关键任务流处理的实用指南…
一个好的描述就是你所需要的一切
如何使用少量学习来提高文本分类性能
·
关注 发表在Towards Data Science ·7 分钟阅读·2023 年 8 月 10 日
–
由Patrick Tomasso拍摄,来源于Unsplash。
我已经使用大语言模型(LLMs)一段时间了,既用于个人项目,也作为日常工作的组成部分。像许多人一样,我对这些模型的强大能力感到兴奋。然而,重要的是要知道,尽管这些模型非常强大,但仍然可以针对各种任务进行改进。
而且不,我不会写关于微调 LLMs的内容,因为这可能会很昂贵,并且通常需要一台好的 GPU 设备。事实上,我将展示一种非常简单的使用少量学习来改善你的模型的方法。
少样本学习是一种机器学习技术,其中模型通过仅使用少量示例(通常每类仅 1–5 个示例)来解决新任务。少样本学习有一些关键点:
-
从小数据中学习归纳:少样本学习方法旨在学习能够从少量示例中很好地归纳模型,这与传统的深度学习方法(需要数千或数百万个示例)形成对比。
-
迁移学习:少样本学习方法利用从解决先前任务中获得的知识,并将这些知识转移到帮助更快地学习新任务和更少的数据。这种迁移学习能力是关键。
-
学习相似度度量:一些少样本学习技术专注于学习示例之间的相似度度量。这允许将新示例与现有标记示例进行比较以进行预测。
但如何在分类问题中使用少样本学习以提高模型性能?让我们通过一个例子来演示。
数据和准备
我通过从 HuggingFace 获取数据来开始我的分析。数据集名为financial-reports-sec(该数据集具有 Apache 许可证 2.0 并允许商业使用),根据数据集作者的说法,它包含了 1993–2020 年美国上市公司向 SEC EDGAR 系统提交的年度报告(10-K 文件)。每份年度报告(10-K 文件)被分为 20 个部分。
当前任务的两个相关属性对数据很有用:
-
句子:来自 10-K 文件报告的摘录
-
部分:标记句子所属的 10-K 文件部分
我关注了三个部分:
-
业务(项目 1):描述公司的业务,包括子公司、市场、最近事件、竞争、法规和劳动力。在数据中用 0 表示。
-
风险因素(项目 1A):讨论可能影响公司的风险,例如外部因素、潜在故障和其他警告投资者的披露。用 1 表示。
-
属性(项目 2):详细说明重要的实物资产。不包括知识产权或无形资产。在数据中用 3 表示。
对于每个标签,我抽取了 10 个示例(不放回)。数据结构如下:
现成预测
一旦数据准备好,我所要做的就是制作一个分类器函数,该函数从数据框中获取句子并预测标签。
Role = '''
You are expert in SEC 10-K forms.
You will be presented by a text and you need to classify the text into either 'Item 1', 'Item 1A' or 'Item 2'.
The text only belongs to one of the mentioned categories so only return one category.
'''
def sec_classifier(text):
response = openai.ChatCompletion.create(
model='gpt-4',
messages=[
{
"role": "system",
"content": Role},
{
"role": "user",
"content": text}],
temperature=0,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0)
return response['choices'][0]['message']['content']
我在这里使用 GPT-4,因为这是迄今为止 OpenAI 最强大的模型。我还将温度设置为 0,以确保模型不会偏离轨道。真正有趣的部分是如何定义角色——这就是我可以指导模型做我想要它做的事情的地方。角色指示模型保持专注并提供我所期望的输出。为模型定义一个清晰的角色有助于生成相关的、高质量的响应。这个功能中的提示是:
你是 SEC 10-K 表格的专家。
你将收到一段文本,你需要将文本分类为‘第 1 项’,‘第 1A 项’或‘第 2 项’。
文本只属于提到的一个类别,因此只返回一个类别。
在对所有数据行应用分类功能后,我生成了一个分类报告来评估模型的性能。宏平均 F1 分数为 0.62,表明该多类问题的预测能力相当强。由于所有 3 个类别的示例数量均衡,宏平均和加权平均值收敛到相同的值。这个基准分数反映了在任何额外调整或优化之前,预训练模型的开箱即用的准确性。
precision recall f1-score support
Item 1 0.47 0.80 0.59 10
Item 1A 0.80 0.80 0.80 10
Item 2 1.00 0.30 0.46 10
accuracy 0.63 30
macro avg 0.76 0.63 0.62 30
weighted avg 0.76 0.63 0.62 30
描述是你所需要的(少样本预测)
如前所述,少样本学习就是通过少量好的示例来推广模型。为此,我通过描述第 1 项、第 1A 项和第 2 项是什么(基于维基百科)来修改了我的类别:
Role_fewshot = '''
You are expert in SEC 10-K forms.
You will be presented by a text and you need to classify the text into either 'Item 1', 'Item 1A' or 'Item 2'.
The text only belongs to one of the mentioned categories so only return one category.
In your classification take the following definitions into account:
Item 1 (i.e. Business) describes the business of the company: who and what the company does, what subsidiaries it owns, and what markets it operates in.
It may also include recent events, competition, regulations, and labor issues. (Some industries are heavily regulated, have complex labor requirements, which have significant effects on the business.)
Other topics in this section may include special operating costs, seasonal factors, or insurance matters.
Item 1A (i.e. Risk Factors) is the section where the company lays anything that could go wrong, likely external effects, possible future failures to meet obligations, and other risks disclosed to adequately warn investors and potential investors.
Item 2 (i.e. Properties) is the section that lays out the significant properties, physical assets, of the company. This only includes physical types of property, not intellectual or intangible property.
Note: Only state the Item.
'''
def sec_classifier_fewshot(text):
response = openai.ChatCompletion.create(
model='gpt-4',
messages=[
{
"role": "system",
"content": Role_fewshot},
{
"role": "user",
"content": text}],
temperature=0,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0)
return response['choices'][0]['message']['content']
现在的提示是:
你是 SEC 10-K 表格的专家。
你将收到一段文本,你需要将文本分类为‘第 1 项’,‘第 1A 项’或‘第 2 项’。
文本只属于提到的一个类别,因此只返回一个类别。
在你的分类中考虑以下定义:
第 1 项(即业务)描述了公司的业务:公司是谁,做什么,拥有哪些子公司,以及运营的市场。
它还可能包括近期事件、竞争、法规和劳动问题。(一些行业受到严格监管,具有复杂的劳动要求,这些都对业务产生重大影响。)
本节中的其他主题可能包括特殊操作成本、季节性因素或保险问题。**
第 1A 项(即风险因素)是公司列出可能出现问题的部分,包括可能的外部影响、未来未能履行义务的可能性以及其他风险,以充分警示投资者和潜在投资者。
第 2 项(即属性)是列出公司重要属性、实物资产的部分。这仅包括实物类型的财产,而不包括知识产权或无形财产。
如果我们在这些文本上运行,就会得到以下性能:
precision recall f1-score support
Item 1 0.70 0.70 0.70 10
Item 1A 0.78 0.70 0.74 10
Item 2 0.91 1.00 0.95 10
accuracy 0.80 30
macro avg 0.80 0.80 0.80 30
weighted avg 0.80 0.80 0.80 30
宏平均 F1 现在是 0.80,也就是我们的预测提升了 29%,这仅仅是通过提供每个类别的良好描述。
最终,你可以查看完整的数据集:
实际上,我提供的示例为模型提供了具体的学习实例。通过查看多个示例,模型可以推断出模式和特征,开始注意到标志总体概念的共同点和差异。这有助于模型形成更为稳健的表示。此外,提供示例本质上充当了一种弱监督形式,引导模型朝着期望的行为发展,而不依赖于大型标记数据集。
在少量样本功能中,具体示例帮助引导模型关注应注意的信息和模式。总之,具体示例对少量样本学习至关重要,因为它们为模型提供了建立新概念初步表示的锚点,然后可以在提供的少量示例中进行细化。通过特定实例的归纳学习帮助模型形成抽象概念的细致表示。
如果你喜欢阅读这些内容并希望保持联系,你可以在我的LinkedIn上找到我,或通过我的网页:iliateimouri.com
注意:所有图片,除非另有说明,均由作者提供。
《生产就绪的 RAG 应用的 12 种调整策略指南》
如何通过这些“超参数”和调整策略来提升你的检索增强生成(RAG)管道的性能
·
关注 发布于 Towards Data Science · 10 min 阅读 · 2023 年 12 月 6 日
–
《检索增强生成(RAG)应用的调整策略》
数据科学是一门实验性科学。它以“无免费午餐定理”开始,该定理指出,没有一种万能的算法可以适用于所有问题。这导致数据科学家使用实验跟踪系统来帮助他们调整机器学习(ML)项目的超参数,以实现最佳性能.
本文从数据科学家的角度审视了检索增强生成(RAG)管道。讨论了可以实验的潜在“超参数”以提高 RAG 管道的性能。类似于深度学习中的实验,在深度学习中,例如,数据增强技术不是超参数,而是一个可以调整和实验的控制旋钮,本文还将涵盖可以应用的不同策略,这些策略本身不一定是超参数。
## 检索增强生成(RAG):从理论到 LangChain 实现
从原始学术论文的理论到使用 OpenAI、Weaviate 和 LangChain 的 Python 实现
towardsdatascience.com
本文涵盖了按相关阶段排序的“超参数”。在 RAG 管道的吞吐阶段,你可以通过以下方法实现性能提升:
-
数据清洗
-
数据分块
-
嵌入模型
-
元数据
-
多索引
-
索引算法
在推理阶段(检索和生成),你可以调整:
-
查询转换
-
检索参数
-
高级检索策略
-
重新排序模型
-
大型语言模型(LLMs)
-
提示工程
请注意,本文涵盖了 RAG 的文本使用案例。对于多模态 RAG 应用,可能需要不同的考虑因素。
吞吐阶段
吞吐阶段是构建 RAG 管道的准备步骤,类似于 ML 管道中的数据清洗和预处理步骤。通常,吞吐阶段包括以下步骤:
-
收集数据
-
数据分块
-
生成数据块的向量嵌入
-
在向量数据库中存储向量嵌入和数据块
RAG 管道的吞吐阶段
本节讨论了在推理阶段可以应用和调整的有影响力的技术和超参数,以提高检索到的上下文的相关性。
数据清洗
像任何数据科学管道一样,你的数据质量对 RAG 管道的结果有着重要影响[8, 9]。在进行以下步骤之前,确保你的数据符合以下标准:
-
清洁:应用至少一些自然语言处理常用的基本数据清理技术,如确保所有特殊字符都正确编码。
-
正确:确保你的信息一致且事实准确,以避免信息冲突让你的 LLM 感到困惑。
分块
在 RAG 管道中,分块你的文档是对外部知识源的一个关键准备步骤,可能会影响性能[1, 8, 9]。这是一种生成逻辑上连贯的信息片段的技术,通常通过将长文档拆分成较小的部分(但也可以将较小的片段合并成连贯的段落)。
你需要考虑的一个方面是分块技术的选择。例如,在LangChain 中,不同的文本分割器根据不同的逻辑拆分文档,如按字符、标记等。这取决于你拥有的数据类型。例如,如果你的输入数据是代码与 Markdown 文件,你将需要使用不同的分块技术。
理想的块长度(**chunk_size**
**)**取决于你的使用场景:如果你的使用场景是问答,你可能需要较短的具体块;但如果你的使用场景是摘要,你可能需要较长的块。此外,如果块太短,可能包含的上下文不够。另一方面,如果块太长,可能包含过多的无关信息。
此外,你还需要考虑块之间的“滚动窗口”(**overlap**
**)**以引入一些额外的上下文。
嵌入模型
嵌入模型是你检索的核心。嵌入的质量对检索结果有着重大影响[1, 4]。通常,生成的嵌入维度越高,嵌入的精度也越高。
关于可用的替代嵌入模型,你可以查看MASSIVE TEXT EMBEDDING BENCHMARK (MTEB)排行榜,该排行榜涵盖了 164 种文本嵌入模型(截至本文撰写时)。
[## MTEB 排行榜 - 由 mteb 提供的 Hugging Face 空间
发现社区制作的精彩 ML 应用
huggingface.co](https://huggingface.co/spaces/mteb/leaderboard?source=post_page-----7ca646833439--------------------------------)
虽然你可以直接使用通用的嵌入模型,但在某些情况下,对你的嵌入模型进行微调可能更有意义,以避免之后出现领域外的问题 [9]。根据 LlamaIndex 进行的实验,微调你的嵌入模型可以导致检索评估指标性能提高 5–10% [2]。
请注意,并非所有嵌入模型都可以微调(例如,OpenAI 的 [text-embedding-ada-002](https://platform.openai.com/docs/guides/fine-tuning)
目前不能进行微调)。
元数据
当你将向量嵌入存储在一个向量数据库中时,一些向量数据库允许你将它们与元数据(或未向量化的数据)一起存储。用元数据注释向量嵌入对搜索结果的额外后处理可能是有帮助的,例如元数据过滤 [1, 3, 8, 9]。例如,你可以添加元数据,如日期、章节或子章节参考。
多重索引
如果元数据不足以提供额外的信息以逻辑地分隔不同类型的上下文,你可能需要尝试多重索引 [1, 9]。例如,你可以为不同类型的文档使用不同的索引。注意,在检索时你需要进行一些索引路由 [1, 9]。如果你对元数据和分离集合有更深入的兴趣,你可能想了解更多关于原生多租户的概念。
索引算法
为了在大规模上实现快速相似性搜索,向量数据库和向量索引库使用近似最近邻(ANN)搜索而不是 k-最近邻(kNN)搜索。顾名思义,ANN 算法近似最近邻,因此可能不如 kNN 算法精确。
你可以尝试不同的 ANN 算法,例如Facebook Faiss(聚类)、Spotify Annoy(树)、Google ScaNN(向量压缩)和HNSWLIB(邻近图)。此外,这些 ANN 算法中的许多都有一些参数可以调整,例如 HNSW 的ef
、efConstruction
和maxConnections
[1]。
此外,你可以为这些索引算法启用向量压缩。类似于 ANN 算法,向量压缩会导致一定的精度损失。然而,根据向量压缩算法的选择及其调整,你也可以对此进行优化。
然而,在实践中,这些参数通常由向量数据库和向量索引库的研究团队在基准测试实验期间进行调整,而不是由 RAG 系统的开发人员进行调整。不过,如果你想通过调整这些参数来挤出最后的性能提升,我推荐这篇文章作为起点:
[## 关于 RAG 评估的概述 | Weaviate - 向量数据库
了解 RAG 评估中的新趋势及当前的最新技术。
weaviate.io](https://weaviate.io/blog/rag-evaluation?source=post_page-----7ca646833439--------------------------------#indexing-knobs)
推理阶段(检索与生成)
RAG 管道的主要组成部分是检索和生成组件。本节主要讨论提高检索(查询转换、检索参数、高级检索策略和重新排序模型)的策略,因为这是两个组件中影响更大的部分。但它也简要涉及一些提高生成(LLM 和提示工程)的策略。
RAG 管道的推理阶段
查询转换
由于在 RAG 管道中检索附加上下文的搜索查询也被嵌入到向量空间中,因此其措辞也会影响搜索结果。因此,如果你的搜索查询没有产生令人满意的结果,你可以尝试各种查询转换技术 [5, 8, 9],例如:
-
改写: 使用 LLM 改写查询并重试。
-
假设文档嵌入(HyDE): 使用 LLM 生成对搜索查询的假设响应,并将两者用于检索。
-
子查询: 将较长的查询拆分为多个较短的查询。
检索参数
检索是 RAG 管道的一个重要组成部分。首先需要考虑的是语义搜索是否足够满足你的用例,还是你想尝试混合搜索。
在后一种情况下,你需要尝试对混合搜索中的稀疏和密集检索方法的加权进行实验[1, 4, 9]。因此,调整参数**alpha**
,即控制语义搜索(**alpha = 1**
)和基于关键词的搜索(**alpha = 0**
**)之间加权的参数,将变得必要。
如何通过将传统的基于关键词的搜索与现代向量搜索相结合来找到更相关的搜索结果
此外,检索的搜索结果数量将发挥重要作用。检索的上下文数量将影响所用上下文窗口的长度(见 Prompt Engineering)。此外,如果你使用的是重排序模型,你需要考虑输入模型的上下文数量(见 Re-ranking models)。
注意,虽然用于语义搜索的相似度度量是一个可以更改的参数,你不应随意实验,而是应根据所用的嵌入模型设置(例如,[text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
支持余弦相似度或 [multi-qa-MiniLM-l6-cos-v1](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1#technical-details)
支持余弦相似度、点积和欧几里得距离)。
高级检索策略
本节技术上可以作为一篇独立的文章。为了本概述,我们将尽量简洁。有关以下技术的详细说明,我推荐这个 DeepLearning.AI 课程:
www.deeplearning.ai/short-courses/building-evaluating-advanced-rag/?source=post_page-----7ca646833439--------------------------------
[## 构建和评估高级 RAG 应用
学习句子窗口检索和自动合并检索等方法,提高你的 RAG 流水线的性能……
本节的基本思想是检索的块不一定要与生成所用的块相同。理想情况下,你会为检索嵌入较小的块(见 Chunking),但检索更大的上下文。[7]
-
句子窗口检索: 不仅检索相关句子,还要检索在检索句子之前和之后的适当句子。
-
自动合并检索: 文档以树状结构组织。在查询时,可以将分开但相关的小块合并成一个更大的上下文。
重排序模型
虽然语义搜索根据与搜索查询的语义相似性检索上下文,但“最相似”并不一定意味着“最相关”。重排序模型,如 Cohere’s Rerank 模型,可以通过计算每个检索上下文对查询的相关性分数来帮助消除不相关的搜索结果 [1, 9]。
“最相似”并不一定意味着“最相关”
如果你使用的是重排序模型,你可能需要重新调整搜索结果数量以供重排序模型输入,并决定你希望将多少个重排序的结果输入到 LLM 中。
与嵌入模型一样,你可能还想尝试对重排序模型进行微调以适应你的特定用例。
LLMs
LLM 是核心组件,用于生成响应。类似于嵌入模型,根据你的要求(如开放 vs. 专有模型、推理成本、上下文长度等),你可以选择不同的 LLM。[1]
与嵌入模型或重排序模型一样,你可能想要尝试对 LLM 进行微调以适应你的特定用例,以融入特定的措辞或语气。
提示工程
你如何表述或工程化你的提示将显著影响 LLM 的完成质量[1, 8, 9]。
Please base your answer only on the search results and nothing else!
Very important! Your answer MUST be grounded in the search results provided.
Please explain why your answer is grounded in the search results!
此外,在提示中使用少量示例可以提高完成的质量。
如检索参数中提到的,输入提示的上下文数量是你应该尝试的一个参数[1]。虽然随着相关上下文的增加,你的 RAG 管道性能可能会提高,但你也可能会遇到“在中间迷失”[6]效应,即如果相关上下文被放置在许多上下文的中间,LLM 可能不会将其识别为相关。
摘要
随着越来越多的开发者获得原型开发 RAG 管道的经验,讨论将 RAG 管道带到生产就绪性能的策略变得越来越重要。本文讨论了不同的“超参数”和在 RAG 管道的相关阶段中可以调整的其他参数:
本文涵盖了摄取阶段中的以下策略:
-
数据清理:确保数据是干净和正确的。
-
分块:选择分块技术、分块大小(
chunk_size
)和分块重叠(overlap
)。 -
嵌入模型:选择嵌入模型,包括维度,以及是否进行微调。
-
元数据:是否使用元数据及其选择。
-
多索引:决定是否对不同的数据集合使用多个索引。
-
索引算法:选择和调整 ANN 和向量压缩算法,通常不由从业者进行调整。
以及在推理阶段(检索和生成)中的以下策略:
-
查询转换:尝试重新表述、HyDE 或子查询。
-
检索参数:选择搜索技术(如果启用了混合搜索,则为
alpha
)和检索结果的数量。 -
高级检索策略:是否使用高级检索策略,如句子窗口或自动合并检索。
-
Re-ranking models:是否使用重新排序模型、选择重新排序模型、输入到重新排序模型中的搜索结果数量以及是否对重新排序模型进行微调。
-
LLMs:选择 LLM 和是否对其进行微调。
-
Prompt engineering:尝试不同的措辞和少量示例。
享受了这个故事吗?
免费订阅 以获取我发布新故事时的通知。
[## 每当 Leonie Monigatti 发布新内容时获取电子邮件通知。
每当 Leonie Monigatti 发布新内容时,获取电子邮件通知。注册后,如果你还没有 Medium 账户,将会创建一个…
medium.com](https://medium.com/@iamleonie/subscribe?source=post_page-----7ca646833439--------------------------------)
在 LinkedIn,Twitter,以及 Kaggle上找到我!
参考文献
文献
[1] Connor Shorten 和 Erika Cardenas(2023)。Weaviate 博客。RAG 评估概述(访问日期:2023 年 11 月 27 日)
[2] Jerry Liu(2023)。LlamaIndex 博客。使用合成数据对 RAG 进行嵌入微调(访问日期:2023 年 11 月 28 日)
[3] LlamaIndex 文档(2023)。为生产构建高性能 RAG 应用程序(访问日期:2023 年 11 月 28 日)
[4] Voyage AI(2023)。嵌入推动 RAG 的质量:Chat.LangChain 的案例研究(访问日期:2023 年 12 月 5 日)
[5] LlamaIndex 文档(2023)。查询转换(访问日期:2023 年 11 月 28 日)
[6] Liu, N. F., Lin, K., Hewitt, J., Paranjape, A., Bevilacqua, M., Petroni, F., & Liang, P.(2023)。《迷失在中间:语言模型如何使用长上下文》。arXiv 预印本 arXiv:2307.03172。
[7] DeepLearning.AI(2023)。构建和评估高级 RAG 应用程序(访问日期:2023 年 12 月 4 日)
[8] Ahmed Besbes(2023)。Towards Data Science。为什么你的 RAG 在生产环境中不可靠(访问日期:2023 年 11 月 27 日)
[9] Matt Ambrogi(2023 年)。面向数据科学。提高检索增强生成系统性能的 10 种方法(访问日期:2023 年 11 月 27 日)
图片
除非另有说明,所有图片均由作者创作。