超越流失预测和流失提升
原文:
towardsdatascience.com/beyond-churn-prediction-and-churn-uplift-45225e5a7541
因果数据科学
如何在存在流失的情况下最佳地针对政策
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 7 月 25 日
–
封面,图片由作者提供
数据科学中的一个非常常见的任务是流失预测。然而,预测流失往往只是一个中间步骤,很少是最终目标。通常,我们实际关心的是减少流失,这是一个独立的目标,不一定相关。事实上,例如,知道长期客户比新客户更不容易流失并不是一个可操作的见解,因为我们无法增加客户的留存时间。我们想知道的是某一(或多种)措施如何影响流失。这通常被称为流失提升。
在本文中,我们将超越流失预测和流失提升,转而考虑流失预防活动的终极目标:增加收入。首先,减少流失的政策可能也会对收入产生影响,这应予以考虑。然而,更重要的是,增加收入只有在客户不流失的情况下才是相关的。相反,减少流失对高收入客户更为相关。这种流失与收入之间的互动在理解任何措施活动的盈利性时至关重要,不应被忽视。
礼品和订阅
在接下来的部分中,我们将使用一个玩具示例来说明主要观点。假设我们是一家希望减少客户流失并最终增加收入的公司。假设我们决定测试一个新想法:向用户发送1 美元的礼品。为了测试这一措施是否有效,我们随机地仅将其发送给我们的客户基础中的一部分样本。
cost = 1
让我们看看我们手头的数据。我从src.dgp
导入数据生成过程dgp_gift()
。我还从src.utils
导入了一些绘图函数和库。
from src.utils import *
from src.dgp import dgp_gift
dgp = dgp_gift(n=100_000)
df = dgp.generate_data()
df.head()
数据快照,图片由作者提供
我们有100_000
个客户的信息,我们观察到他们作为活跃客户的months
数量、上个月他们产生的收入(rev_old
)、上个月与之前一个月的收入变化(rev_change
)、他们是否被随机送了gift
以及两个感兴趣的结果:churn
,即他们是否不再是活跃客户,以及他们在当前月产生的revenue
。我们用字母Y表示结果,用字母W表示处理,用字母X表示其他变量。
Y = ['churn', 'revenue']
W = 'gift'
X = ['months', 'rev_old', 'rev_change']
注意:为了简化起见,我们考虑数据的单期快照,并仅用几个变量总结数据的面板结构。通常,我们会有更长的时间序列,但对于结果(例如,客户终身价值)来说,时间范围也会更长。
我们可以用以下**有向无环图(DAG)**来表示潜在的数据生成过程。节点代表变量,箭头代表潜在的因果关系。我用绿色标出了两个感兴趣的关系:gift
对churn
和revenue
的影响。请注意,churn
与 revenue 相关,因为流失的客户根据定义不会产生收入。
数据生成过程的 DAG,图片由作者提供
重要的是,过去的收入和收入变化是churn
和revenue
的预测因子,但与我们的干预无关。相反,干预根据客户的总活跃months
对churn
和revenue
有不同的影响。
尽管简单,这个数据生成过程旨在捕捉一个重要的洞察:预测churn
或revenue
的变量不一定是预测churn
或revenue
提升的变量。稍后我们将看到这如何影响我们的分析。
首先,让我们开始探索数据。
探索性数据分析
让我们从churn
开始。公司上个月流失了多少客户?
df.churn.mean()
0.19767
公司上个月几乎*损失了 20%*的客户!gift
是否有助于防止流失?
我们想要比较收到礼物的客户的流失频率与未收到礼物的客户的流失频率。由于礼物是随机发放的,因此均值差异估计量是gift
对churn
的平均处理效应(ATE)的无偏估计量。
平均处理效应,图片由作者提供
我们通过线性回归计算均值差异估计。我们还包括其他协变量以提高估计器的效率。
smf.ols("churn ~ " + W + " + " + " + ".join(X), data=df).fit().summary().tables[1]
流失回归表,图片来源:作者
看起来gift
将流失率降低了大约11个百分点,即几乎是基准水平*32%*的三分之一!它是否对revenue
也有影响?
至于流失率,我们将revenue
对gift
,即我们的处理变量,进行回归,以估计平均处理效果。
smf.ols("revenue ~ " + W + " + " + " + ".join(X), data=df).fit().summary().tables[1]
收入回归表,图片来源:作者
看起来gift
平均增加的收入为0.63$,这意味着它并不盈利。这是否意味着我们应该停止向客户赠送礼物?这要看情况。事实上,礼物可能对某些客户群体有效。我们只需要识别这些客户群体。
定向策略
在本节中,我们将尝试通过针对特定客户来了解是否存在数据驱动的盈利发送gift
的方法。特别是,我们将比较不同的目标策略,目的是增加收入。
在本节中,我们将需要一些算法来预测revenue
或churn
,或接收gift
的概率。我们使用来自[lightgbm](https://lightgbm.readthedocs.io/en/latest/index.html)
库的梯度提升树模型。我们对所有策略使用相同的模型,以便我们无法将性能差异归因于预测准确性。
from lightgbm import LGBMClassifier, LGBMRegressor
要评估每项政策,我们希望在单独的验证数据集中,将有政策的隐含利润Π⁽¹⁾与没有政策的隐含利润Π⁽⁰⁾进行比较。我们称这两个量为潜在结果,它们的差异为利润提升,用τ表示。请注意,提升从未被观察到,因为对于每个客户,我们只能观察到两种潜在结果之一,即有或没有gift
。然而,由于我们使用的是合成数据,我们可以进行预言评估。如果你想了解如何用真实数据评估提升模型,我推荐我的入门文章。
行业中因果推断的最广泛应用之一是提升建模,也称为估计…
towardsdatascience.com
首先,让我们定义利润Π为当客户不流失时的净收入R。
利润公式,图片来源:作者
因此,对于处理过的个体,总体利润效应由两个假设量的差异给出:处理时的利润Π⁽¹⁾减去未处理时的利润Π⁽⁰⁾。
盈利提升公式,作者提供的图像
对未处理个体的效果为零。
def evaluate_policy(policy):
data = dgp.generate_data(seed_data=4, seed_assignment=5, keep_po=True)
data['profits'] = (1 - data.churn) * data.revenue
baseline = (1-data.churn_c) * data.revenue_c
effect = policy(data) * (1-data.churn_t) * (data.revenue_t-cost) + (1-policy(data)) * (1-data.churn_c) * data.revenue_c
return np.sum(effect - baseline)
1. 针对流失客户
第一个策略可能是仅针对流失客户。假设我们只将礼物
发送给预测流失率高于平均水平的客户。
model_churn = LGBMClassifier().fit(X=df[X], y=df['churn'])
policy_churn = lambda df : (model_churn.predict_proba(df[X])[:,1] > df.churn.mean())
evaluate_policy(policy_churn)
-5497.46
该策略没有盈利,并且会导致总计亏损超过5000$。
你可能认为问题在于任意阈值,但事实并非如此。下面我绘制了所有可能策略阈值的总体效果。
x = np.linspace(0, 1, 100)
y = [evaluate_policy(lambda df : (model_churn.predict_proba(df[X])[:,1] > p)) for pin x]
fig, ax = plt.subplots(figsize=(10, 3))
sns.lineplot(x=x, y=y).set(xlabel='Churn Policy Threshold', title='Aggregate Effect');
ax.axhline(y=0, c='k', lw=3, ls='--');
按流失阈值的总体效果,作者提供的图像
正如我们所见,无论阈值如何,基本上都不可能获得任何利润。
问题在于,客户可能流失并不意味着礼物
会对他们的流失概率产生任何影响。这两个度量并不是完全无关的(例如,我们不能降低那些流失概率为 0%的客户的流失概率),但它们并不是同一回事。
2. 针对收益客户
现在我们尝试另一种策略:我们只将礼物发送给高收益客户。例如,我们可以只将礼物发送给按收益排名前 10%的客户。这个想法是,如果该策略确实能降低流失,那么这些客户就是那些降低流失最具盈利性的客户。
model_revenue = LGBMRegressor().fit(X=df[X], y=df['revenue'])
policy_revenue = lambda df : (model_revenue.predict(df[X]) > np.quantile(df.revenue, 0.9))
evaluate_policy(policy_revenue)
-4730.82
该策略再次没有盈利,导致了相当大的亏损。如之前所示,这不是选择阈值的问题,正如下图所示。我们能做的最好的是设置一个很高的阈值,这样我们不处理任何人,从而实现零利润。
x = np.linspace(0, 100, 100)
y = [evaluate_policy(lambda df : (model_revenue.predict(df[X]) > c)) for c in x]
fig, ax = plt.subplots(figsize=(10, 3))
sns.lineplot(x=x, y=y).set(xlabel='Revenue Policy Threshold', title='Aggregate Effect');
ax.axhline(y=0, c='k', lw=3, ls='--');
按收益阈值的总体效果,作者提供的图像
问题在于,在我们的设置中,高收益客户的流失概率并没有下降到足以使礼物
有盈利的程度。这部分也是因为现实中常观察到的情况,即高收益客户本身就是流失可能性最小的客户。
现在让我们考虑一组更相关的策略:基于提升的策略。
3. 针对流失提升客户
更合理的方法是针对那些在接受1$ 礼物
时,流失
概率降低最多的客户。我们使用双重稳健估计器来估计流失提升,这是表现最好的提升模型之一。如果你对元学习者不熟悉,我建议从我的介绍文章开始。
在许多情况下,我们不仅仅关心估计因果效应,还关心这个效应是否…
towardsdatascience.com
我们从 econml 导入双重稳健学习器,这是一个微软库。
from econml.dr import DRLearner
DR_learner_churn = DRLearner(model_regression=LGBMRegressor(), model_propensity=LGBMClassifier(), model_final=LGBMRegressor())
DR_learner_churn.fit(df['churn'], df[W], X=df[X]);
既然我们已经估计了客户流失的提升,我们可能会倾向于只针对那些具有高负面提升的客户(负面,因为我们想要减少流失)。例如,我们可能会将礼物
发送给所有估计流失高于平均水平的客户。
policy_churn_lift = lambda df : DR_learner_churn.effect(df[X]) < - np.mean(df.churn)
evaluate_policy(policy_churn_lift)
-3925.24
该政策仍然不盈利,导致了接近*4000$*的损失。
问题在于我们没有考虑政策的成本。实际上,降低流失概率仅对高收入客户有利。以极端情况为例:避免流失一个不产生任何收入的客户是没有价值的干预。
因此,我们只将礼物
发送给那些其流失概率加权收入下降幅度大于礼物成本的客户。
model_revenue_1 = LGBMRegressor().fit(X=df.loc[df[W] == 1, X], y=df.loc[df[W] == 1, 'revenue'])
policy_churn_lift = lambda df : - DR_learner_churn.effect(df[X]) * model_revenue_1.predict(df[X]) > cost
evaluate_policy(policy_churn_lift)
318.03
这一政策最终是盈利的!
然而,我们仍然没有考虑到一个渠道:干预也可能会影响现有客户的收入。
4. 目标收入提升客户
与前一种方法对称的方法是只考虑对收入
的影响,忽略对流失的影响。我们可以估计非流失客户的收入
提升,只处理那些在流失后对收入的增量效果大于礼物
成本的客户。
DR_learner_netrevenue = DRLearner(model_regression=LGBMRegressor(), model_propensity=LGBMClassifier(), model_final=LGBMRegressor())
DR_learner_netrevenue.fit(df.loc[df.churn==0, 'revenue'], df.loc[df.churn==0, W], X=df.loc[df.churn==0, X]);
model_churn_1 = LGBMClassifier().fit(X=df.loc[df[W] == 1, X], y=df.loc[df[W] == 1, 'churn'])
policy_netrevenue_lift = lambda df : DR_learner_netrevenue.effect(df[X]) * (1-model_churn_1.predict(df[X])) > cost
evaluate_policy(policy_netrevenue_lift)
50.80
这一政策也有利可图,但忽略了对流失的影响。我们如何将这一政策与之前的政策结合起来?
5. 目标收入提升客户
高效结合对流失和对净收入影响的最佳方法就是估计总收入提升。隐含的最佳政策是处理那些总收入提升大于礼物
成本的客户。
DR_learner_revenue = DRLearner(model_regression=LGBMRegressor(), model_propensity=LGBMClassifier(), model_final=LGBMRegressor())
DR_learner_revenue.fit(df['revenue'], df[W], X=df[X]);
policy_revenue_lift = lambda df : (DR_learner_revenue.effect(df[X]) > cost)
evaluate_policy(policy_revenue_lift)
2028.21
这一政策迄今为止表现最好,产生了超过*2000$*的总利润!
policies = [policy_churn, policy_revenue, policy_churn_lift, policy_netrevenue_lift, policy_revenue_lift]
df_results = pd.DataFrame()
df_results['policy'] = ['churn', 'revenue', 'churn_L', 'netrevenue_L', 'revenue_L']
df_results['value'] = [evaluate_policy(policy) for policy in policies]
fig, ax = plt.subplots()
sns.barplot(df_results, x='policy', y='value').set(title='Overall Incremental Effect')
plt.axhline(0, c='k');
比较政策,图像由作者提供
直觉与分解
如果我们比较不同的政策,很明显,针对高收入或高流失概率客户是最糟糕的选择。这并不总是如此,但在我们的模拟数据中发生了这种情况,因为有两个在许多实际场景中也很常见的事实:
-
收入与流失概率呈负相关
-
礼物
对流失
(或收入
)的影响与基准值并没有强烈的负相关(或对收入
的正相关)
这些事实中的任何一个都足以使得以收入或流失为目标的策略变得不佳。应该关注的是那些具有高增量效果的客户。而且,最好直接使用感兴趣的变量,即在这种情况下的收入
,只要有数据。
为了更好地理解机制,我们可以分解政策对利润的整体影响为三个部分。
利润提升分解,图像由作者提供
这意味着有三个渠道使得对客户的处理具有盈利性。
-
如果这是一个高收入客户并且处理减少了其流失概率
-
如果这是一个非流失客户并且处理增加了其收入
-
如果处理对其收入和流失概率都有强影响
通过流失提升进行的目标定位仅利用第一个渠道,通过净收入提升进行的目标定位仅利用第二个渠道,而通过总收入提升进行的目标定位利用所有三个渠道,使其成为最有效的方法。
奖金:加权
如Lemmens, Gupta (2020)所强调,有时在估计提升时可能值得对观察值进行加权。特别是,可能值得对接近处理政策阈值的观察值给予更多权重。
观点是,加权通常会降低估计量的效率。然而,我们并不关心对所有观察值获得正确的估计,而是关心正确估计政策阈值。实际上,无论你估计*1
∗
还是
∗
1000
*还是*1000
∗还是∗1000的净利润都无关紧要:隐含的政策是相同的:发送礼物
。然而,估计1
∗
的净利润而不是
∗
−
1
*的净利润而不是*-1
∗的净利润而不是∗−1*会颠倒政策的含义。因此,距离阈值的准确性大幅下降有时值得在阈值处的小幅提高准确性。
让我们尝试使用负指数权重,距离阈值越远权重越低。
DR_learner_revenue_w = DRLearner(model_regression=LGBMRegressor(), model_propensity=LGBMClassifier(), model_final=LGBMRegressor())
w = np.exp(1 + np.abs(DR_learner_revenue.effect(df[X]) - cost))
DR_learner_revenue_w.fit(df['revenue'], df[W], X=df[X], sample_weight=w);
policy_revenue_lift_w = lambda df : (DR_learner_revenue_w.effect(df[X]) > cost)
evaluate_policy(policy_revenue_lift_w)
1398.19
在我们的情况下,加权是不值得的:隐含的政策仍然有利可图,但不如未加权模型所获得的2028$。
结论
在这篇文章中,我们探讨了为什么以及如何应当超越流失预测和流失提升建模。特别是,应集中于提高盈利的最终业务目标。这意味着将重点从预测转移到提升,同时将流失和收入结合成一个结果。
一个重要的警告涉及到数据的维度。我们使用了一个玩具数据集,这在至少两个维度上高度简化了问题。首先,回溯,我们通常有更长时间的时间序列,这些时间序列可以(并且应该)用于预测和建模目的。其次,前瞻,应该将流失与长期的客户盈利估计相结合,通常称为客户生命周期价值。
参考文献
-
Kennedy (2022), “朝着最优双重稳健估计异质因果效应”
-
Bonvini, Kennedy, Keele (2021), “最小最大最优子组识别”
-
Lemmens, Gupta (2020), “管理流失以最大化利润”
相关文章
-
评估提升模型
-
理解元学习者
-
理解 AIPW,即双重稳健估计器
Code
你可以在这里找到原始的 Jupyter Notebook:
## Blog-Posts/notebooks/beyond_churn.ipynb at main · matteocourthoud/Blog-Posts
这是我在 Medium 博客文章中的代码和笔记本。通过创建一个…来贡献于 matteocourthoud/Blog-Posts 的开发。
感谢阅读!
我非常感谢! 🤗 如果你喜欢这篇文章并想看到更多内容,请考虑 关注我。我每周发布一次关于因果推断和数据分析的内容。我尝试保持文章简洁而准确,始终提供代码、示例和模拟。
此外,一个小小的 免责声明:我写作是为了学习,因此错误是常有的事,尽管我尽力而为。如果你发现了错误,请告诉我。我也很感激对新主题的建议!
超越英语:实现多语言 RAG 解决方案
实施非英语检索增强生成(RAG)系统时的注意事项
·
阅读 在 Towards Data Science 发布 · 18 分钟阅读 · 2023 年 12 月 20 日
–
RAG,一个无所不知的同事,全天候提供服务(图片由作者使用 Dall-E 3 生成)
TLDR
这篇文章介绍了在开发非英语 RAG 系统时应考虑的因素,并提供了具体的示例和技术。关键点包括:
-
在数据加载过程中优先保持句法结构,因为这对有意义的文本分割至关重要。
-
使用简单分隔符如\n\n 来格式化文档,以促进高效的文本拆分。
-
选择基于规则的文本分割器,因为在多语言环境中,基于 ML 的语义分割器计算强度大且性能较差。
-
在选择嵌入模型时,考虑其多语言能力和不对称检索性能。
-
对于多语言项目,通过大语言模型 (LLM) 微调嵌入模型可以提高性能,可能需要以实现足够的准确性。
-
强烈推荐实施基于 LLM 的检索评估基准,以有效微调 RAG 系统的超参数,并且可以利用现有框架轻松完成。
RAG 成为 2023 年搜索技术中最流行的术语也就不足为奇了。检索增强生成 (RAG) 正在改变组织利用其大量现有数据来推动智能聊天机器人的方式。这些能够进行自然语言对话的机器人,可以利用组织的集体知识,充当一个始终可用的内部专家,提供基于经验证数据的相关答案。虽然有大量资源可用于构建 RAG 系统,但大多数资源针对的是英语,较小语言的资源仍有缺口。
这份易于遵循的 6 步指南将引导你了解在为非英语语言创建 RAG 系统时的注意事项。
RAG 结构,简要回顾
本文假设读者对嵌入、向量和标记等概念有一定了解。对于需要简要回顾 RAG 系统架构的人来说,它们主要由两个核心组件组成:
-
索引阶段(本文的重点):这一初始阶段涉及处理输入数据。数据首先被加载、适当格式化,然后进行拆分。之后,数据通过嵌入技术进行向量化,最终存储在知识库中以便将来检索。
-
生成阶段:在此阶段,用户的查询被输入到检索系统中。该系统随后从知识库中提取相关信息片段。利用大语言模型 (LLM),系统解释这些数据以制定连贯的自然语言响应,有效地解答用户的询问。
现在让我们开始吧!
免责声明:
本指南并不旨在成为使用任何特定工具的详尽手册。相反,其目的是阐明应指导工具选择的总体决策。实际上,我强烈建议利用已建立的框架来构建系统基础。对于构建 RAG 系统,我个人推荐 LlamaIndex,因为它们提供了详细的指南和专注于索引和检索优化的功能。
此外,本指南假设我们处理的是使用拉丁字母并从左向右阅读的语言。这包括德语、法语、西班牙语、捷克语、土耳其语、越南语、挪威语、波兰语以及其他一些语言。其他语言可能有不同的需求和考虑因素。
1. 数据加载器:关键在于细节
一个外观酷炫的多模态数据加载器(图像由作者使用 Dall-E 3 生成)
RAG 系统的第一步是使用数据加载器处理各种格式,从文本文件到多媒体,提取所有相关内容以供进一步处理。对于基于文本的格式,数据加载器通常在不同语言间表现一致,因为它们不涉及特定语言的处理。然而,随着多模态 RAG 系统的出现,了解语音转文本模型在与其英语对应模型相比性能降低的情况非常重要。像Whisper v3这样的模型展示了令人印象深刻的多语言能力,但最好查看它们在Mozilla Common Voice或Fleurs数据集上的表现,并且最好在自己的基准上进行评估。
本文其余部分将集中讨论基于文本的输入。
保留句法结构为何重要
数据加载的一个关键方面是保持原始数据的句法完整性。丢失诸如标题或段落结构的元素可能会影响后续信息检索的准确性。对于非英语语言,这种担忧尤为突出,因为基于机器学习的分段工具的可用性有限。
句法信息发挥着至关重要的作用,因为 RAG 系统在提供有意义答案的效果部分取决于它们将数据拆分为语义准确的子部分的能力。
为了突出保留结构的数据加载方法与不保留结构的方法之间的区别,举一个使用基础 HTML 数据加载器与 PDF 加载器对medium article的例子。像LangChain和LlamaIndex这样的库都依赖于完全相同的库,但只是将函数封装在各自的文档类中(Web 用 Requests+BS4,PDF 用 PyPDF2)。
HTML 数据加载器:此方法保留了内容的句法结构。
import requests
from bs4 import BeautifulSoup
url = "https://medium.com/llamaindex-blog/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83"
soup = BeautifulSoup(requests.get(url).text, 'html.parser')
filtered_tags = soup.find_all(['h1', 'h2', 'h3', 'h4', 'p'])
filtered_tags[:14]
<p class="be b dw dx dy dz ea eb ec ed ee ef dt"><span><a class="be b dw dx eg dy dz eh ea eb ei ec ed ej ee ef ek el em eo ep eq er es et eu ev ew ex ey ez fa bl fb fc" data-testid="headerSignUpButton" href="https://medium.com/m/signin?operation=register&redirect=https%3A%2F%2Fblog.llamaindex.ai%2Fboosting-rag-picking-the-best-embedding-reranker-models-42d079022e83&source=post_page---two_column_layout_nav-----------------------global_nav-----------" rel="noopener follow">Sign up</a></span></p>
<p class="be b dw dx dy dz ea eb ec ed ee ef dt"><span><a class="af ag ah ai aj ak al am an ao ap aq ar as at" data-testid="headerSignInButton" href="https://medium.com/m/signin?operation=login&redirect=https%3A%2F%2Fblog.llamaindex.ai%2Fboosting-rag-picking-the-best-embedding-reranker-models-42d079022e83&source=post_page---two_column_layout_nav-----------------------global_nav-----------" rel="noopener follow">Sign in</a></span></p>
<p class="be b dw dx dy dz ea eb ec ed ee ef dt"><span><a class="be b dw dx eg dy dz eh ea eb ei ec ed ej ee ef ek el em eo ep eq er es et eu ev ew ex ey ez fa bl fb fc" data-testid="headerSignUpButton" href="https://medium.com/m/signin?operation=register&redirect=https%3A%2F%2Fblog.llamaindex.ai%2Fboosting-rag-picking-the-best-embedding-reranker-models-42d079022e83&source=post_page---two_column_layout_nav-----------------------global_nav-----------" rel="noopener follow">Sign up</a></span></p>
<p class="be b dw dx dy dz ea eb ec ed ee ef dt"><span><a class="af ag ah ai aj ak al am an ao ap aq ar as at" data-testid="headerSignInButton" href="https://medium.com/m/signin?operation=login&redirect=https%3A%2F%2Fblog.llamaindex.ai%2Fboosting-rag-picking-the-best-embedding-reranker-models-42d079022e83&source=post_page---two_column_layout_nav-----------------------global_nav-----------" rel="noopener follow">Sign in</a></span></p>
<h1 class="pw-post-title gp gq gr be gs gt gu gv gw gx gy gz ha hb hc hd he hf hg hh hi hj hk hl hm hn ho hp hq hr bj" data-testid="storyTitle" id="f2a9">Boosting RAG: Picking the Best Embedding & Reranker models</h1>
<p class="be b iq ir bj"><a class="af ag ah ai aj ak al am an ao ap aq ar is" data-testid="authorName" href="https://ravidesetty.medium.com/?source=post_page-----42d079022e83--------------------------------" rel="noopener follow">Ravi Theja</a></p>
<p class="be b iq ir dt"><span><a class="iv iw ah ai aj ak al am an ao ap aq ar eu ix iy" href="https://medium.com/m/signin?actionUrl=https%3A%2F%2Fmedium.com%2F_%2Fsubscribe%2Fuser%2F60738cbbc7df&operation=register&redirect=https%3A%2F%2Fblog.llamaindex.ai%2Fboosting-rag-picking-the-best-embedding-reranker-models-42d079022e83&user=Ravi+Theja&userId=60738cbbc7df&source=post_page-60738cbbc7df----42d079022e83---------------------post_header-----------" rel="noopener follow">Follow</a></span></p>
<p class="be b bf z jh ji jj jk jl jm jn jo bj">LlamaIndex Blog</p>
<p class="be b du z dt"><span class="lq">--</span></p>
<p class="be b du z dt"><span class="pw-responses-count lr ls">5</span></p>
<p class="be b bf z dt">Listen</p>
<p class="be b bf z dt">Share</p>
<p class="pw-post-body-paragraph nl nm gr nn b no np nq nr ns nt nu nv nw nx ny nz oa ob oc od oe of og oh oi gk bj" id="4130"><strong class="nn gs">UPDATE</strong>: The pooling method for the Jina AI embeddings has been adjusted to use mean pooling, and the results have been updated accordingly. Notably, the <code class="cw oj ok ol om b">JinaAI-v2-base-en</code> with <code class="cw oj ok ol om b">bge-reranker-large</code>now exhibits a Hit Rate of 0.938202 and an MRR (Mean Reciprocal Rank) of 0.868539 and with<code class="cw oj ok ol om b">CohereRerank</code> exhibits a Hit Rate of 0.932584, and an MRR of 0.873689.</p>
<p class="pw-post-body-paragraph nl nm gr nn b no np nq nr ns nt nu nv nw nx ny nz oa ob oc od oe of og oh oi gk bj" id="8267">When building a Retrieval Augmented Generation (RAG) pipeline, one key component is the Retriever. We have a variety of embedding models to choose from, including OpenAI, CohereAI, and open-source sentence transformers. Additionally, there are several rerankers available from CohereAI and sentence transformers.</p>
PDF 数据加载器,句法信息丢失的示例(将文章保存为 PDF 后重新加载)
from PyPDF2 import PdfFileReader
pdf = PdfFileReader(open('data/Boosting_RAG_Picking_the_Best_Embedding_&_Reranker_models.pdf','rb'))
pdf.getPage(0).extractText()
'Boosting RAG: Picking the Best\nEmbedding & Reranker models\n
Ravi Theja·Follow\nPublished inLlamaIndex Blog·7 min read·Nov 3\n
389 5\nUPDATE: The pooling method for the Jina AI embeddings has been adjusted\n
to use mean pooling, and the results have been updated accordingly.\n
Notably, the JinaAI-v2-base-en with bge-reranker-largenow exhibits a Hit\n
Rate of 0.938202 and an MRR (Mean Reciprocal Rank) of 0.868539 and\n
withCohereRerank exhibits a Hit Rate of 0.932584, and an MRR of 0.873689.\n
When building a Retrieval Augmented Generation (RAG) pipeline, one key\n
component is the Retriever. We have a variety of embedding models to\n
choose from, including OpenAI, CohereAI, and open-source sentence\n
Open in app\nSearch Write\n'
初步检查显示,PDF 数据加载器的输出看起来更可读,但仔细检查后发现丢失了结构信息——如何区分标题和节的结束?相比之下,HTML 文件保留了所有相关的结构。
理想情况下,你希望在数据加载器中保留所有原始格式,并且仅在下一步决定过滤和重新格式化。然而,这可能涉及为你的使用案例构建自定义数据加载器,并且在某些情况下可能是不可能的。我建议你从标准数据加载器开始,但花几分钟仔细检查加载的数据示例,并了解丢失了哪些结构。
了解丢失的语法结构是至关重要的,因为它指导了系统下游检索性能需要改进的潜在方向,允许进行有针对性的优化。
2. 数据格式化:无聊……但重要
文档分块(图像由作者使用 Dall-E 3 生成)
第二步,格式化,其主要目的是以统一的方式整理来自数据加载器的数据,以便为下一步的文本拆分做准备。如以下章节所述,将输入文本划分为无数较小的块是必要的。成功的格式化将文本设置成提供最佳条件以将内容划分为语义上有意义的块。简单来说,你的目标是将从 html 或 markdown 文件中检索到的潜在复杂语法结构转换为带有基本分隔符的纯文本文件,如 /n(换行)和 /n/n(节结束),以指导文本拆分器。
一个简单的函数将 BS4 HTML 对象格式化为包含标题和文本的字典,如下所示:
def format_html(tags):
formatted_text = ""
title = ""
for tag in tags:
if 'pw-post-title' in tag.get('class', []):
title = tag.get_text()
elif tag.name == 'p' and 'pw-post-body-paragraph' in tag.get('class', []):
formatted_text += "\n"+ tag.get_text()
elif tag.name in ['h1', 'h2', 'h3', 'h4']:
formatted_text += "\n\n" + tag.get_text()
return {title: formatted_text}
formatted_document = format_html(filtered_tags)
{'Boosting RAG: Picking the Best Embedding & Reranker models': "\n
UPDATE: The pooling method for the Jina AI embeddings has been adjusted to use mean pooling, and the results have been updated accordingly. Notably, the JinaAI-v2-base-en with bge-reranker-largenow exhibits a Hit Rate of 0.938202 and an MRR (Mean Reciprocal Rank) of 0.868539 and withCohereRerank exhibits a Hit Rate of 0.932584, and an MRR of 0.873689.\n
When building a Retrieval Augmented Generation (RAG) pipeline, one key component is the Retriever. We have a variety of embedding models to choose from, including OpenAI, CohereAI, and open-source sentence transformers. Additionally, there are several rerankers available from CohereAI and sentence transformers.\n
But with all these options, how do we determine the best mix for top-notch retrieval performance? How do we know which embedding model fits our data best? Or which reranker boosts our results the most?\n
In this blog post, we’ll use the Retrieval Evaluation module from LlamaIndex to swiftly determine the best combination of embedding and reranker models. Let's dive in!\n
Let’s first start with understanding the metrics available in Retrieval Evaluation\n\n
... }
对于复杂的 RAG 系统,其中相对于上下文可能有多个正确答案,将文档标题或标题等附加信息存储为文本块的元数据是有益的。这些元数据可以在之后用于过滤,如果可用,格式化元素如标题应影响你的分块策略。像 LlamaIndex 这样的库本地处理与元数据和文本一起封装在 Node 对象中的概念,我强烈推荐使用这个或类似的框架。
现在我们已经正确地完成了格式化,让我们深入探讨文本拆分的关键方面吧!
3: 文本拆分:大小重要
拆分文本,简单的方法(图像由作者使用 Dall-E 3 生成)
在为 RAG 系统准备数据以进行嵌入和检索时,将文本拆分为适当大小的块是至关重要的。这个过程受两个主要因素的指导:模型约束和检索有效性。
模型约束
嵌入模型对输入的最大 token 长度有一个限制;超出此限制的内容会被截断。了解所选择模型的限制,并确保每个数据块不超过此最大 token 长度。
多语言模型,特别是,与其英文对应模型相比,通常具有较短的序列限制。例如,广泛使用的 Paraphrase multilingual MiniLM-L12 v2 模型的最大上下文窗口仅为 128 个 token。
此外,还要考虑模型的训练文本长度——一些模型虽然在技术上可以接受更长的输入,但其训练数据却较短,这可能会影响对较长文本的性能。例如,SBERT 的 Multi QA 基础模型 如下所示,
检索效果
虽然将数据拆分到模型的最大长度似乎是合理的,但这可能并不总是能带来最佳的检索结果。较大的块为 LLM 提供了更多的上下文,但可能会掩盖关键细节,使得精确匹配更加困难。相反,较小的块可以提高匹配准确性,但可能缺乏获取完整答案所需的上下文。混合方法使用较小的块进行搜索,但在查询时包括周围的上下文以保持平衡。
尽管关于块大小没有确切的答案,但块大小的考虑在多语言项目和英语项目中是一致的。我建议进一步阅读相关资源,如 使用 Llamaindex 评估 RAG 系统的理想块大小 或 为生产环境构建基于 RAG 的 LLM 应用程序。
文本拆分:文本拆分的方法
文本可以通过各种方法进行拆分,主要分为两类:基于规则的(注重字符分析)和基于机器学习的模型。机器学习方法,从简单的 NLTK 和 Spacy 分词器到先进的 transformer 模型,通常依赖于语言特定的训练,主要是英语。尽管像 NLTK 和 Spacy 这样的简单模型支持多种语言,但它们主要处理句子拆分,而非语义划分。
由于基于机器学习的句子拆分器目前在大多数非英语语言中效果不佳且计算密集,我建议从简单的基于规则的拆分器开始。如果你保留了原始数据的相关句法结构,并正确地格式化了数据,结果将会质量良好。
一种常见而有效的方法是递归字符文本分割器,例如在 LangChain 或 LlamaIndex 中使用的,它通过在优先序列中找到最近的分隔字符(例如 \n\n, \n, ., ?, !)来缩短段落。
使用前一部分格式化文本的示例,使用 LangChain 的递归字符分割器如下所示:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-base-v2")
def token_length_function(text_input):
return len(tokenizer.encode(text_input, add_special_tokens=False))
text_splitter = RecursiveCharacterTextSplitter(
# Set a really small chunk size, just to show.
chunk_size = 128,
chunk_overlap = 0,
length_function = token_length_function,
separators = ["\n\n", "\n", ". ", "? ", "! "]
)
split_texts = text_splitter(formatted_document['Boosting RAG: Picking the Best Embedding & Reranker models'])
在这里,需要注意的是,应该将分词器定义为拟使用的嵌入模型,因为不同模型对词汇的计数方式不同。函数现在将按照优先顺序,首先通过我们在段落末尾引入的 \n\n 拆分任何超过 128 个标记的文本,如果不可能,则通过 \n 分隔的段落末尾,依此类推。前三个块将是:
Token of text: 111
UPDATE: The pooling method for the Jina AI embeddings has been adjusted to use mean pooling, and the results have been updated accordingly. Notably, the JinaAI-v2-base-en with bge-reranker-largenow exhibits a Hit Rate of 0.938202 and an MRR (Mean Reciprocal Rank) of 0.868539 and withCohereRerank exhibits a Hit Rate of 0.932584, and an MRR of 0.873689.
-----------
Token of text: 112
When building a Retrieval Augmented Generation (RAG) pipeline, one key component is the Retriever. We have a variety of embedding models to choose from, including OpenAI, CohereAI, and open-source sentence transformers. Additionally, there are several rerankers available from CohereAI and sentence transformers.
But with all these options, how do we determine the best mix for top-notch retrieval performance? How do we know which embedding model fits our data best? Or which reranker boosts our results the most?
-----------
Token of text: 54
In this blog post, we’ll use the Retrieval Evaluation module from LlamaIndex to swiftly determine the best combination of embedding and reranker models. Let's dive in!
Let’s first start with understanding the metrics available in Retrieval Evaluation
现在我们已经成功地以语义上有意义的方式拆分了文本,可以进入最终阶段,即将这些块嵌入以便存储。
4. 嵌入模型:在丛林中导航
嵌入模型将文本转换为向量(图片由作者使用 Dall-E 3 生成)
选择正确的嵌入模型对于检索增强生成(RAG)系统的成功至关重要,这比英语语言的情况复杂得多。比较模型的一个全面资源是 Massive Text Embedding Benchmark (MTEB),其中包含超过 100 种语言的基准。
你选择的模型必须是多语言的,或专门针对你正在使用的语言(单语言)定制的。请记住,最新的高性能模型通常以英语为中心,可能不适用于其他语言。
如果有相关资源,请参考与你的任务相关的语言特定基准。例如,在分类任务中,有超过 50 个语言特定的基准,帮助选择最有效的模型,适用于从丹麦语到西班牙语的语言。然而,重要的是要注意,这些基准可能不会直接指示模型在 RAG 系统中检索相关信息的效率,因为检索与分类、聚类或其他任务不同。任务是找到训练用于不对称搜索的模型,因为那些没有针对这一特定任务训练的模型可能会不准确地优先考虑较短的段落而非较长且更相关的段落。
该模型应在非对称检索中表现出色,将短查询匹配到较长的文本块。原因在于,在 RAG 系统中,你通常需要将简短的查询匹配到更长的段落中以提取有意义的答案。与非对称检索相关的 MTEB 基准列在检索部分。一个挑战是截至 2023 年 11 月,MTEB 的检索基准仅包括英语、中文和波兰语。
在处理像挪威语这样的语言时,可能没有特定的检索基准,你可能会想知道是否应该选择分类基准中表现最好的模型,还是选择一个在英语检索方面表现出色的通用多语言模型?
对于实际建议,简单的经验法则是选择 MTEB 检索基准中表现最好的多语言模型。注意,检索评分本身仍然基于英语,因此需要在你自己的语言上进行基准测试以验证性能(第 6 步)。截至 2023 年 12 月,E5-多语言系列是开源模型的一个强有力的选择。该模型经过针对非对称检索的微调,通过在嵌入前将文本标记为“查询”或“段落”,它通过考虑输入的性质优化了检索过程。这种方法确保了查询与知识库中相关信息之间的更有效匹配,从而提升了 RAG 系统的整体性能。根据基准测试,cohere-embed-multilingual-v3.0 可能表现更佳,但需付费。
嵌入步骤通常作为存储文档到向量数据库的一部分完成,但使用 E5 系列对所有分割句子进行嵌入的简单示例如下,使用了 Sentence Transformer 库。
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('intfloat/e5-large')
prepended_split_texts = ["passage: " + text for text in split_texts]
embeddings = model.encode(prepended_split_texts, normalize_embeddings=True)
print(f'We now have {len(embeddings)} embeddings, each of size {len(embeddings[0])}')
We now have 12 embeddings, each of size 1024
如果现成的嵌入在你的特定检索领域中表现不够理想,不用担心。随着 LLM 的出现,现在可以从现有语料库中自动生成训练数据,并通过在你自己的数据上微调现有嵌入提高性能,提升幅度可达 5–10%。LlamaIndex 在这里提供了一个指南 或 SBERTs GenQ 方法,其中主要是 Bi-Encoder 训练部分相关。
5. 向量数据库:嵌入的家园
嵌入存储在数据库中以供检索(图像由作者通过 Dall-E 3 生成)
在加载、格式化、拆分数据并选择嵌入模型之后,RAG 系统设置的下一步是嵌入数据并存储这些向量嵌入以供检索。大多数平台,包括 LangChain 和 LlamaIndex,都提供了集成的本地存储解决方案,使用像 Qdrant、Milvus、Chroma DB 这样的向量数据库,或者直接与基于云的存储选项如 Pinecone 或 ActiveLoop 集成。向量存储的选择通常不受数据语言(英语或其他语言)的影响。为了全面了解存储和搜索选项,包括向量数据库,我推荐你探索现有资源,例如这个详细介绍:关于向量数据库及其如何增强你的 LLM 应用程序的全部知识。这个指南将为你提供有效管理 RAG 系统存储方面的必要见解。
到目前为止,你已经成功创建了作为检索系统“大脑”的知识库。
生成响应(图像由作者 w. Dall-E 3 生成)
6. 生成阶段:去其他地方阅读 😉
RAG 系统的第二部分,生成阶段,在确保解决方案成功方面同样重要。严格来说,这是一个搜索优化问题,上面加上了一些 LLM,考虑因素较少依赖语言。这意味着针对英语的检索优化指南通常也适用于其他语言,因此在此未包含。
在最简单的形式中,生成阶段涉及一个直接的过程:获取用户的问题,使用第 4 步中选择的嵌入模型进行嵌入,在新创建的数据库中执行向量相似度搜索,然后将相关的文本块提供给 LLM。这使得系统能够用自然语言回应查询。然而,要实现高性能的 RAG 系统,需要在检索方面进行若干调整,如重新排序、过滤等。有关更多见解,我建议你探索一些文章,例如 提升检索增强生成系统性能的 10 种方法 或 通过混合搜索改进 RAG 管道中的检索性能。
结语:评估你的 RAG 系统
正确的选择是什么?(图像由作者 w. Dall-E 3 生成)
那么你接下来该做什么?针对你的具体问题和语言,正确的配置是什么?
现在可能已经很清楚,决定 RAG 系统的最佳设置可能是一项复杂的任务,因为涉及的变量众多。定制的查询和上下文基准对于评估不同配置至关重要,特别是因为针对你特定的多语言数据集和用例的现有基准非常不可能存在。
幸运的是,凭借大型语言模型(LLMs),创建定制的基准数据集已变得可行。检索系统的基准通常包括搜索查询及其对应的上下文(我们在第 4 步中拆分的文本块)。如果你拥有原始数据,LLMs 可以自动生成与数据集相关的虚构查询。像 LlamaIndex 这样的工具提供了内置功能来实现这一目的。通过生成自定义查询,你可以系统地测试嵌入模型、块大小或数据格式的调整对你特定场景下检索性能的影响。
创建一个具有代表性的评估基准涉及许多注意事项,2024 年初我将跟进一篇关于如何创建一个表现良好的检索基准的单独文章——敬请期待!
感谢你抽出时间阅读这篇文章,希望你觉得这篇文章对你有所帮助。
如果内容对你有帮助,请记得点赞👏👏👏,如有问题或评论,请随时与我联系。
参考文献:
-
如何分块文本数据的比较分析
-
关于向量数据库及其如何增强你的 LLM 应用程序的所有信息
-
提高检索增强生成系统性能的 10 种方法
-
通过混合搜索提高 RAG 管道中的检索性能
人工智能如何用来预测和解释学生表现?
人工智能能否用来减少留级率并改善教育?
·发表于 Towards Data Science ·27 分钟阅读·2023 年 9 月 18 日
–
👋 介绍
机器学习(ML)通过使计算机能够从数据中学习和做出决策,已经显著改变了各个行业。从推荐电商网站上的产品到诊断医疗状况,机器学习和人工智能(AI)的应用范围广泛且深远。
机器学习在音频、图像和视频处理中的应用也非常有用。例如,人脸识别和图像质量改进,这些内容在我之前的两篇文章中有所介绍。
该领域的简介、应用及当前问题
towardsdatascience.com ## 图像超分辨率:当前研究状态概述
对流行技术和剩余挑战的回顾
towardsdatascience.com
最近,大型语言模型(LLMs)如 ChatGPT 展示了 ML 在自然语言处理(NLP)中的强大能力,这是一种使计算机具备理解文本和语音的能力,与人类相当——有时,甚至比某些人类更强的能力……
在解决问题时,不同的 AI 领域也可以结合在一起。例如,在之前的一篇文章中,我使用了机器学习和 ChatGPT 来预测一个移动游戏应用的用户是否会在接下来的几天内停止玩游戏(这被称为玩家流失问题):
使用低代码平台进行玩家流失预测的数据分析和模型训练
towardsdatascience.com
AI 和 ML 在教育领域也有许多应用,这无疑构成了我们日常生活许多方面的基础。毕竟,良好的教育确保了明天的这一代能够引领和推动人类的进步。
尽管其重要性,我常常觉得文章没有充分阐明 AI 和 ML 在教育中的应用,而是更侧重于引人注目和点击诱导的内容。这就是为什么本文将考虑这样一个用例,特别是 ML 如何用于识别可能因学术困难或其他因素如社会经济因素而面临重读风险的学生——并帮助确定其背后的原因。
鉴于全球许多学生即将返回——而一些学生已经返回——他们的学习桌前,这也许是一个相当及时的文章,可以提供关于如何利用 ML 和 AI 改善学生教育的见解。
这种预测能力的意义是巨大的。通过及早识别出有困难的学生,学校可以针对具体需求量身定制干预措施,从而可能避免学生重读一年。这不仅有利于个别学生,还可以使教育机构更有效地分配资源。
从这些模型中获得的见解还可以促进对影响学生表现的各种因素之间复杂相互作用的深入理解。
本文将介绍数据科学家的思考过程以及如何解读任何发现,如下所示:
-
🖥️ 配置
-
🗂️ 数据集
-
🔍 数据探索
-
🔮 一些预测的时间
-
💯 奖励:我们可以预测实际成绩吗?
-
⚠️ 伦理和教学关注
-
🗒️ 结论
-
📚 参考文献
🖥️ 配置
如上所述,之前的文章中使用了一个低代码机器学习平台进行流失预测 a previous article。在这篇文章中,我将使用一个不同的平台 Dataiku,该平台也旨在使数据科学和机器学习变得简单——特别是对那些可能不太熟悉像 Python 这样的编程语言的人。
提供了一个免费层,实际上提供了大多数功能,尽管它需要在你自己的计算机上本地安装(这可能是好事也可能是坏事,取决于你的看法)。
请注意,Dataiku Science Studio (DSS) 仅作为实验版本提供用于测试目的,支持 Windows,在撰写本文时尚未正式支持。完整的警告列表可以在 Dataiku 文档 这里 中找到。
当你启动 Dataiku DSS 时,你应该会看到类似这样的界面(但没有我正在处理的项目和条目):
Dataiku DSS 的主屏幕。图片由作者提供。
🗂️ 数据集
我们将使用的数据集基于 UCI 机器学习库中的一个名为‘学生表现’ [1] 的数据集,拥有 CC BY 4.0 许可证。许多用于演示机器学习工具和应用的经典数据集都来源于 UCI 库,因此这是一个值得与其他平台如 Kaggle 一起考虑的重要数据来源。
发现全球的数据集!
archive.ics.uci.edu](https://archive.ics.uci.edu/dataset/320/student+performance?source=post_page-----23580747e8b0--------------------------------)
这个数据集实际上用于 Dataiku 培训和认证的一个评估中,也可以通过点击 +New Project > DSS Tutorials > ML Practitioner > ML Practitioner Assessment 直接在 Dataiku 中找到。
数据集中包含关于四所学校(GP、LT、MS、RC)的信息,包括:
-
学生特征,如学习时间、空闲时间、缺勤次数。
-
地理和社会经济指标,如通勤时间、家庭互联网接入、父母婚姻状态。
我们的数据集中有许多特征,但主要关注的特征包括以下内容:
-
school
: 学生的学校名称(即 GP、LT、MS、RC 之一) -
failures
: 过去的班级失败次数 -
grade
: 学生的成绩(例如,来自考试或其他评估) -
repeated
:学生是否需要重复学年。这是我们希望预测的目标变量。因此,这是一项二元分类问题,其中“正类”是学生重复学年,“负类”是无需重复学年的学生。
所有特征的详细描述可以在上面链接的 UCI 网站上查看。
🔍 数据探索
在考虑对数据做任何处理之前,你应该首先理解数据——不仅要了解特征表示的内容,还要探索和分析它们的特性。
让我们首先查看对最重要特征的单变量分析:
我们数据集中四个重要变量的单变量分析。图片作者提供。
已经可以观察到很多内容:
-
school
:每所学校的样本数量明显不平衡,GP 和 RC 占样本的大多数(总数的 61%)。属于 LT 和 MS 的样本明显较少。 -
failures
:大多数学生(占总数的 79%)似乎没有过去的失败记录。其余的学生有 1 到 3 次失败记录。 -
grade
:一半的学生(50%)获得了 10.0、12.0 或 13.0 的成绩。 -
repeated
:这是我们希望预测的特征,理想情况下应该是平衡的,以帮助减少机器学习模型对某个类别的偏好。不幸的是,这个特征往往是不平衡的——结果显示有兴趣的类别是代表性不足的:仅 30%的数据代表学生重复学年的情况。这将使得进行高精度预测变得更困难,因此在设计 ML 模型时也需要记住这一点。
让我们超越单独考虑每个变量的范围,查看学校与成绩之间的互动:
关于‘学校’和‘成绩’特征的频率表和马赛克图。图片作者提供。
在上图中,我们可以看到每所学校的前五个最常见成绩的样本数量(其余成绩被归类为“其他”),这两者在下方的频率表和相应的马赛克图中展示。这些图表基本上展示了相同的信息,但方式不同——表格显示原始值,而马赛克图则以视觉形式展示相同的数字。
我们还可以使用直方图来显示每所学校的每个成绩的频率:
按‘学校’划分的‘成绩’直方图。学校 LT 具有最高的均值和中位数成绩,而 RC 的平均成绩最低。图片作者提供。
从这些可视化图中,很明显不同学校之间存在一些差异。例如,似乎 LT 学校的平均和中位数成绩最高,而 RC 学校的平均成绩最低。这让我们质疑学校与学生成绩之间是否存在某种关系。让我们使用卡方独立性检验来测试一下:
卡方独立性检验。图片由作者提供。
测试的假设是学校
和成绩
是独立的。然而,测试结果显示在给定的置信水平 0.05 下,应当拒绝这一假设。换句话说,学校
和成绩
不是独立的。
这也许并不令人惊讶,因为我们期望不同学校的教育水平不同。然而,我们已经客观地证明了这一点(至少对于这个特定的数据集而言)。
鉴于上述情况,发现不同学校之间重修学生数量的差异也并不令人惊讶:
根据学生是否需要重修一年的情况对学校进行的单变量分析。图片由作者提供。
从上面的图像中观察到的一点是,RC 学校有更高比例的学生需要重修一年。然而,这也是一个在查看统计数据时需要小心的例子——除了原始数字和百分比,我们还应该记住每所学校的总学生人数不同。
换句话说,学生人数较多的学校更有可能有更多的学生需要重修一年,仅仅因为有更多的学生。
举个简单的例子。假设学校 A 有 20 名学生需要重修一年,而学校 B 有 200 名学生需要重修一年。人们可能会说学校 B 比学校 A 差得多。然而,我们后来发现学校 A 有 100 名学生,而学校 B 有 1000 名学生。这意味着两所学校的重修率都是 20%。显然,总学生人数也需要考虑在内。
在我们的案例中,考虑到每所学校的总学生人数不会改变观察结果太多,但在其他应用中仍然需要考虑。
到目前为止,我们对一些最重要的特征和需要注意的问题有了相当好的了解。是时候训练一些机器学习模型并做出预测了!
🔮 预测时间
机器学习模型通过进入 Dataiku DSS 的“实验室”部分进行训练,在那里我们可以配置要训练的模型及其超参数和全局设置。
为了简化起见,所有默认设置均保留(主要是将使用默认超参数设置训练逻辑回归模型和随机森林模型),除了优化度量标准——如前所述,我们的目标特征非常不平衡。这意味着度量标准的选择至关重要,因为一个模型可能在准确率等指标上获得很高的值。
正如在 另一篇文章 中讨论的那样,预测玩家流失的原因非常简单——一个模型可能倾向于预测最主导的类别,以至于它在大多数情况下都是正确的——即使它错误地预测了属于另一类别的所有样本,因为这些样本数量不多,不会显著影响性能!你可以在 Baptiste Rocca 和 Jason Brownlee 的两篇精彩文章中阅读更多相关内容。
为了实现这一点,接收器操作特征曲线下面积 (AUC ROC) 是一个更好的选择,可以最小化类别不平衡的影响。其值范围从 0 到 1(后者为完美分数)。
模型训练后,获得了以下结果:
训练模型的结果。图像由作者提供。
天哪! 看到两个模型都获得了完美的 AUC 和 F1 分数 1.0,这可能是你的第一反应!
然而,我担心消息并不那么乐观——实际上 Dataiku DSS 询问“这可能是真的吗?” 不幸的是,任何模型在任何数据集上获得完美分数都是非常困难的,除非它是一个非常简单的、很可能是合成生成的数据集。因此,当你看到如此高的分数时,你应该立即对为什么会发生这种情况产生一些怀疑。
在这种情况下,这是因为一种叫做 数据泄露 的现象,即在训练集之外可能无法获得的信息——即在将模型应用于新未见过的数据时——被用来训练和创建模型。
我们可以通过查看随机森林模型中最重要的变量和逻辑回归模型中的顶级系数来获取一些线索。这两者都指示了在预测结果时哪些变量被赋予了最多的权重——即最重要性。
从上面的图像可以观察到,grade
排在两个模型列表的最上面。这意味着 grade
对于模型做出正确预测非常重要。
结果表明,repeated
列(我们试图预测的那一列)是基于grade
的!难怪模型表现良好——grade
包含了直接得出我们目标变量所需的所有信息。
如果你考虑一下,这确实有意义——要确定一个学生是否需要重复学年,你自然会查看他们的成绩;如果成绩太低,这意味着他们没有学到足够的东西,重复课程会对他们有帮助。因此,在实际操作中,我们会根据学生的成绩来决定是否需要重复学年。
然而,我们显然无法提前知道它们的成绩,这也意味着这个特征不应该用于训练我们的模型——我们只应使用在学年期间可以使用的特征。
让我们丢弃grade
特征,并重新训练我们的模型:
未使用‘grade’特征训练的模型结果。图片来源:作者。
啊,这样好多了!也许很奇怪,我们对模型表现得更差的情况感到满意,正如下面表格所示,总结了不同评估指标下模型的表现:
四个训练模型的结果。前两个模型使用了‘grade’特征,底部两个模型没有使用‘grade’特征。图片来源:作者。
如所观察到的,新模型通常比旧模型表现更差,但它们在现实世界中实际上是可用的。就 AUC 而言,最佳模型是随机森林,得分为 0.922。尽管在召回率方面不如逻辑回归,但在准确率和精度方面更好。
这意味着,考虑到所有应被预测为需要重复学年的学生(‘正类’),随机森林模型错误地预测了更多的学生(比逻辑回归更多)为不需要重复学年。然而,随机森林模型更高的精度也意味着它较少出现假阳性,即错误地预测学生需要重复学年,而实际上他们不需要。
我们也可以在混淆矩阵中清楚地观察到这种行为:
随机森林模型(左)和逻辑回归模型(右)的混淆矩阵。‘1’表示学生需要重复学年,而‘0’表示学生不需要重复学年。图片来源:作者。
除了上面提到的 Baptiste Rocca 的文章外,我还邀请你查看 Koo Ping Shung 的文章,以获取更多关于准确性、精确度、召回率和 F1 的信息。
最好的模型真的取决于学校的需求——谨慎些,标记那些可能不需要重读一年(从而倾向于逻辑回归模型),还是避免过多的错误预测及其可能带来的后果,如资源浪费和烦恼的家长(从而选择随机森林模型)?
我可能会倾向于谨慎一些,选择逻辑回归模型,因为我知道可能有些学生实际上不需要重读一年。然而,这将给我一个机会来关注这些学生,然后手动重新评估是否需要让学生重读一年。毕竟,大多数机器学习模型确实需要某种形式的人工干预,以充分发挥其能力。
但是,请记住,只有两种类型的模型被训练,几乎没有对其超参数进行优化。因此,预计可以获得更好的性能。
希望你能认识到理解模型的重要性,而不是仅仅依据一个性能指标盲目使用它们——通常情况下,每个模型都有其优缺点,选择哪个模型取决于应用场景和你的需求。这是可解释人工智能(XAI)的一部分,它在不断增加的 AI 应用数量下变得越来越相关——也越来越必要。
我们可以做几种其他分析来更好地理解我们的模型,其中一些可能是针对所用模型类型的。然而,我会专注于两个重要的方面,以简化问题(并简洁!)。
第一个是变量重要性,我们在调查数据泄漏时实际上已经提到过。你可能还记得,这告诉我们模型在进行预测时优先考虑哪些特征,并给予哪些特征最多的“权重”。除了理解模型如何工作,它还帮助我们理解什么因素实际上会影响学生是否需要重读一年。
让我们关注在 AUC 方面表现最好的模型,即随机森林模型:
随机森林模型的变量重要性。图片由作者提供。
如所示,最重要的变量(差距相当大)是‘学校是 RC’和‘失败次数’。这很有道理——我们之前观察到,学校 RC 的平均成绩较低,这意味着学生在取得好成绩方面遇到更多困难,并且面临更高的考试和评估失败风险。现在我们有一些更客观和明确的证据来支持我们最初的观察。
说到失败,将其列在首位也是有道理的,因为较高的失败次数表明学生在学习和理解课程内容方面确实遇到困难。
相比之下,其余变量的重要性相对较低,但当然它们仍然是相关的。第三和第四个变量也与学校有关,而第五个变量可能有些令人惊讶 — Medu
。这指的是母亲的教育水平,它似乎起到了一定的作用。
为了更好地理解这些特征如何实际影响预测,我们可以查看部分依赖图(PDP)。这些图是通过基本上冻结样本中所有特征的值,然后仅变化感兴趣的特征来计算的。
特征的变化(在保持其他特征固定的情况下)显示其与结果的关系,例如是否是线性或单调的。更技术性地说,部分依赖性表示特征对机器学习模型预测结果的平均边际效应[2]。更多细节可以在 Christoph Molnar 的指南中找到 这里。
PDP 是通过所有样本计算并取平均的。每个样本的计算结果生成个体条件期望(ICE)图,以便进行更细粒度的分析(按学生层面)。
让我们考虑一个(应该是)相当直接的特征——学习时间。在查看数据或结果之前,我们可以直观地预期,学生重复学年的机会会随着学习时间的增加而减少,对吗?让我们看看studytime
的 PDP 是否符合这一预期:
学习时间的 PDP。图像来源:作者。
显然,随着学习时间的增加,对结果的影响会减少。也就是说,学生重复学年的可能性确实因更高的学习时间而减少。如前所述,这可能是显而易见的,但拥有一些客观证据来进行合理性检查还是很好的。了解发生了什么总是很重要的,而不是盲目相信计算机的输出——记住,“垃圾进,垃圾出”!
然而,有两点需要注意:
-
首先,部分依赖性的值相对较低,这意味着学习时间对结果的影响可能并不大。这是可以预期的,因为
studytime
被赋予了相对较低的重要性。 -
其次,除了部分依赖性外,还展示了训练集中特征值的分布。我们应当对训练集中只有少量样本的特征值保持谨慎,因为我们实际上是在训练数据的区域外得出结论(外推)。
现在让我们看看Medu
:
母亲教育水平的 PDP。图片来源:作者。
尽管其效应的幅度略高于studytime
,但仍然相对较低。然而,仍然可以观察到,母亲的教育水平提高时,学生需要留级的可能性趋于减少。
当教育水平超过 3(中等教育)且等于 4(高等教育)时,会出现特别显著的下降。这背后的解释可能是,教育水平较高的母亲在孩子遇到困难时(例如做作业时)会更有能力提供帮助。
最后,让我们查看学校的 PDP:
学校的 PDP。图片来源:作者。
学校是模型中最重要的变量,因此在不同学校之间会看到一些较大的差异。特别是,学校 RC 往往增加学生需要留级的可能性,而学校 LT 则基本上相反。这些观察结果再次与我们之前的分析一致。
我们将通过检查模型在测试集上的表现来总结这一部分,即模型在训练过程中未见过的一组样本:
最佳分类模型在测试集上的表现。图片来源:作者。
结果与在验证集上计算的结果相似,表明模型在不同样本集之间的表现相当一致。尽管结果不是完美的,但鉴于我们几乎没有对使用的模型或其参数进行优化,结果仍然相当不错。值得注意的是,精确度和召回率比较平衡,准确率为 0.851。然而,考虑到数据集的不平衡,AUC 值为 0.915 更具参考意义。
让我们看看模型在不同学校之间的表现,通过计算一个新的特征叫做prediction_correct
,该特征简单地检查预测值是否与已知(真实)值匹配:
各学校正确(和错误)预测的分布。图片来源:作者。
看起来模型在各学校之间的表现非常相似,不过它似乎在学校 MS 的表现最好(96%的样本预测正确),在学校 RC 的表现最差(75%的样本预测正确)。
但如果我们想知道模型在哪里出现了问题呢?是在预测学生是否需要重复一年时,还是在预测学生不需要重复一年时?也就是说,模型的表现是否在预测类别之间以及不同学校之间有显著差异?
让我们进行子群体分析,从而计算各个学校的指标:
跨学校的子群体分析。图片由作者提供。
此外,让我们计算每个目标类别和每个学校的正确预测数量的直方图:
按学校划分的每个类别(学生是否重复一年)的正确预测数量。蓝色条形表示正确预测,橙色条形表示错误预测。图片由作者提供。
主要观察结果包括:
-
在各个学校中,样本数量较多的类别(多数类)预测得较好。这包括 GP、LT 和 MS 中的类别‘0’(学生不需要重复一年),以及 RC 中的类别‘1’(学生需要重复一年)。MS 的样本几乎被完全预测正确。
-
少数类(即学校 RC 的‘0’,其他学校的‘1’)的差异更为显著。对于 GP 和 LT,正确预测的样本数量几乎等于错误预测的样本数量。对于 MS,大多数样本被正确预测。
-
除了上面的(2),学校 RC 的错误预测样本数量高于其他学校(如前所述)。然而,也可以观察到,少数类(‘0’)的错误样本数量高于正确预测的数量。
学校 RC 的观察结果可能是由于(a)RC 的少数类样本比其他学校的少数类样本多(即 RC 的样本分布不同),以及(b)RC 与其他学校相比是一个‘异常值’,因为它有更多需要重复一年学生。这意味着模型可能已经学习了如何在其他学校(样本数量较多)上表现良好,而在 RC 上的表现则较差。因此,模型可能需要调整以更好地适应这个学校的样本。
💯 奖金:我们能预测实际的分数吗?
在上述所有情况下,我们训练了一个分类模型来预测学生是否需要重复一年。但如果我们想知道实际的分数呢?毕竟,分数也可以用来判断学生是否需要重复一年。
那么问题就变成了——预测分数是否更好,然后通过应用一个阈值将其转化为二元分类问题?
如前所述,学生是否需要重读一年直接基于成绩——如果成绩低于 10,则需要重读一年;否则则不需要。因此,我们也可以训练一个回归模型,然后创建一个新特征,将低于 10 的成绩转换为‘1’(学生需要重读一年),否则转换为‘0’。
所以,让我们尝试一下!我在 Dataiku 中快速训练了几个回归模型,结果如下:
回归模型的表现。图片来源:作者。
随机森林模型的表现远优于普通最小二乘(OLS)模型,即使它仍然不出色。然而,可以训练更多模型,我们也没有调优任何参数来提高性能,因此这些结果对于第一次尝试来说相当不错。
在上面用于分类模型的相同测试集上的结果如下:
最佳回归模型在测试集上的表现。图片来源:作者。
这些结果与上面的结果相似,验证了我们在训练过程中看到的结果与在实际未见数据上部署模型时可能会看到的结果是一致的。虽然结果并不出色,但也不算太差——例如,平均绝对误差表明预测成绩的误差平均为+/- 1.2。
现在,让我们创建一个新的二元特征,其中低于 10 的成绩表示学生需要重读一年,而成绩大于或等于 10 则不需要重读一年。这使得结果与我们之前处理的二元分类情况类似。让我们查看结果:
使用预测成绩来判断学生是否需要重读一年时,回归模型的表现。图片来源:作者。
应注意,有些结果不可用,例如 ROC。然而,这些需要类别概率,而我们显然没有,因为我们对预测的数值成绩使用了简单的阈值。
可以观察到,我们拥有的指标的表现整体上不如分类情况。你可以说这在意料之中——毕竟,直接进行分类不是比先进行回归再进行二值化更好么?
实际上,不应该有很大差异——毕竟,成绩与是否需要重读一年之间有非常简单明了的关系。实际上,如果我们考虑准确率,差异并不大(0.82 对比 0.85),而 F1 分数也相当接近(0.74 对比 0.77)。
由于我们没有训练多种模型,也没有进行任何形式的参数/超参数优化,因此我们预计分类和回归模型的性能会有所提升。
然而,我们可以说的是,可以同时执行这两个任务——即预测学生是否需要重复一年及其相关概率,或者预测成绩(然后可以用来判断学生是否需要重复一年)。
最后,让我们检查一下这种第二种方法是否对任何类型的预测或学校存在失败情况,类似于我们对分类模型所做的:
每个类别的正确预测数量(无论学生是否需要重复一年),按学校分类。图片由作者提供。
对分类模型可以做出类似的观察,尽管误分类样本的数量较高(鉴于上述讨论的较低性能指标值,这是可以预期的)。或许有趣的是,学校 RC 的类别‘0’的误分类样本数量现在略低于正确分类样本的数量。
⚠️ 伦理和教育关注
拥有可以预测学生是否需要重复一年或者预测他们成绩的机器学习模型当然很好,但我们真的可以信任它们吗?如果预测不准确,会有什么影响?我们是否愿意让机器学习做出这样关键的决策?
透明性
主要关注之一是预测模型使用的透明性。教育机构必须向教育工作者和学生提供有关这些模型如何运作、依赖什么数据以及如何做出预测的清晰信息。透明的沟通有助于建立信任,使个人更好地理解决策过程。
这也是为什么我们进行了大量分析,并试图理解模型是如何运作的,以及它如何使用我们的特征(并且这是否有意义)。虽然一些模型(特别是基于深度学习的模型)往往是‘黑箱’,内部运行情况不明,但我们发现仍有一些工具可以帮助揭示其内部运作情况。
偏见缓解与反馈循环
另一个关键方面是缓解数据和模型中固有的偏见。偏见可能会加剧不平等,因此积极识别和纠正这些偏见是至关重要的。定期审计模型的预测以检查偏见是确保公平性的必要措施。换句话说,我们应当持续监测模型的预测,以确保它们保持任何期望的特征和行为。
我们还应采取措施对抗已知的偏见来源。例如,如果目标特征不平衡(如本案例所示),我们可以使用不同的评估指标,如 ROC AUC,并调整决策阈值,以适应我们的要求,例如改善少数类别的性能或减少假阴性(错误预测学生不需要重读一年)。
此外,输入特征的分布可能随时间变化,从而导致预测性能下降,这个问题被称为数据漂移。因此,模型可能需要使用新数据进行重新训练,以确保其保持相关性和令人满意的性能。
拥有这样的反馈循环对于持续改进至关重要。通过收集预测结果的数据,机构不仅可以优化其模型,还能解决任何意外问题,并随着时间的推移提高其准确性和公平性。
人工监督
如前所述,人工监督仍然不可或缺,以防止过度依赖自动预测。这不仅适用于数据分析、模型选择和所选模型的运行情况,也适用于任何决策。
确实,教育决策——如是否重读一年或提供额外支持——应包括能考虑更广泛背景的教育者,包括学生的整体进展、潜力和个人情况。这可以通过使用如 Shapley 值或个体条件期望(ICE)图等工具来促进:
特定学生的三个最强特征的 Shapley 值。失败次数(3)、学校(RC)以及学生不愿意继续接受高等教育,都导致了该学生极有可能会重读学年。
归根结底,教育者与学生在日常互动中建立的关系是机器无法复制的,这要求对模型输出进行人工分析。在如此敏感的应用中,ML 应作为助手或顾问,而不是完全的决策者。
提升 ML 模型性能
还有几种方法可以进一步改进并帮助最大化任何训练过的 ML 模型的性能。例如,收集更全面和准确的数据至关重要。机构必须投资于数据收集系统,这些系统可以捕捉学生表现的全景视图,包括学术和非学术因素。数据越多,ML 模型预测任何期望结果的能力就越强。
对更复杂的算法(及其超参数)和特征工程技术的实验也是提高模型预测能力所必需的。如在之前的文章中观察到,特征工程可以显著提升模型性能。这也是为什么了解数据并进行彻底探索至关重要,因为这将使我们能够确定可以生成哪些新特征。
验证
为了评估模型的泛化能力,进行稳健的验证是至关重要的,它确保模型在不同的人群和教育背景下表现一致。在我们的案例中,对两个不同数据集(从同一数据集中提取)进行评估表明,两个案例中的表现非常相似。另一方面,一个过拟合训练集的模型在其他数据集上的表现会大幅下降,这会促使采取行动(例如,使用更多的数据和更简单的模型)。
长期影响评估
进行长期研究以评估在教育环境中使用预测模型的影响至关重要。这些研究可以帮助我们了解这些工具如何影响学生的成果,以及它们是否真正有助于改善教育成果。它们还可以帮助我们确定什么样的性能水平是可以接受的,以增加对机器学习模型的依赖。例如,可能会确定只有当 AUC ROC 超过 0.95 时,才会考虑机器学习模型。
🗒️结论
这篇文章考虑了数据科学、机器学习、人工智能与教育的融合。具体来说,我们看到如何应用数据科学和机器学习来预测学生表现。这是通过两种方式完成的:
-
使用分类预测学生是否需要重复学年
-
使用回归预测学生的成绩
我们还看到如何将这两种方法结合起来,首先预测成绩,然后基于成绩应用阈值来确定哪些学生需要重复一年。
在性能方面没有明确的胜者,因为我们没有进行任何优化(无论是分类情况还是回归情况)。然而,这确实展示了预测和分析执行方式的灵活性,同时还需要考虑学校的具体要求。
我们还进行了相当多的分析——无论是在训练机器学习模型之前,还是之后。这是为了理解数据和训练后的模型。特别是,我们考虑了多个指标来确定每个模型的优点(和缺点),以帮助我们确定哪个模型最适合我们的需求,并观察哪些特征在预测结果中最为重要。这也帮助教育者理解影响学生表现的最重要因素,并采取措施最大化学生学习效果。
当然,还可以进行更多的分析。然而,这个案例研究中所做的分析已经提供了很多有用的见解。此外,还可以进行其他类型的机器学习方法,例如 因果推断(建模处理变量有无效应的结果差异,其它条件相同)。
希望这能体现数据科学家需要处理的信息的广度。
整篇文章中使用了 Dataiku DSS 软件,具有以下优缺点:
✔️ 优点:
-
功能非常丰富(而且我们甚至没有探索它的所有功能)。
-
界面友好且响应迅速,这意味着那些对 Python 等编程语言不太自信的人也能很容易地执行数据科学和机器学习任务。
-
可以在本地计算机上运行,确保您的数据隐私得到保护。
-
大多数功能是完全免费的(那些不免费的主要是面向企业用户的,例如支持的用户数量和数据库连接类型)。
❌ 缺点:
-
大量的功能也可能使平台使用起来有些压倒性。然而,提供了教程来介绍新用户使用平台及其几个数据科学方面的知识。
-
Dataiku DSS 目前尚未在 Windows 机器上正式支持,这令人遗憾,因为这种操作系统非常普遍。
总的来说,我认为这个平台对数据科学家非常有用,无论他们是否熟悉编程。对于那些了解 Python 或 R 等语言的人来说,Dataiku DSS 可能会使某些任务变得更简单、更快捷。自定义函数也可以直接在平台内编写,增强了平台的灵活性和执行更高级任务的能力。
希望这篇文章让您读得很有趣!请随时留下任何反馈或问题,并记得关注 我,并注册接收电子邮件更新 以确保您在未来文章发布时会收到通知。
📚 参考文献
[1] P. Cortez 和 A. Silva. “利用数据挖掘预测中学学生表现。” 收录于 A. Brito 和 J. Teixeira 编,《第五届未来商业技术会议(FUBUTEC 2008)论文集》第 5–12 页,葡萄牙波尔图,2008 年 4 月,EUROSIS,ISBN 978–9077381–39–7。
[2] J.H. Friedman, “贪婪函数逼近:一种梯度提升机器。”《统计年鉴》(2001):1189–1232。
图片来源:Pixabay
📝 对这篇文章有什么想法吗?请随时在 LinkedIn 上发布留言、评论或直接给我发消息!
📧 请务必 关注 我的 Medium 个人资料,并 注册接收电子邮件更新 以确保您在未来文章发布时会收到通知。
[## Medium
在 Medium 上关注我!
medium.com](https://medium.com/m/signin?actionUrl=https%3A%2F%2Fmedium.com%2F_%2Fsubscribe%2Fuser%2Fa9be78db0c9b&operation=register&redirect=https%3A%2F%2Ftowardsdatascience.com%2Fplayer-churn-rate-prediction-data-analysis-and-visualisation-part-1-12a9fdff9c10&user=Christian+Galea&userId=a9be78db0c9b&source=post_page-----23580747e8b0--------------------------------)
📚 查看我的 其他 Medium 文章:
阅读 Christian Galea 在 Medium 上的文章。博士后研究员,计算机视觉与机器学习(特别是……)
超越 LLaMA:开源 LLMs 的力量
原文:
towardsdatascience.com/beyond-llama-the-power-of-open-llms-cef807a54a4f
LLaMA 如何让开源再次变得酷炫
·发表于 Towards Data Science ·18 min read·2023 年 7 月 18 日
–
(图片由 Paz Arando 提供,来源于 Unsplash)
尽管大型语言模型(LLMs)近期取得了进展,但许多最强大的模型仍然只能通过 付费 API 访问,并且使用大量的 专有数据 进行训练,从而限制了研究社区对这些模型的访问或复制。这一趋势引发了严重的担忧,即 LLMs 是否将主要由少数几个集中化的组织控制,这些组织迫使他人支付费用以与这些模型互动。这种情况严格阻止了大多数研究人员直接访问或自行改进 LLMs。
“[许多] LLMs 需要巨大的计算资源进行训练,而且通常使用大型且专有的数据集。这表明未来,高能力的 LLMs 将主要由少数几个组织控制。” — 摘自 [5]
鉴于训练和托管大型语言模型(LLMs)的计算负担,我们可能会质疑开源这些模型对研究社区是否真的有帮助。如果我们不是拥有大量计算资源的大型组织的一部分,我们甚至能用 LLMs 进行有意义的研究吗? 如果不能,也许我们注定要面对一个中央控制和访问 LLMs 的世界。这些模型似乎具有过强的“引力”(即需要大量的数据和计算资源),让大多数人很难轻松使用它们。
LLaMA 的提议(以及随后向公众泄露)通过开源一套强大的(但较小的)LLM,走向了相反的方向。在 LLaMA 向公众发布之后,我们见证了一波大规模的 LLM 开放研究。这些研究产生了各种不同的模型,其中一些与 ChatGPT 的质量相当。然而,最显著的是,这些模型的生产成本极低(即,大多数情况下低于 $500)且计算资源 modest(即,部分模型可在普通 macbook 上运行!)。在这里,我们将调查一些最近提出的后 LLaMA 模型,并探索开源 LLM 研究如何使这一主题变得更易接触。
(来自 [3, 4, 5])
核心概念
在 之前的一篇文章中,我们了解了 LLaMA,这是一套开源的高性能 LLM,具有多种规模。LLaMA 模型仅在公共数据上进行训练,使其与开源兼容,并且无需访问专有数据即可重复生成。然而,LLaMA 的故事并未止步于此!这些模型最近已成为深度学习的热门话题。在本概述中,我们将探讨 LLaMA 使研究得以进行的原因,并了解这些模型为何以及如何变得流行。首先,我们将提供更多有关 LLaMA 的背景信息,然后概述本概述所需理解的重要思想。
LLaMA 是如何(或未能)开源的……
深度学习社区已经接受了开源一段时间,某些研究领域仍然如此(例如,请参见 Stable Diffusion)。然而,LLM 领域却大相径庭,因为最受欢迎/强大的模型仅通过付费 API 提供(例如,GPT-4 [6]、Claude 和 Cohere)。LLaMA [1] 的开源,即一套质量卓越的较小 LLM 基础模型,打破了这一趋势。然而,LLaMA 并没有 完全 开源……故事要复杂一些。
首先,LLaMA 被 Meta 公布,详细信息包括深入、有用的 出版物、申请访问 LLaMA 的表单以及一个 简单的仓库,在获得模型访问权限后可用于运行推理和标记化。为了获得模型访问权限,必须同意一系列要求,例如不将 LLaMA 用于商业目的,并确保用 LLaMA 创建的任何 衍生模型 遵循相同的许可证。但这些要求都被抛到一边了,因为在发布大约一周后,所有 LLaMA 模型的权重被公开发布到 4chan,任何人都可以下载。
尽管 LLaMA 的共享方式出乎意料(且可以说是有害的),但它引发了数千次下载,并且随后促进了大量的开放研究。鉴于 LLaMA 由更小的模型组成,这些模型对于没有大量计算资源的研究人员来说更为可及,这些模型非常适合这种情况。在几周内,大量令人惊叹的深度学习研究人员投入工作,利用 LLaMA 开展了各种项目,从在 Macbook 上托管多亿参数的 LLM 到用不到 $500 复现 ChatGPT。
指令微调
(来自 [10])
我们在本概述中看到的许多模型都是基于指令微调(或简称为指令调优)的思想。指令微调最初由 FLAN [10] 提出,它是一种训练形式,使语言模型在解决语言相关任务方面表现更好,而不仅仅是单一任务;见上文。实际上,这通过在一组“指令”上对语言模型进行微调来实现,这些指令包括与任务描述结合的微调示例。通过这种方法,我们可以通过使用不同的任务模板进行文本提示,微调语言模型以解决各种不同的任务;见下文。
(来自 [10])
当前,指令微调的一个最受欢迎的变体是通过对话示例对 LLM 进行微调,这些示例可以来自人类或由聊天机器人生成。鉴于许多最近的聊天机器人专门用于遵循指令并执行信息寻求对话,这些模型、它们的输出,甚至用于训练它们的数据都包含丰富的指令跟随示例和行为,可以直接用于指令微调。
(来自 [2])
自我指导。 与本工作相关的一种指令调整形式是自我指导框架[2],它通过生成用于微调的指令来减少对人工编写指令的依赖。特别地,这一过程从一小部分指令数据开始,并迭代地*(i)* 使用 LLM 生成新数据和*(ii)* 过滤低质量数据;见上文。这种技术能以最少的人力注释工作生成高质量的指令调整数据。
知识蒸馏
(来源于 [12])
在[11]中提出,知识蒸馏使用一个(大型)完全训练好的神经网络作为另一个(小型)神经网络的训练信号;见上文。虽然存在许多不同类型的知识蒸馏,但其背后的理念保持不变。即,如果我们使用*(i)* 普通训练数据和*(ii)* 一个更大、更强大的神经网络对这些数据的输出来训练一个神经网络,那么通常会比仅使用数据来训练神经网络得到更好的结果。通过将其输出作为训练目标,我们可以将一些信息从更大的网络“蒸馏”到正在训练的小型“学生”网络中。有关知识蒸馏及其众多变体的更多信息,请查看这里的链接。
其他内容……
除了以上涵盖的信息,我们还需要对大型语言模型(LLMs)及其工作原理有一个基础的理解。要了解这些知识,请查看以下资源。
在概述中,我们还会提到 OpenAI 目录中的一些具体模型的名称(例如,text-davinci-003
)。查看这里可以找到提供的模型列表(及其相关描述),这些模型包含在OpenAI API中。
Alpaca: 一个指令跟随的 LLaMA 模型 [3]
“在学术界研究指令跟随模型一直很困难,因为没有一个容易获得的模型在能力上接近于封闭源模型,如 OpenAI 的 text-davinci-003。” — 来源于 [3]
Alpaca [3] 是 LLaMA-7B [1] LLM 的一个微调版本,其性能类似于 OpenAI 的text-davinci-003
(即,GPT-3.5)。Alpaca 的微调过程基于 self-instruct [2],其中从表现更好的 LLM(即text-davinci-003
)收集指令跟随数据,并用于 SFT。简而言之,Alpaca 表明,在指令跟随背景中,通过高质量数据的微调可以显著提高小型开源 LLM 的质量。此外,整个 Alpaca 的微调过程费用仅为$600(包括数据收集和微调),使得这种指令跟随 LLM 易于且便宜地复制用于研究目的。
创建 Alpaca LLM(来自[3])
方法。 要通过 SFT 创建一个指令跟随的 LLM,我们需要 i) 一个高质量的预训练语言模型和 ii) 用于 SFT 的指令跟随数据。幸运的是,最近发布的 LLaMA 提供了易于访问的预训练语言模型。获得指令跟随数据要复杂一些,但一种方法是使用 self-instruct [2]。从高层次来看,self-instruct bootstraps LLM 生成的输出进行进一步训练。在 Alpaca 的案例中,我们使用text-davinci-003
通过以下方式生成指令跟随数据:
-
从self-instruct 的种子集开始,使用 175 个指令和输出对。
-
提示 LLM 生成更多指令,使用种子集作为上下文示例进行少量学习。
[3]的作者也采用了一些技巧(例如,修改过的提示和更高效的解码/生成过程),使数据生成过程比原始 self-instruct [2]更便宜、更高效。总体而言,通过 OpenAI API 生成指令跟随数据的费用不到$500,用于 52K 个指令跟随示例。
LLaMA-7B 模型随后使用基于 HuggingFace 的训练框架在这些数据上进行微调。通过使用完全分片的数据并行 (FSDP)和混合精度训练技术,微调过程在 8 个 A100 GPU 上缩短至 3 小时,成本低于 $100。用于创建 Alpaca 的代码/数据在线获取。然而,Alpaca 的商业使用被禁止,因为 i) LLaMA(Alpaca 基于的模型)具有非商业许可证,ii) OpenAI 禁止 使用其模型来训练竞争的 LLM。
结果。 Alpaca 在用于 self-instruct 的评估集上的指令(即,大多与电子邮件、社交媒体和生产力相关的任务)和由作者手工编写的开放领域指令上进行评估。在这些任务中,Alpaca 的表现类似于 text-davinci-003
(即,在测试的约 180 个案例中,表现最佳的占 50%)。尽管这种评估显然范围有限,考虑到 Alpaca 比 GPT-3.5 小得多且相对容易复制,其性能仍然非常令人印象深刻。
Alpaca 输出示例(来自 [3])
类似于 text-davinci-003
,Alpaca 的输出通常比 ChatGPT 的要短。换句话说,模型的风格反映了用于生成指令跟随数据的 LLM 的风格。
Vicuna: 一种具有 90% ChatGPT 质量的开源聊天机器人 [4]
(来自 [4])
信息检索对话代理(或聊天机器人)如 ChatGPT 很出色,但此类模型的训练框架和架构未知,这阻碍了开源研究。作为解决方案,[4] 的作者提出了 Vicuna,一种通过微调 LLaMA — 13B [1](即,一个与 GPT-3 性能相当的小型 LLM)创建的开源聊天机器人。Vicuna 的微调数据是与 ChatGPT 进行的用户对话示例,整个微调过程可以以不到 $300 的成本复制,从而使聊天机器人在研究中更加可及。与 Alpaca 相比,Vicuna 更加接近 ChatGPT,生成的答案更具细节和结构。
(来自 [4])
方法。 用于 Vicuna 的 SFT 数据通过ShareGPT的公共 API 下载,该平台允许用户分享与 ChatGPT 的对话。在微调之前,作者会过滤不适当和低质量的数据,并将较长的对话分割成适合 LLaMA-13B 最大上下文长度的较短片段。总共收集了 70K 个对话。类似于 Alpaca,该模型在 8 个 A100 GPU 上使用 FSDP(经过一些修改以降低成本和处理长序列)进行训练,约需一天时间;见上文。作者公开了代码,用于训练和托管 Vicuna。下表提供了 Vicuna 与开源 LLM LLaMA 和 Alpaca 的更全面的比较。我们将接下来讨论 Vicuna 的评估方法。
(摘自 [4])
结果。 准确评估聊天机器人非常困难,随着聊天机器人质量的提高,这种困难会加剧。例如,[4]的作者声称,自我指导评估集(用于评估 Alpaca)已被近期聊天机器人有效解决,这使得模型之间的差异难以分辨。鉴于现有基准的局限性和创建新的全面评估集的难度,[4]的作者选择了另一种策略:使用 LLMs 进行评估。
“随着 GPT-4 的最新进展,我们很好奇其能力是否已经达到类似人类的水平,这种水平是否可以支持一个自动化的评估框架用于基准生成和性能评估。” — 摘自 [4]
此时,我们可能会认为这实际上不可能奏效。聊天自指? 然而,令人惊讶的是,基于最近提出的GPT-4 模型 [6]形成的评估框架效果良好。首先,[4]的作者设计了八类问题(例如,角色扮演场景和数学任务)。然后,GPT-4 被提示在每个类别中生成多样化的问题。有趣的是,GPT-4 被发现能够生成近期聊天机器人难以回答的难题。
特别是,GPT-4 用于在每个类别中生成十个问题,并评估五种不同聊天机器人的输出(即 LLaMA-13B、Alpaca-13B、Vicuna-13B、Bard 和 ChatGPT)。进一步说,每个模型输出的质量通过要求 GPT-4 根据详细程度、帮助性、相关性和准确性对答案质量进行评分来判断。虽然以这种方式进行评估可能看起来有些牵强,但 GPT-4 对模型的排名相当一致,甚至解释了其推理过程。
(摘自 [4])
根据 GPT-4 的判断,Vicuna 的输出质量相对于 ChatGPT 为 92%;见上文。这个比例是通过让 GPT-4 为每个模型的输出分配分数来实现的。然后,通过计算所有问题的总质量分数来评估模型之间的相对表现。尽管这种评估方法并不严格,但它相当有趣、相对一致,并迫使我们思考 LLM 领域未来会如何演变。
(来自 [4])
与其他开源模型相比,我们看到 GPT-4 更倾向于 Vicuna 的输出。此外,Vicuna 在 45% 的问题上产生的输出质量超过或匹配 ChatGPT。这种质量水平对于一个只需 $300 即可微调的模型来说相当令人印象深刻!
Koala: 一个用于学术研究的对话模型 [5]
“足够小到可以在本地运行的模型,如果经过精心挑选的数据训练,可以捕捉到其较大同行的大部分性能。”— 来自 [5]
在这一点上,我们可能开始怀疑是否会用尽用于为 LLM 命名的动物。尽管如此,Koala 与 Vicuna 和 Alpaca 类似,因为它继续致力于缩小专有和开源 LLM 之间的质量差距。更具体地说,Koala 是 LLaMA-13B 的一个版本,经过在各种来源的对话数据上进行微调,从公共数据集到与互联网上其他高质量 LLM 的对话。
Koala 与相关 LLM 的比较(来自 [5])
在真实世界的提示上进行评估时,Koala-13B 被发现相较于 ChatGPT 表现出具有竞争力的性能,甚至超过了相关的 Alpaca 模型。因此,Koala 的结果继续支持我们在所有 LLaMA 后续工作中看到的趋势。即,我们看到较小的模型在获得正确的数据进行微调后可以取得令人印象深刻的质量。这样的发现可能会让我们想知道:我们是否过于关注模型规模,而对数据质量关注不够?
方法。 Koala 使用来自公共数据集和互联网的对话数据进行微调。然而,[5]中的作者强调了为微调策划高质量数据集的重要性。用于微调 Koala 的数据大致可以分为蒸馏基础(即,来自其他 LLM 的对话)或开源数据(即,公开数据集中可用)两类,包括来自ShareGPT、HC3、OIG、Anthropic HH和 OpenAI WebGPT/Summarization的数据。此外,微调集甚至包括用于训练 Alpaca [3]模型的数据。
(见[8])
所有这些数据都是基于对话的。然而,需要注意的是,一些数据集包含多个对话或对每个问题的响应,这些响应被评为好或坏。有趣的是,我们可以借鉴先前的技术[8],将这些信息纳入 LLM 的微调过程中。特别地,这是通过条件训练来完成的,我们可以简单地将数据条件化,通过人类偏好标记来训练 LLM(例如,只需附加有关对话是否好的文本信息);见上文。这种方法可以提高性能,并使我们能够使用即使是低质量的对话进行模型训练。
[5]中的作者使 Koala 的训练和托管框架公开可用。该模型使用八个 V100 GPU 训练两个时期,耗时约 6 小时。总的来说,训练该模型的计算成本低于$100(假设我们可以使用可抢占/临时实例),这意味着 Koala 是迄今为止我们见过的模型中最便宜的再现模型!
结果。 [5]中的作者训练了两种不同类型的 Koala 模型:
-
Koala-distill:仅在蒸馏数据上进行微调(即,来自其他聊天机器人的对话示例)
-
Koala-all:使用上述所有数据进行微调。
根据人类试验和反馈,这些 Koala 模型的质量与 Alpaca 和 ChatGPT 进行比较。评估中使用了来自 Alpaca [3]评估集的问题和来自互联网的真实用户查询集。作者选择增加更多问题到评估集中,因为 Alpaca 的评估集与其训练数据非常相似(即,两者均源自self-instruct [2])。
(见[5])
当人们在质量和正确性方面评估不同 LLM 的输出时,发现 Koala-all 通常超越了 Alpaca 的表现,并在许多情况下达到或超过了 ChatGPT 的质量。此外,我们看到 Koala-distill 实际上表现优于 Koala-all。这有点违反直觉,因为 Koala-distill 的微调数据集较小(即仅包含来自 ChatGPT 的示例对话),但这告诉我们,微调所用数据的类型和质量极为重要。也就是说,使用来自更大、更好的 LLM 生成的对话进行微调是非常有效的。
“构建强大对话模型的关键可能在于策划高质量、多样化的用户查询对话数据”— 来自[5]
进一步探索…
尽管 LLaMA 提出的时间相对较短,但 Alpaca、Vicuna 和 Koala 并不是唯一受到 LLaMA 启发或支持的显著模型。以下是最近发布的其他开源语言模型的列表。
-
Lit-LLaMA: 一个基于 LLaMA 的开源复现项目,遵循Apache-2.0 许可证(允许商业使用)。
-
ChatLLaMA: 使用 LLaMA、你自己的数据以及尽可能少的计算资源来制作个性化版本的 ChatGPT。
-
FreedomGPT: 一个开源的对话型聊天机器人(基于 Alpaca),强调没有审查。
-
ColossalChat: 一个开源的 ChatGPT 复制品,配备了一个完全实现的(且公开的)RLHF 管道,基于 LLaMA(包括数据收集、监督微调、奖励模型训练和强化学习微调;详见下文)。
-
StackLLaMA: 提供了一个基于 RLHF 的微调开源实现和讨论,用于生成强大的聊天机器人(具体使用 LLaMA 作为起点)。
-
GPT4All: 用于训练基于 LLaMA 和 GPT-J 的开源 LLM 的演示、数据和代码(拥有 Apache-2.0 许可证!)。
-
Dolly 2.0: 该模型不基于 LLaMA,但是一款开源聊天机器人,经过指令微调以达到类似 ChatGPT 的质量,并开放商业使用。
-
Open Assistant: 一个开源聊天机器人(与 ChatGPT 相当),能够理解任务、与第三方系统互动并检索信息。
(来自 [9])
除了提出的各种模型,LLM 的研究和使用也因为 LLaMA 而变得更加可及。LLaMA-13B 已经可以仅用一个 GPU 运行,但现在我们甚至可以在本地(例如,在 macbook 上)完成这个操作!
-
Alpaca.cpp: 本地运行 Alpaca 的开源复刻版本。
-
GPTQ-4-LLaMA: 一个 4-bit 量化 的 LLaMA 版本。
-
LLaMA.cpp: 几个开源 LLM 的 4-bit 量化推理,这使得本地托管成为可能(例如,在 macbook 上)。
看起来 LLMs 很快将比以往更多地向公众开放。
要点
我们可以从这项工作中推断出的主要观点是 i) LLaMA 激发了大量开源 LLM 研究和 ii) 围绕 LLM 的研究/使用因为 LLaMA 而变得显著更为可及。如果一个月前你告诉我,我可以在我的 macbook 上运行接近 ChatGPT 性能的 LLM,我是不会相信的。这是令人兴奋的时刻,我很感激能成为这样一个了不起的社区中的一员!以下列出了几个基本要点。
LLMs 适合所有人。 如果之前我们对此有所质疑,现在我们知道研究社区确实可以在 LLMs 上进行有价值的研究。几周前,我们大多数人认为由于极高的数据和计算需求,LLMs 并不容易获得。然而,现在我们可以用几百美元训练出 ChatGPT 级别的模型(或至少接近的模型),甚至可以在我们的笔记本电脑上使用这些模型进行对话!
较小的模型是否足够? 长期以来,模型规模一直是高性能 LLM 的一个重要组成部分(连同大规模的预训练数据集)。然而,像 Koala 和 Vicuna 这样的模型告诉我们,较小的 LLM 实际上可以表现得非常出色(甚至在某些情况下与强大的 LLM 如 ChatGPT 的表现相匹配)。这样的发现突显了数据质量的重要性。在我们看到的工作中,最有效的技术往往使用较大 LLM 的输出作为训练数据,这表明知识蒸馏可能是创建小而强大的 LLM 的重要组成部分。
商业上可行? 尽管这些技术都很酷,但在商业应用中使用它们却很困难。例如,OpenAI 禁止使用 ChatGPT(或任何其他 API 模型)来训练竞争模型,从而阻止了基于 OpenAI API 的知识蒸馏方法。此外,即便是 LLaMA 本身也禁止商业使用。因此,像 Alpaca、Koala 和 Vicuna 这样的模型仅在研究层面上具有兴趣,它们的方法不能用于任何商业用途的模型。然而,随着像 Lit-LLaMA 这样的提案出现,这些模型的商业可行版本可能会逐渐出现。
结语
非常感谢您阅读本文。我是 Cameron R. Wolfe,Rebuy 的 AI 总监。我研究深度学习的经验和理论基础。您还可以查看我在 medium 上的 其他文章!如果您喜欢这篇文章,请关注我的 twitter 或订阅我的 Deep (Learning) Focus 新闻通讯,在其中我通过对流行论文的易懂概述帮助读者深入理解 AI 研究中的主题。
参考文献
[1] Touvron, Hugo 等人。“Llama:开放且高效的基础语言模型。” arXiv 预印本 arXiv:2302.13971 (2023)。
[2] Wang, Yizhong 等人。“Self-Instruct:将语言模型与自生成的指令对齐。” arXiv 预印本 arXiv:2212.10560 (2022)。
[3] Taori, Rohan 等人。“斯坦福 Alpaca:一个遵循指令的 LLaMA 模型。” (2023)。
[4] Chiang, Wei-Lin 等人。“Vicuna:一个开源聊天机器人,令人印象深刻的 GPT-4 质量达到 90%* ChatGPT。” (2023)。
[5] Geng, Xinyang 等人。“Koala:一个用于学术研究的对话模型。” (2023)。
[6] OpenAI (2023)。“GPT-4 技术报告。” ArXiv, abs/2303.08774。
[7] Guo, Biyang 等人。“ChatGPT 与人类专家有多接近?比较语料库、评估和检测。” arXiv 预印本 arXiv:2301.07597 (2023)。
[8] Liu, Hao 等人。“事后链将语言模型与反馈对齐。” arXiv 预印本 arXiv:2302.02676 (2023)
[9] Ouyang, Long 等人。“通过人类反馈训练语言模型以遵循指令。” 神经信息处理系统进展 35 (2022):27730–27744。
[10] Wei, Jason 等人。“微调语言模型是零样本学习者。” arXiv 预印本 arXiv:2109.01652 (2021)。
[11] Hinton, Geoffrey, Oriol Vinyals 和 Jeff Dean。“在神经网络中蒸馏知识。” arXiv 预印本 arXiv:1503.02531 (2015)。
[12] Gou, Jianping 等人。“知识蒸馏:综述。” 国际计算机视觉期刊 129 (2021):1789–1819。
超越 NeRF(第一部分)
原文:
towardsdatascience.com/beyond-nerfs-part-one-7e84eae816d8
提高 NeRF 训练速度 100 倍或更多……
·发表在 Towards Data Science ·15 分钟阅读·2023 年 6 月 7 日
–
(照片由 Mathew Schwartz 提供,来源于 Unsplash)
正如我们在之前的概述中所见,神经辐射场(NeRFs) [4] 的提出在神经场景表示领域是一个突破。给定一些底层场景的图像,我们可以训练一个 NeRF 以高分辨率生成该场景的任意视角。简而言之,NeRF 利用深度学习提供 3D 场景的摄影级渲染。
但,它们有一些显著的问题。在本概述中,我们将特别关注 NeRF 的两个局限性:
-
训练一个可以准确渲染新视角的 NeRF 需要大量的场景图像。
-
使用 NeRF 进行训练(和渲染)是很慢的。
作为解决这些问题的方案,我们将概述 NeRF 方法的两个显著扩展:PixelNeRF [1] 和 InstantNGP [2]。在学习这些方法的过程中,我们会看到,NeRF 所面临的大部分问题可以通过制作更高质量的输入数据以及利用深度神经网络将已学模式推广到新场景的能力来解决。
(来自 [1, 2])
背景
我们最近了解到许多不同的使用深度学习建模 3D 形状和场景的方法。这些概述包含了一些背景概念,这些概念也将有助于理解本概述中的概念:
除了这些概念外,在这个概述中,理解 NeRFs [4]也会非常有用。为了建立这种理解,我建议阅读我对 NeRFs 的概述这里。
特征金字塔
在这篇文章中,我们将看到多个实例,展示如何使用深度神经网络将图像转换为相应的(金字塔)特征表示。但有些人可能对这个概念不太熟悉。因此,我们需要快速了解特征表示,并概述我们在深度学习中可能遇到的一些不同变体。
什么是特征? 在了解特征金字塔之前,我们需要理解“特征”一词的含义。通常,神经网络的输出会是分类、边界框集合、分割掩码或其他类似的东西。例如,在图像分类中,我们将图像作为输入,通过神经网络传递,网络的最后一层是一个分类模块,将隐藏状态转换为类别概率向量。简单明了!
从深度神经网络中提取特征(由作者创建)
然而,有时我们不希望执行最后一步。相反,我们可以直接取网络的最终隐藏状态(在分类模块之前),并将这个向量作为数据的表示;见上文。这个向量,也称为特征(或特征表示),是数据中语义信息的压缩表示,我们可以用它来执行各种任务(例如,相似性搜索)。
什么是特征金字塔? 多尺度(或“金字塔”)策略是计算机视觉中的一个重要基本概念。基本思想很简单。在神经网络的层次中,我们偶尔*(i)* 降采样特征的空间分辨率,并*(ii)* 增加通道维度。例如,见ResNet-18 [6]的示意图。这个 CNN 包含四个“部分”,每个部分的通道维度逐渐增高,空间维度逐渐降低;见下文。
ResNet 架构中“部分”的基本示意图(由作者创建)
从这个网络中提取特征的一种方法是仅使用最终的隐藏表示。但是,与网络早期层相比,这种表示不包含太多的空间信息(即,空间维度在每一层中逐渐降低!)。这对于依赖图像中空间信息的密集预测任务(例如目标检测)是一个问题!为了解决这个问题,我们需要构建一个特征金字塔[3]。
简而言之,特征金字塔从网络中的几个不同层中提取特征,而不是仅使用网络最终层的特征;见下文。
(来自[3])
得到的特征集包含不同量的空间和语义信息,因为每一层的空间和通道维度都不同。因此,特征金字塔通常生成对各种不同任务有用的图像特征。在本概述中,我们将看到特征金字塔用于为 NeRF 的变体提供额外的输入信息!
输入编码
有时,我们有些数据不想直接输入到机器学习模型中,因此我们将这些数据的编码版本作为输入。这是机器学习中的一个基本概念。例如,独热编码的分类变量。一个更复杂的例子是核函数,或我们将数据通过的函数(即,可能使其线性可分)然后再提供给模型。在这些情况下,我们都在编码/转换输入,以便它以更适合模型的格式呈现。
NeRF 架构中的位置编码(由作者创建)
位置编码。 类似地,当我们将 3D 坐标作为输入传递给 NeRF 的前馈网络时,我们不想直接使用这些坐标作为输入。相反,我们使用位置编码方案将它们转换为更高维的向量;见上文。这种位置编码方案是用于在变换器[6]中为标记化输入添加位置数据的完全相同的技术。在 NeRF 中,位置编码已被证明能显著改善场景渲染效果。
可学习的嵌入层。 位置编码方案有一个问题——它们是固定的。如果我们想学习这些编码呢?一种方法是构造一个嵌入矩阵。给定一个将每个空间位置映射到矩阵中的索引的函数,我们可以检索每个空间位置对应的嵌入并将其作为输入。然后,这些嵌入可以像普通模型参数一样进行训练!
出版物
现在,我们将概述一些扩展和改进 NeRF 的出版物。特别是,这些出版物 (i) 通过较少的场景图像生成高质量的场景表示,并 (ii) 使 NeRF 的训练和渲染过程更快。
PixelNeRF:来自一张或几张图片的神经辐射场 [1]
(来源于 [1])
原始 NeRF 公式的主要缺点之一是它必须针对每个场景进行训练和使用。通过 NeRF 获取每个场景的表示在计算上是昂贵的,并且需要许多具有姿态的场景图像。PixelNeRF [1] 旨在通过将 NeRF 的输出条件化为由预训练的深度神经网络创建的图像特征来缓解这一问题。通过使用图像特征作为输入,PixelNeRF 可以利用先前的信息,在仅有少数场景图像的情况下生成高质量的场景渲染。因此,它在数据有限的情况下显著提高了场景表示的质量。
方法。 PixelNeRF 与原始 NeRF 公式非常相似。它使用前馈神经网络通过预测给定空间位置和视角方向(已转换为 位置嵌入)的颜色和不透明度值来建模辐射场。体积渲染和训练过程没有改变。这些方法之间的主要区别在于 pixelNeRF 具有一个额外的输入组件:从底层场景视图中衍生的图像特征。
(来源于 [1])
PixelNeRF 能够将一个或多个场景图像作为输入的一部分。图像首先通过一个预训练的编码器——一个特征金字塔 ResNet 变体[6]——来生成特征金字塔。从这里,我们可以提取这些特征中对应于特定空间位置的区域(这可以通过相机位姿信息比较容易地完成;见 [1] 的第 4.1 节)。然后,我们将这些提取的特征与对应的空间位置和视角方向串联作为 PixelNeRF 前馈网络的输入。
让我们思考一下 PixelNeRF 的前馈神经网络的单次前向传播。我们在这一前向传播中考虑一个单一的空间位置和观察方向。如果我们可以访问到场景的单幅图像,我们可以通过以下方式包含这些信息:
-
通过编码器传递图像以生成特征网格。
-
通过提取与当前空间位置对应的特征金字塔区域来获取特征。
-
连接空间、方向和特征输入。
然后,PixelNeRF 的其余组件与原始的 NeRF 公式相匹配;见下文。
(来自 [1])
如果有多个场景图像可用,我们只需将 PixelNeRF 的前馈网络分为两个组件。第一个组件使用上述过程单独处理每张图像。即,网络通过将每张图像的特征与相同的空间和方向输入信息连接起来来执行单独的前向传播。
PixelNeRF 具有多个输入视角的架构(作者创建)
每次前向传播都会产生一个输出向量。我们可以通过计算这些向量的平均值来聚合它们,然后将这个平均向量通过更多的前馈层以生成最终的 RGB 和不透明度输出;见上文。尽管有这种修改后的架构,PixelNeRF 的训练过程与 NeRF 相似,并且只需要场景图像的数据集。
结果。 PixelNeRF 在 ShapeNet 上进行物体和场景视图合成等任务的评估,以及其表示真实世界场景的能力。首先,PixelNeRF 被训练以代表来自特定 ShapeNet 类(例如,椅子或汽车)的对象,给定一张或两张输入图像(即,一张或两张图像场景)。在这种情况下,我们发现 PixelNeRF 在从少量输入图像重建对象方面优于基线;见下文。
(来自 [1])
此外,PixelNeRF 不进行任何测试时优化,而像 SRNs [5] 这样的基线则不是这样。因此,尽管 PixelNeRF 更快且解决了比基线更困难的问题,但它的表现更为出色。当我们以类别无关的方式训练 PixelNeRF(即,在 ShapeNet 的 13 个对象类别上),我们看到其性能提升更为显著!PixelNeRF 在表示这更广泛的对象集方面超越了基线;见下文。
(来自 [1])
当 PixelNeRF 在更复杂的设置中进行评估时(例如,未见过的类别、多对象场景、真实图像等),我们继续看到性能的提升。最值得注意的是,PixelNeRF 在捕捉多对象场景和在测试时推断未见对象的能力上显著优于基线;见下文。
(来自[1])
将这一点推向极限,PixelNeRF 可以仅凭三张真实场景的输入图像就重建出相当高保真的场景;见下文。这些结果强调了 PixelNeRF 在给定有限且噪声数据下建模场景的能力。在这种情况下,NeRF 无法准确重建场景。
(来自[1])
多分辨率哈希编码的即时神经图形原语 [2]
(来自[2])
PixelNeRF [1] 允许我们从少量图像中恢复场景表示。但请记住,NeRF 的训练过程也很慢(即单个 GPU 上需要2 天)!考虑到这一点,我们可能会问自己:我们能多快训练一个 NeRF? 在[2]中提出的即时神经图形原语(InstantNGP)展示了我们可以大大加快训练 NeRF 的速度。
使用哈希函数对简单的特征嵌入矩阵进行索引(作者创建)
InstantNGP 的方法类似于 NeRF [4]——唯一的区别在于我们如何构建前馈网络的输入。我们没有使用位置编码方案,而是构建了一个多分辨率哈希表,将每个输入坐标映射到一个可训练的特征向量;见上文。该方法*(i)* 向 NeRF 添加了更多可学习的参数,并*(ii)* 为每个输入坐标生成丰富的输入表示,从而使前馈网络变得更小。总体而言,这种方法可以显著加快 NeRF 的训练过程。
方法。 实际上构建和查询输入特征哈希表的方法(不幸的是)比上述简单示意图要复杂得多。让我们更深入地探讨 InstantNGP [2]是如何处理输入特征的。
InstantNGP 采用参数化的方法来编码输入。与使用位置嵌入函数将坐标映射到固定的高维输入的 NeRF 不同,InstantNGP 在训练过程中学习输入特征。从高层次看,这是通过以下方式完成的:
-
在嵌入矩阵中存储输入特征。
-
基于输入坐标对嵌入矩阵进行索引。
-
通过随机梯度下降正常更新特征。
让我们逐一处理这些组件。首先,我们需要创建一个可以索引的可学习输入特征表。在[2]中,输入特征存储在一个包含L
级特征(即L
个不同的嵌入矩阵)的多分辨率表中。每一层表都有T
个维度为F
的特征向量(即一个T x F
大小的矩阵)。通常,这些参数遵循以下所示的设置。
(来源于 [2])
每个级别旨在以不同的分辨率表示 3D 空间,从N-min
(最低分辨率)到N-max
(最高分辨率)。我们可以将其视为将 3D 空间划分为不同粒度的体素网格(例如,N-min
将使用非常大/粗糙的体素)。通过这种方式划分 3D 空间,我们可以确定输入坐标所在的体素——这对于每个分辨率级别都是不同的。坐标所在的体素随后用于将该坐标映射到每个级别的嵌入矩阵中的一个条目。
(来源于 [2])
更具体地说,[2]中的作者使用上面显示的哈希函数将体素位置(即,由体素边缘的坐标给出)映射到每个分辨率级别的嵌入矩阵中的条目索引。值得注意的是,分辨率较粗的级别(即,较大的体素)将具有较少的哈希冲突,这意味着完全不同位置的输入坐标被映射到相同特征向量的可能性较小。
(来源于 [2])
在我们检索到每个分辨率级别的相应特征向量后,我们会得到多个特征向量对应于单个输入坐标。为了合并这些向量,我们进行线性插值,其中插值的权重是通过输入坐标在每个级别的体素中的相对位置得出的。从这里开始,我们将这些向量与其他输入信息(例如,位置编码的视角方向)拼接起来,形成最终输入!InstantNGP 中的完整多分辨率方法如上图所示。
(来源于 [2])
由于使用了更高质量的可学习输入特征,InstantNGP 能够相对于 NeRF 使用更小的前馈网络,同时在质量上取得类似的结果;见上文。当这些修改与更高效的实现(即,完全融合的 cuda 内核,最小化带宽和浪费操作)结合时,NeRF 的训练时间可以显著缩短。事实上,我们可以在几秒钟内使用 InstantNGP 获得高质量的场景表示。
结果。 InstantNGP 使用与[4]中提出的几乎相同的设置来训练 NeRF,除了修改的输入编码方案和更小的前馈神经网络。坐标输入使用多分辨率哈希表进行编码,而视角方向则使用普通的、位置编码的嵌入进行编码。
(来源于 [2])
使用提出的方法和更快的渲染程序,[2]中的作者发现 InstantNGP 可以在几秒钟内训练场景表示,甚至可以以 60 FPS 渲染场景!这相对于 NeRF 是一个巨大的效率提升;详见上文。值得注意的是,InstantNGP 在训练仅 15 秒后就能与 NeRF(需要数小时训练)竞争,表现突出!
为了确定这种加速是否来源于更高效的 cuda 实现或多分辨率哈希表,作者进行了一些分析。他们发现高效的实现确实提供了很大的加速,但仅使用哈希表和较小的前馈网络即可在训练 NeRFs 时获得 20 倍到 60 倍的加速。
“我们用频率编码替代了哈希编码,并扩大了 MLP 以大致匹配[NeRF]的架构……我们算法的这个版本在训练约 ∼5 分钟后接近 NeRF 的质量,但在训练更短时间(5 秒至 15 秒)后被我们的完整方法超越,这得益于哈希编码和较小的 MLP,使得效率提高了 20 倍到 60 倍。” — 摘自[2]
在某些情况下,我们确实看到基准方法在包含复杂的视角依赖反射和非朗伯效应的场景中优于 InstantNGP。作者声称这是由于[2]中使用了较小的前馈网络;详见下文。
(摘自[2])
我们还可以用它做什么? 尽管我们专注于改进 NeRF,但 InstantNGP 的方法相当通用——它可以提高各种计算机图形原语(即,描述外观的函数)的效率。例如,InstantNGP 在[2]中被证明在以下方面有效:
收获
尽管 NeRFs 革新了神经场景表示的质量,但在本综述中我们看到还有很大的改进空间!NeRFs 仍然需要很长时间来训练,并且需要大量的训练数据才能良好工作。下面概述了一些减轻这些问题的基本要点。
提高样本复杂性。 在其原始形式中,NeRF 需要大量的输入观察来进行视图合成。这主要是因为 NeRF 是逐场景训练的,无法利用任何先前的信息来生成新的视图。PixelNeRF [1] 通过将预训练的图像特征作为输入添加到 NeRF 的前馈网络中来缓解这个问题。这种方法允许利用来自其他训练数据的学习到的先验信息。因此,这种方法可以仅凭几张图像就生成场景表示!
更高质量的输入非常重要! 正如 InstantNGP [2] 所示,NeRF 使用的输入编码方案至关重要。使用更丰富、可学习的编码方案可以缩小前馈网络的大小,从而在训练和渲染效率上取得显著提升。在我看来,这种发现可以激发未来大量的工作。我们能找到更好的编码方案吗?是否有其他类型的深度学习模型可以应用这个概念?
局限性。 我们在此概述中看到的方法在解决 NeRF 已知的局限性方面做了很多努力,但它们并不完美。InstantNGP 在 NeRF 训练时间上提供了令人难以置信的加速,但结果场景表示的质量并不总是最佳的。与基线相比,InstantNGP 在捕捉复杂效果如反射方面表现不佳,这表明我们为更快的训练牺牲了表示质量。
“一方面,我们的方法在几何细节丰富的场景中表现最佳… 另一方面,mip-NeRF 和 NSVF 在具有复杂视角依赖反射的场景中超越了我们的方法… 我们将此归因于我们为了获得比这些竞争实现快几个数量级的速度提升而必然使用的更小的 MLP。” — 来自 [2]
此外,由于 PixelNeRF [1] 在其初始前馈组件中分别处理每个输入图像,其运行时间随输入视图数量线性增长。这种线性依赖性可能导致训练和渲染速度相当慢。因此,我们可以解决一些 NeRF 的主要问题,但可能会付出一些代价!
结语
非常感谢你阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的 AI 总监。我研究深度学习的经验和理论基础。你也可以查看我在 medium 上的 其他著作!如果你喜欢这篇文章,请在 twitter 上关注我或订阅我的 Deep (Learning) Focus 新闻通讯,我通过对流行论文的易懂概述来帮助读者更深入地理解深度学习研究中的主题。
参考文献
[1] Yu, Alex 等人。“pixelnerf: 从一张或几张图片中生成神经辐射场。” IEEE/CVF 计算机视觉与模式识别会议论文集。2021 年。
[2] Müller, Thomas 等人。“具有多分辨率哈希编码的即时神经图形原语。” ACM 图形学汇刊(ToG) 41.4(2022 年):1–15。
[3] Lin, Tsung-Yi 等人。“用于目标检测的特征金字塔网络。” IEEE 计算机视觉与模式识别会议论文集。2017 年。
[4] Mildenhall, Ben 等人。“Nerf:将场景表示为神经辐射场以进行视图合成。” ACM 通讯 65.1(2021 年):99–106。
[5] Sitzmann, Vincent, Michael Zollhöfer 和 Gordon Wetzstein。“场景表示网络:连续的 3D 结构感知神经场景表示。” 神经信息处理系统进展 32(2019 年)。
[6] Vaswani, Ashish 等人。“注意力机制是你所需的一切。” 神经信息处理系统进展 30(2017 年)。
超越 NeRFs(第二部分)
成功使用 NeRFs 的技巧和窍门
·发表于 数据科学前沿 ·阅读时长 16 分钟·2023 年 6 月 13 日
–
(照片由 Ashim D’Silva 提供,来源于 Unsplash)
在 3D 场景的表示和渲染领域,神经辐射场(NeRFs)在准确性上取得了巨大的突破。给定几个基础场景的图像,NeRFs 可以从任意视角重建高分辨率的 2D 渲染图像。与先前的技术如 局部光场融合(LLFF) [5] 和 场景表示网络(SRNs) [6] 相比,NeRFs 更能够捕捉场景外观和几何结构的复杂组件(例如,视角依赖的反射和复杂的材料)。
NeRFs 具有颠覆虚拟现实、计算机图形学等应用的潜力。例如,可以设想使用 NeRFs 来重建正在出售的房子的 3D 渲染图像,前提是提供了该房子的在线图像,甚至可以使用基于现实场景训练的 NeRFs 设计视频游戏环境。然而,在其原始形式中,NeRFs 大多是在简单、受控的环境中进行评估的。当对现实世界场景的图像进行训练时,NeRFs 的表现往往不如预期(见下文),这使得它们在实际应用中的实用性降低。
(来自 [2])
在本概述中,我们将深入研究 NeRF,以更好地理解它们在现实世界中表现不佳的原因以及如何解决这个问题。特别是,我们将探讨一些最近的提议,如 NeRF-W [1] 和 def-NeRF [2],这些提议修改了 NeRF,以更好地处理在不受控的噪声环境中拍摄的图像。这些技术通过使 NeRF 能够应用于与大多数实际应用中遇到的数据更接近的图像,从而使 NeRF 更加有用。
(来自 [1, 2])
背景
本概述是我们关于 3D 形状和场景的深度学习系列的一部分。如果你还没有阅读过,我建议阅读这个系列中的 先前帖子,因为它们包含了大量关于 NeRF 和相关技术的有用背景信息。在这里,我们将简要概述 NeRF 和一些其他相关概念(例如,潜在空间、非刚性变形、位置编码等),这些概念将在我们讨论 NeRF-W [1] 和 def-NeRF [2] 时出现。
NeRF 的简要概述
在之前的概述中,我们已经深入讨论了神经辐射场(NeRFS)[3] 的概念。鉴于本概述探讨了扩展和修改 NeRF 以用于现实世界应用,我建议阅读 NeRF 的概述 这里。
快速概览。 要重述 NeRF 的基本理念,它们只是 前馈神经网络,接受 3D 坐标和视角方向作为输入,并生成体积密度和 RGB 颜色作为输出。通过在 3D 空间的不同点(和视角方向)上评估 NeRF,我们可以积累大量关于场景几何和外观的信息,这些信息可以用来渲染该场景的图像(或视图);见下文。
(来自 [3])
要训练 NeRF,我们只需积累几个场景的图像和每张图像的相关 相机姿态信息。然后,我们可以使用这些图像作为目标来训练我们的 NeRF!特别是,我们反复 i) 使用 NeRF 在已知视点处渲染图像,ii) 使用光度损失函数(即测量 RGB 像素值之间的差异)将 NeRF 的输出与实际图像进行比较;见下文。
(来自 [3])
NeRFs 的问题。NeRFs 在 3D 场景表示领域是一个重大突破,但它们也有一些局限性。在一个前期概述中,我们讨论了训练和渲染 NeRFs 的计算负担,以及它们对场景的多张图片的需求。然而,像 InstantNGP [7]和 PixelNeRF [8]这样的技术大幅提高了 NeRFs 的计算和样本效率。
更进一步,NeRFs 假设场景是静态的。实际上,这一假设往往不成立。图像可能包含移动的物体(例如人),这些物体会遮挡场景中的相关部分,甚至可能是在一天的不同时间拍摄的(例如在晚上或早晨)。这些是场景的瞬时成分,可能在一张图像中存在,而在另一张图像中不存在。
“NeRF 的中心限制……是它假设世界在几何、材料和光度上都是静态的。NeRF 要求任何在相同位置和方向下拍摄的两张照片必须是相同的。这一假设在许多现实世界的数据集中是被违背的。” — 引自 [1]
这一静态假设是 NeRF 在不受控制的环境中表现不佳的一个重要因素。在本概述中,我们将探讨如何减轻这一假设,使 NeRFs 能够在我们在实际应用中遇到的不完美现实世界数据集上进行训练。
形状变形入门
为了成功地在嘈杂的智能手机图像上训练 NeRFs,最近的技术将 NeRFs 与可学习的变形场结合在一起。然而,要理解这意味着什么,我们需要了解一般的变形。我们将简要介绍这一概念。
简单来说,变形描述了初始几何形状到最终几何形状的转变(例如,通过相对于某个参考系的位移、平移或形态变换)。我们通常会遇到两种基本类型的变形:
-
刚性变形
-
非刚性变形
对于刚性变形(例如旋转和平移),变形的对象相对于外部参考系发生变化,但相对于内部参考系则保持不变。下面的图片中提供了相关示例。
刚性变形的示例(来自 sci.sdsu.edu)
非刚性变形略有不同。物体相对于内部和外部参考系都会发生变化。因此,非刚性变形可以捕捉到像膨胀和剪切这样的变换;见下文。
非刚性变形的示例(来自 sci.sdsu.edu)
变形场。 变形场是一种表示变形的方法。它通过对 3D 空间中的点进行映射来定义一个变换(即,每个空间中的点被映射到一个新点)。通过根据该场定义的映射重新定位/变换物体,我们可以任意变换物体的形状,类似于上面显示的变形。
其他资源
除了上述讨论,还有一些概念可能会提供对本文内容的更深刻理解。请查看下面的链接以获取相关资源:
发表物
尽管 NeRF 在受控环境中效果良好,但它们在渲染从现实世界中捕获的图像的 3D 场景时遇到困难。在这里,我们将概述两种最近提出的方法,称为 NeRF-W [1]和 def-NeRF [2],它们试图解决这一问题。这些方法可以从一组捕捉不完美的照片(例如,手机拍摄)中渲染出准确的 3D 场景,甚至包含剧烈的光照变化或遮挡物体!
NeRF in the Wild: Neural Radiance Fields for Unconstrained Photo Collections [1]
(来源 [1])
现实世界中的图像往往具有许多不希望出现的特性,这使得训练 NeRF 变得相当困难。例如,考虑尝试训练一个 NeRF,使用几年前拍摄的多个重要地标图像;见上图。这些场景的图像可能在不同的时间(夜晚或白天)拍摄,并且包含任何数量的移动人员或物体,这些人员或物体实际上并不是场景几何的一部分!
在不受控制的情况下,由于 NeRF 的假设是场景是静态的,这使得其在实际应用中往往失败。NeRF-W [1] —— NeRF 的扩展 —— 通过放宽 NeRF 所做的静态假设来缓解这些问题,从而允许在常见的现实问题(例如瞬态对象和光照变化)下准确建模 3D 场景。
(来源 [1])
分解场景。 NeRF 在实际应用中遇到的主要问题可以大致分类如下:
-
光度变化: 一天中的时间和大气条件影响场景的光照/辐射。
-
瞬态对象: 现实世界的场景很少被孤立地捕捉。通常会有一些人或物体在拍摄过程中遮挡或移动。
上述图中说明了这些问题,这些问题都是对静态假设的违反。
光度变化。 为了解决光度变化问题,每张图片都会被分配一个“外观”向量,NeRF-W 在预测输出 RGB 颜色时会将其视为(额外输入)。然而,外观嵌入对预测的体积密度没有影响,体积密度捕捉了场景的 3D 几何形状。这个改变仅通过将 NeRF 的前馈网络分离为几个接受不同输入的组件来实现;详见下文。
外观嵌入影响颜色但不影响体积密度(来自[1])。
通过将 NeRF-W 的 RGB 输出与此外观嵌入条件化,模型可以基于特定图像改变场景的外观,同时确保场景的基本几何形状对外观是不变的,并且在图像之间共享。分配给每张训练图像的独特外观嵌入在训练过程中与模型参数一起优化。
静态与瞬态组件。 为了处理瞬态对象,我们应注意场景包含两种类型的实体:
-
图像依赖组件(即移动/瞬态对象)
-
共享组件(即实际场景)
NeRF-W 使用独立的前馈网络组件来建模图像依赖(瞬态)和共享(静态)场景组件。网络的瞬态部分和静态部分分别输出它们自己的颜色和密度估计,这使得 NeRF-W 能够分离场景的静态和瞬态组件;详见下文。
NeRF-W 的静态和瞬态组件(来自[1])
NeRF-W 的瞬态部分发出一个不确定性场(使用贝叶斯学习框架[4]),允许在训练过程中忽略被遮挡的场景组件。为了确保瞬态效果依赖于图像,每张训练图像都与一个“瞬态”嵌入向量相关联,该向量作为输入提供给 NeRF-W 的瞬态组件。与外观嵌入类似,瞬态嵌入在训练过程中被学习。见下文以获得 NeRF-W 架构的完整描述。
(来自[1])
NeRF-W 的所有组件都通过类似于 NeRF [3]的程序进行联合优化,具体描述可以在此处找到。NeRF-W 使用真实世界中著名地标的照片集合进行评估,这些照片从Photo Tourism dataset中选取。当 NeRF-W 被训练以表示六个地标时,我们看到 NeRF-W 在大多数情况下在定量上优于基线;详见下文。
(来自[1])
我们应该回顾一下,为了进行评估,我们:
-
在对应于单个场景的图像上训练模型。
-
采样一个保留测试图像(及其对应的相机姿态)。
-
使用来自保留图像的相机姿态信息渲染一个视点(使用训练好的模型)。
-
将渲染图像与真实图像进行比较。
对于 NeRF-W,我们没有测试图像的外观或瞬态嵌入。因此,NeRF-W 基于测试图像的一半优化这些嵌入,并在另一半图像上进行评估;见下文。
(来自[1])
当我们检查不同 NeRF 变体的输出时,我们发现 NeRF 的输出往往包含由于训练图像中的瞬态物体而产生的鬼影伪影。相比之下,NeRF-W 生成的渲染图像清晰准确,表明它更能处理现实世界中场景外观的变化;见下文。
(来自[1])
此外,NeRF-W 可以根据不同光照条件的训练图像生成准确的场景渲染。考虑到 NeRF-W 能够根据不同的外观嵌入生成输出,我们可以调整 NeRF-W 的外观嵌入以修改最终渲染的外观;见下文。
(来自[1])
进一步推进这一理念,我们甚至可以在不同训练图像的外观嵌入之间进行插值,从而实现渲染场景外观的平滑变化;见下文。
(来自[1])
Nerfies: 变形神经辐射场 [2]
(来自[2])
现代应用的大多数计算机视觉数据都是通过智能手机捕捉的。考虑到这一点,人们可能会想是否可以使用这些数据训练 NeRF。在[2]中,作者探讨了沿这些思路的具体应用:将随意捕捉的“自拍”图像/视频转换为能够生成真实感渲染的 NeRF。作者将这些模型称为“NeRFies”(即基于 NeRF 的自拍)!
最初,这个应用可能看起来相当具体且无用。我们真的那么在意调整自拍角度吗?这能让我们的 Instagram 帖子更具美感吗? 然而,[2]中提出的方法在几个方面都非常有洞察力:
-
它让我们了解了使用智能手机图像和视频训练 NeRF 的可行性。
-
它提高了 NeRF 处理场景中具有挑战性、细节丰富或捕捉不完美材料的能力。
-
它不仅适用于捕捉自画像,还可以应用于更一般的场景建模应用。
使用 [2] 中提出的技术,我们可以在给定噪声和不完美的手机拍摄图像的情况下生成高质量的场景表示。举个例子,想象一下仅通过在手机上拍摄一个快速视频来生成你自己的 3D 模型。目前的相关方法需要整个专门的实验室,配备同步的灯光和相机!
NeRFs + 变形场。当我们考虑使用手持相机来构建一个人的 3D 模型时,会想到一些可能的困难:
-
相机将会移动(这违反了静态假设!)。
-
人类包含许多复杂的几何形状和材料,这些是难以建模的(例如,头发、眼镜、珠宝等)。
在 [2] 中,作者通过用一个共同优化的非刚性变形场来增强 NeRF,解决了这些挑战,该变形场学习在 3D 空间中变换场景的底层几何形状。
使用变形场变换坐标(来自 [2])
这个变形场通过一个 前馈网络建模,该网络接受位置编码的 3D 坐标和每张图像的潜在变形代码作为输入,然后生成一个非刚性变形的 3D 坐标作为输出;见上文。
def-NeRF 的工作原理。在 [2] 中的方法论,我们称之为可变形神经辐射场(def-NeRF),包含两个组件:
-
变形场:使用前馈神经网络建模 3D 坐标的非刚性变形。
-
NeRF:使用原始的 NeRF 架构来创建底层场景几何形状和外观的模板。
我们将每张训练图像与可学习的变形和外观向量关联。这些潜在代码模拟了 NeRF-W [1] 中使用的每图像嵌入方法,使得变形和外观依赖于图像,从而使 def-NeRF 能够处理场景图像中的变异(例如,光照变化)。
(来自 [2])
def-NeRF 接受 3D 坐标作为输入。这个坐标经过位置编码,并与潜在的变形代码(通过加法)结合,然后传递给建模 def-NeRF 变形场的前馈网络。这个网络的输出是一个变换后的 3D 坐标;见上文。
(来自 [2])
这个变换后的坐标作为输入传递给 NeRF。类似于 NeRF-W [1],我们用一个每张图像的、可学习的外观向量来增强这个 NeRF。给定变换后的坐标、观察方向和一个外观向量作为输入,NeRF 输出一个体积密度和 RGB 颜色;见上文。
(来自 [2])
上述的完整 def-NeRF 架构与原始 NeRF 的 架构和训练策略 几乎相同。主要区别是:
-
变形场的建模。
-
每图像变形和外观向量的使用。
“在渲染时,我们简单地在观察帧中投射光线并采样点,然后使用变形场将采样点映射到模板。” — 来自 [2]
(来自 [2])
为什么这是必要的? def-NeRF 仅在主要的 NeRF 架构上添加了一个变形场,该变形场以非刚性方式变形输入坐标。因此,这种方法将场景表示分解为两部分:
-
场景的几何模型。
-
将这种几何形状变形为期望的视角。
因此,def-NeRF 解除 NeRF 的静态假设,允许在对位移、平移、视角变化等不变的情况下学习基础场景几何。
添加的正则化提高了重建质量(来自 [2])
正则化。 [2] 中的作者观察到学习到的变形场容易陷入局部最小值和过拟合。作为解决方案,我们可以在 def-NeRF 的优化过程中添加额外的正则化;详见上述。采用了几种不同的正则化方案,如 [2] 第 3.3–3.5 节所述。
效果好吗? def-NeRF 主要基于其生成 “Nerfies” (即从任意视角拍摄的逼真渲染图)的能力进行评估。为了创建 Nerfie,用户使用智能手机拍摄他们的脸大约 20 秒。然后,def-NeRF 方法在这些数据上进行训练,并用于从各种新颖的视角渲染自拍。
(来自 [2])
为了评估从新视角生成的这些场景重建的质量,作者构建了一个相机装置,同时从多个视角捕捉对象。这允许使用捕捉同一确切场景的两种不同视角的图像构建验证集;详见上述。
(来自 [2])
当与各种基线进行定量比较时,def-NeRF 在大多数情况下能够生成更高质量的对象重建。值得注意的是,def-NeRF 似乎在处理 PSNR metric 时遇到困难。然而,作者声称该指标偏向模糊图像,不适合评估场景重建。
(来自 [2])
定性地,我们看到相较于基线,def-NeRF 更能够捕捉场景中的细节(例如头发、衬衫褶皱、眼镜等);见上文。此外,该方法适用于一般场景,超越了在 NeRFie 中重建人类主体的范围。总体而言,def-NeRF 在给定手机图像的情况下似乎能够提供高质量的场景重建!
(摘自 [2])
主要收获
尽管 NeRFs(神经辐射场)展示了令人印象深刻的演示效果,但除非我们能将它们应用于现实世界中的图像,否则它们的实际用途有限。在本综述中,我们强调了在实际应用中使用 NeRFs 的主要困难(即静态假设),并概述了一些旨在解决这一问题的近期研究。以下是一些主要的收获。
静态假设。NeRFs 在其原始形式中假设场景是静态的,这意味着从相同位置/方向拍摄的两个场景图像必须是相同的。在实际操作中,这一假设很少成立!人或物体可能在场景中移动,而变化的光照条件可以显著改变图像的外观。在现实世界中部署 NeRFs 需要显著放宽这一假设。
图像依赖的嵌入。现实世界中的场景可以分为图像独立和图像依赖的组件。如果我们希望学习场景的基础几何而不对图像依赖的组件进行过拟合,我们必须根据每张图像定制 NeRF 的输出。对于 NeRF-W 和 def-NeRF,这主要通过添加每图像嵌入向量(即外观、瞬态和变形向量)来实现。然而,未见过/测试图像的每图像嵌入向量的缺乏可能使得这些模型的部署更加困难。
局限性。允许 NeRFs 超越受控环境的应用是重要的,但这并不是 NeRFs 唯一的局限性!这些模型仍然面临着样本效率低和计算复杂性高的问题,正如 之前的文章中讨论的那样。使 NeRFs 适用于实时应用将需要结合解决 NeRFs 每个个别问题的技术。
结语
非常感谢您阅读本文。我是 Cameron R. Wolfe,Rebuy 的 AI 主管。我研究深度学习的经验和理论基础。您还可以查看我在 medium 上的 其他文章!如果您喜欢这篇文章,请在 twitter 上关注我或订阅我的 Deep (Learning) Focus 新闻通讯,我在其中帮助读者通过对热门论文的通俗概述建立对深度学习研究主题的更深理解。
参考文献
[1] Martin-Brualla, Ricardo 等. “Nerf in the wild:用于无约束照片集合的神经辐射场。” IEEE/CVF 计算机视觉与模式识别会议论文集。2021。
[2] Park, Keunhong 等. “Nerfies:可变形神经辐射场。” IEEE/CVF 国际计算机视觉大会论文集。2021。
[3] Mildenhall, Ben 等. “Nerf:将场景表示为神经辐射场以进行视图合成。” ACM 通讯 65.1 (2021): 99–106。
[4] Kendall, Alex 和 Yarin Gal. “在计算机视觉的贝叶斯深度学习中,我们需要哪些不确定性?” 神经信息处理系统进展 30 (2017)。
[5] Mildenhall, Ben 等. “局部光场融合:具有规范化采样指南的实用视图合成。” ACM 图形学会论文 (TOG) 38.4 (2019): 1–14。
[6] Sitzmann, Vincent, Michael Zollhöfer 和 Gordon Wetzstein. “场景表示网络:连续 3D 结构感知神经场景表示。” 神经信息处理系统进展 32 (2019)。
[7] Müller, Thomas 等. “瞬时神经图形原语与多分辨率哈希编码。” ACM 图形学会论文 (ToG) 41.4 (2022): 1–15。
[8] Yu, Alex 等. “pixelnerf:来自一张或少量图像的神经辐射场。” IEEE/CVF 计算机视觉与模式识别会议论文集。2021。
超越 Numpy 和 Pandas:释放鲜为人知的 Python 库的潜力
3 个数据专业人员应该了解的 Python 科学计算库
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 7 月 6 日
–
图片由 OrMaVaredo 提供,来源于 Pixabay
Python 是世界上使用最广泛的编程语言之一,为开发者提供了广泛的库。
无论如何,当涉及数据处理和科学计算时,我们通常会想到 Numpy
、Pandas
或 SciPy
等库。
在这篇文章中,我们介绍了 3 个你可能感兴趣的 Python 库。
1. Dask
介绍 Dask
Dask 是一个灵活的并行计算库,使大规模数据处理的分布式计算和并行处理成为可能。
那么,我们为什么要使用 Dask 呢?正如他们在 他们的网站 上所说:
Python 已经成长为数据分析和通用编程领域的主导语言。这一增长得益于像 NumPy、pandas 和 scikit-learn 这样的计算库。然而,这些包并未设计用于超越单机的规模。Dask 的开发旨在将这些包及其生态系统本地扩展到多核机器和分布式集群,当数据集超出内存时。
因此,Dask 的一个常见用途是 如他们所说:
Dask DataFrame 在 pandas 常用的情况下使用,通常是当 pandas 因数据大小或计算速度而失败时:
操作大数据集,即使这些数据集无法完全加载到内存中
通过使用多个核心加速长时间计算
在大型数据集上进行分布式计算,使用标准 pandas 操作,如 groupby、join 和时间序列计算
所以,当我们需要处理巨大的 Pandas 数据框时,Dask 是一个不错的选择。这是因为 Dask:
允许用户在笔记本电脑上处理 100GB+ 的数据集,或在工作站上处理 1TB+ 的数据集
这是一项相当令人印象深刻的结果。
在幕后发生的情况是:
Dask 数据框协调许多按索引排列的 pandas 数据框/系列。Dask 数据框是按行分区的,通过按索引值分组行以提高效率。这些 pandas 对象可能存储在磁盘上或其他机器上。
所以,我们可以得到类似这样的结果:
Dask 和 Pandas 数据框之间的区别。图像由作者提供,灵感来自于已引用的 Dask 网站上的图像。
Dask 的一些功能展示
首先,我们需要安装 Dask。我们可以通过pip
或conda
来完成,方法如下:
$ pip install dask[complete]
or
$ conda install dask
功能一:打开 CSV 文件
我们可以展示的第一个 Dask 特性是如何打开 CSV 文件。我们可以这样做:
import dask.dataframe as dd
# Load a large CSV file using Dask
df_dask = dd.read_csv('my_very_large_dataset.csv')
# Perform operations on the Dask DataFrame
mean_value_dask = df_dask['column_name'].mean().compute()
所以,正如我们在代码中看到的,使用 Dask 的方式与 Pandas 非常相似。特别是:
-
我们使用
read_csv()
方法,就像在 Pandas 中一样。 -
我们以完全相同的方式截取列,就像在 Pandas 中一样。实际上,如果我们有一个名为
df
的 Pandas 数据框,我们会这样截取列:df['column_name']
。 -
我们对截取的列应用
mean()
方法,类似于 Pandas,但这里我们还需要添加compute()
方法。
此外,即使打开 CSV 文件的方法与 Pandas 相同,在幕后 Dask 也在轻松处理超出单台机器内存容量的大型数据集。
这意味着我们看不到任何实际差异,除了一个事实,那就是大型数据框在 Pandas 中无法打开,但在 Dask 中可以。
功能二:扩展机器学习工作流
我们还可以使用 Dask 创建一个包含大量样本的分类数据集。然后,我们可以将其拆分为训练集和测试集,用 ML 模型拟合训练集,并为测试集计算预测结果。
我们可以这样做:
import dask_ml.datasets as dask_datasets
from dask_ml.linear_model import LogisticRegression
from dask_ml.model_selection import train_test_split
# Load a classification dataset using Dask
X, y = dask_datasets.make_classification(n_samples=100000, chunks=1000)
# Split the data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Train a logistic regression model in parallel
model = LogisticRegression()
model.fit(X_train, y_train)
# Predict on the test set
y_pred = model.predict(X_test).compute()
这个例子强调了 Dask 处理庞大数据集的能力,即使在机器学习问题的情况下,通过将计算分布到多个核心上。
具体来说,我们可以使用方法dask_datasets.make_classification()
创建一个用于分类的“Dask 数据集”,并且我们可以指定样本数量和块(甚至非常庞大!)。
和之前一样,预测结果是通过compute()
方法获得的。
**NOTE:**
in this case, you may need to intsall the module dask_ml.
You can do it like so:
$ pip install dask_ml
功能三:高效的图像处理
Dask 利用的并行处理能力也可以应用于图像。
具体来说,我们可以打开多个图像,调整其大小,并保存调整后的图像。我们可以这样做:
import dask.array as da
import dask_image.imread
from PIL import Image
# Load a collection of images using Dask
images = dask_image.imread.imread('image*.jpg')
# Resize the images in parallel
resized_images = da.stack([da.resize(image, (300, 300)) for image in images])
# Compute the result
result = resized_images.compute()
# Save the resized images
for i, image in enumerate(result):
resized_image = Image.fromarray(image)
resized_image.save(f'resized_image_{i}.jpg')
所以,这里是处理流程:
-
我们用方法
dask_image.imread.imread("image*.jpg")
打开当前文件夹中的所有“.jpg”图像(或者是一个你可以指定的文件夹中的图像)。 -
我们用
da.stack()
方法中的列表推导式将它们调整到 300x300。 -
我们使用方法
compute()
来计算结果,就像以前一样。 -
我们用 for 循环保存所有调整大小的图像。
2. SymPy
介绍 Sympy
如果你需要进行数学计算和运算,并且希望坚持使用 Python,你可以尝试 Sympy。
确实:为何使用其他工具和软件,当我们可以使用我们喜爱的 Python 呢?
根据他们在网站上的描述,Sympy 是:
一个用于符号数学的 Python 库。它旨在成为一个功能齐全的计算机代数系统(CAS),同时保持代码尽可能简单,以便易于理解和扩展。SymPy 完全用 Python 编写。
但为什么使用 SymPy?他们建议:
SymPy 是…
- 免费: 根据 BSD 许可证,SymPy 既是免费的(在言论上),也是免费的(在酒精上)。
- 基于 Python: SymPy 完全用 Python 编写,并且使用 Python 作为其语言。
- 轻量级: SymPy 仅依赖于 mpmath,一个用于任意浮点运算的纯 Python 库,使其易于使用。
- 一个库: 除了作为交互工具使用外,SymPy 可以嵌入到其他应用程序中,并通过自定义函数进行扩展。
所以,它基本上具备了 Python 爱好者喜欢的所有特点!
现在,让我们看看它的一些功能。
SymPy 的某些功能展示
首先,我们需要安装它:
$ pip install sympy
**PAY ATTENTION:**
if you write *$ pip install* *simpy* you'll install another (completely
different!) library.
So, the second letter is a "y", not an "i".
特性一:解决代数方程
如果我们需要解决代数方程,可以像这样使用 SymPy:
from sympy import symbols, Eq, solve
# Define the symbols
x, y = symbols('x y')
# Define the equation
equation = Eq(x**2 + y**2, 25)
# Solve the equation
solutions = solve(equation, (x, y))
# Print solution
print(solutions)
>>>
[(-sqrt(25 - y**2), y), (sqrt(25 - y**2), y)]
所以,这就是过程:
-
我们使用
symbols()
方法定义方程的符号。 -
我们用
Eq
方法编写代数方程。 -
我们用
solve()
方法解决方程。
当我在大学时,我使用过不同的工具来解决这些问题,我不得不说,正如我们所见,SymPy 非常易读且用户友好。
但确实:它是一个 Python 库,那会有什么不同呢?
特性二:计算导数
计算导数是我们在数据分析中可能需要的另一项任务。通常,我们可能需要进行计算,SymPy 确实简化了这个过程。实际上,我们可以这样做:
from sympy import symbols, diff
# Define the symbol
x = symbols('x')
# Define the function
f = x**3 + 2*x**2 + 3*x + 4
# Calculate the derivative
derivative = diff(f, x)
# Print derivative
print(derivative)
>>>
3*x**2 + 4*x + 3
所以,正如我们所见,过程非常简单且自解释:
-
我们使用
symbols()
定义我们要导出的函数的符号。 -
我们定义函数。
-
我们用
diff()
计算导数,指定函数和我们计算导数的符号(这是绝对导数,但对于具有x
和y
变量的函数,我们也可以进行偏导数计算)。
如果我们测试它,我们会看到结果在 2 到 3 秒内到达。所以,它也相当快。
特性三:计算积分
当然,如果 SymPy 可以计算导数,它也可以计算积分。让我们来做一下:
from sympy import symbols, integrate, sin
# Define the symbol
x = symbols('x')
# Perform symbolic integration
integral = integrate(sin(x), x)
# Print integral
print(integral)
>>>
-cos(x)
所以,在这里我们使用方法 integrate()
,指定要积分的函数和积分变量。
难道不更简单吗?!
3. Xarray
介绍 Xarray
Xarray 是一个 Python 库,它扩展了 NumPy 的功能和特性,使我们能够使用标签数组和数据集。
正如他们的网站上所说的:
Xarray 使得在 Python 中处理带标签的多维数组变得简单、高效且有趣!
以及也:
Xarray 在原始的类似 NumPy 的多维数组之上引入了标签,如维度、坐标和属性,从而允许更直观、更简洁且错误更少的开发体验。
换句话说,它通过为数组维度添加标签或坐标来扩展 NumPy 数组的功能。这些标签提供了元数据,并使得对多维数据进行更高级的分析和操作成为可能。
例如,在 NumPy 中,数组是通过基于整数的索引进行访问的。
在 Xarray 中,每个维度可以有一个关联的标签,使得基于有意义的名称理解和操作数据变得更容易。
例如,在 Xarray 中,我们可以使用 arr.sel(x=0, y=1, z=2)
来代替 arr[0, 1, 2]
,其中 x
、y
和 z
是维度标签。
这使得代码更加可读!
所以,让我们看看 Xarray 的一些功能。
Xarray 的一些功能示例
一如既往,要安装它:
$ pip install xarray
功能一:处理带标签的坐标
假设我们想创建一些与温度相关的数据,并希望用坐标如纬度和经度来标记这些数据。我们可以这样做:
import xarray as xr
import numpy as np
# Create temperature data
temperature = np.random.rand(100, 100) * 20 + 10
# Create coordinate arrays for latitude and longitude
latitudes = np.linspace(-90, 90, 100)
longitudes = np.linspace(-180, 180, 100)
# Create an Xarray data array with labeled coordinates
da = xr.DataArray(
temperature,
dims=['latitude', 'longitude'],
coords={'latitude': latitudes, 'longitude': longitudes}
)
# Access data using labeled coordinates
subset = da.sel(latitude=slice(-45, 45), longitude=slice(-90, 0))
如果我们打印它们,我们得到:
# Print data
print(subset)
>>>
<xarray.DataArray (latitude: 50, longitude: 25)>
array([[13.45064786, 29.15218061, 14.77363206, ..., 12.00262833,
16.42712411, 15.61353963],
[23.47498117, 20.25554247, 14.44056286, ..., 19.04096482,
15.60398491, 24.69535367],
[25.48971105, 20.64944534, 21.2263141 , ..., 25.80933737,
16.72629302, 29.48307134],
...,
[10.19615833, 17.106716 , 10.79594252, ..., 29.6897709 ,
20.68549602, 29.4015482 ],
[26.54253304, 14.21939699, 11.085207 , ..., 15.56702191,
19.64285595, 18.03809074],
[26.50676351, 15.21217526, 23.63645069, ..., 17.22512125,
13.96942377, 13.93766583]])
Coordinates:
* latitude (latitude) float64 -44.55 -42.73 -40.91 ... 40.91 42.73 44.55
* longitude (longitude) float64 -89.09 -85.45 -81.82 ... -9.091 -5.455 -1.818
所以,让我们一步步来看这个过程:
-
我们已经创建了一个温度值的 NumPy 数组。
-
我们已经将纬度和经度值定义为 NumPy 数组。
-
我们已经使用方法
DataArray()
将所有数据存储在一个 Xarray 数组中。 -
我们使用方法
sel()
选择了纬度和经度的一个子集,这个方法选择了我们想要的子集值。
结果也很容易读取,因此标签在很多情况下确实很有帮助。
功能二:处理缺失数据
假设我们正在收集一年中的温度数据。我们想知道数组中是否有一些空值。方法如下:
import xarray as xr
import numpy as np
import pandas as pd
# Create temperature data with missing values
temperature = np.random.rand(365, 50, 50) * 20 + 10
temperature[0:10, :, :] = np.nan # Set the first 10 days as missing values
# Create time, latitude, and longitude coordinate arrays
times = pd.date_range('2023-01-01', periods=365, freq='D')
latitudes = np.linspace(-90, 90, 50)
longitudes = np.linspace(-180, 180, 50)
# Create an Xarray data array with missing values
da = xr.DataArray(
temperature,
dims=['time', 'latitude', 'longitude'],
coords={'time': times, 'latitude': latitudes, 'longitude': longitudes}
)
# Count the number of missing values along the time dimension
missing_count = da.isnull().sum(dim='time')
# Print missing values
print(missing_count)
>>>
<xarray.DataArray (latitude: 50, longitude: 50)>
array([[10, 10, 10, ..., 10, 10, 10],
[10, 10, 10, ..., 10, 10, 10],
[10, 10, 10, ..., 10, 10, 10],
...,
[10, 10, 10, ..., 10, 10, 10],
[10, 10, 10, ..., 10, 10, 10],
[10, 10, 10, ..., 10, 10, 10]])
Coordinates:
* latitude (latitude) float64 -90.0 -86.33 -82.65 ... 82.65 86.33 90.0
* longitude (longitude) float64 -180.0 -172.7 -165.3 ... 165.3 172.7 180.0
于是我们得到有 10 个空值。
同样,如果我们仔细看看代码,我们可以看到可以将 Pandas 的方法应用于 Xarray,如 isnull.sum()
,在这个例子中,这个方法计算了缺失值的总数。
功能一:处理和分析多维数据
当我们有可能给数组加标签时,处理和分析多维数据的诱惑很大。那么,为什么不试试呢?
例如,假设我们仍在收集与特定纬度和经度相关的温度数据。
我们可能希望计算均值、最大值和中位数温度。我们可以这样做:
import xarray as xr
import numpy as np
import pandas as pd
# Create synthetic temperature data
temperature = np.random.rand(365, 50, 50) * 20 + 10
# Create time, latitude, and longitude coordinate arrays
times = pd.date_range('2023-01-01', periods=365, freq='D')
latitudes = np.linspace(-90, 90, 50)
longitudes = np.linspace(-180, 180, 50)
# Create an Xarray dataset
ds = xr.Dataset(
{
'temperature': (['time', 'latitude', 'longitude'], temperature),
},
coords={
'time': times,
'latitude': latitudes,
'longitude': longitudes,
}
)
# Perform statistical analysis on the temperature data
mean_temperature = ds['temperature'].mean(dim='time')
max_temperature = ds['temperature'].max(dim='time')
min_temperature = ds['temperature'].min(dim='time')
# Print values
print(f"mean temperature:\n {mean_temperature}\n")
print(f"max temperature:\n {max_temperature}\n")
print(f"min temperature:\n {min_temperature}\n")
>>>
mean temperature:
<xarray.DataArray 'temperature' (latitude: 50, longitude: 50)>
array([[19.99931701, 20.36395016, 20.04110699, ..., 19.98811842,
20.08895803, 19.86064693],
[19.84016491, 19.87077812, 20.27445405, ..., 19.8071972 ,
19.62665953, 19.58231185],
[19.63911165, 19.62051976, 19.61247548, ..., 19.85043831,
20.13086891, 19.80267099],
...,
[20.18590514, 20.05931149, 20.17133483, ..., 20.52858247,
19.83882433, 20.66808513],
[19.56455575, 19.90091128, 20.32566232, ..., 19.88689221,
19.78811145, 19.91205212],
[19.82268297, 20.14242279, 19.60842148, ..., 19.68290006,
20.00327294, 19.68955107]])
Coordinates:
* latitude (latitude) float64 -90.0 -86.33 -82.65 ... 82.65 86.33 90.0
* longitude (longitude) float64 -180.0 -172.7 -165.3 ... 165.3 172.7 180.0
max temperature:
<xarray.DataArray 'temperature' (latitude: 50, longitude: 50)>
array([[29.98465531, 29.97609171, 29.96821276, ..., 29.86639343,
29.95069558, 29.98807808],
[29.91802049, 29.92870312, 29.87625447, ..., 29.92519055,
29.9964299 , 29.99792388],
[29.96647016, 29.7934891 , 29.89731136, ..., 29.99174546,
29.97267052, 29.96058079],
...,
[29.91699117, 29.98920555, 29.83798369, ..., 29.90271746,
29.93747041, 29.97244906],
[29.99171911, 29.99051943, 29.92706773, ..., 29.90578739,
29.99433847, 29.94506567],
[29.99438621, 29.98798699, 29.97664488, ..., 29.98669576,
29.91296382, 29.93100249]])
Coordinates:
* latitude (latitude) float64 -90.0 -86.33 -82.65 ... 82.65 86.33 90.0
* longitude (longitude) float64 -180.0 -172.7 -165.3 ... 165.3 172.7 180.0
min temperature:
<xarray.DataArray 'temperature' (latitude: 50, longitude: 50)>
array([[10.0326431 , 10.07666029, 10.02795524, ..., 10.17215336,
10.00264909, 10.05387097],
[10.00355858, 10.00610942, 10.02567816, ..., 10.29100316,
10.00861792, 10.16955806],
[10.01636216, 10.02856619, 10.00389027, ..., 10.0929342 ,
10.01504103, 10.06219179],
...,
[10.00477003, 10.0303088 , 10.04494723, ..., 10.05720692,
10.122994 , 10.04947012],
[10.00422182, 10.0211205 , 10.00183528, ..., 10.03818058,
10.02632697, 10.06722953],
[10.10994581, 10.12445222, 10.03002468, ..., 10.06937041,
10.04924046, 10.00645499]])
Coordinates:
* latitude (latitude) float64 -90.0 -86.33 -82.65 ... 82.65 86.33 90.0
* longitude (longitude) float64 -180.0 -172.7 -165.3 ... 165.3 172.7 180.0
我们达到了我们想要的结果,并且方式也很清晰易读。
再次如前所述,我们通过 Pandas 的函数应用于一个数组来计算最大值、最小值和均值。
结论
在这篇文章中,我们展示了三种用于科学计算的库。
虽然 SymPy 可以替代其他工具和软件,让我们有可能使用 Python 代码进行数学计算,但 Dask 和 Xarray 扩展了其他库的功能,帮助我们在遇到其他最知名的 Python 数据分析和处理库的困难时应对。
费德里科·特罗塔
嗨,我是费德里科·特罗塔,我是一名自由职业的技术作家。
想与我合作?联系我。