TowardsDataScience 2023 博客中文翻译(二百七十八)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

图形神经网络中的罗马数字分析

原文:towardsdatascience.com/roman-numeral-analysis-with-graph-neural-networks-4d6140cd4c0b?source=collection_archive---------9-----------------------#2023-10-24

入门指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Emmanouil Karystinaios

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 10 月 24 日

在这篇文章中,我想解释一下我在开发自动和声分析模型过程中的经历。对我个人而言,我对深入理解音乐感兴趣。像“为什么事情是这样结构的?”和“作曲家或艺术家在创作这部作品时在想什么?”这样的问题对我很重要。自然,我的起点是分析一部作品的内在和声。

在从温室里找回我旧的笔记本时,我偶然发现了我们用来注释和分析小型音乐片段的技术。这被称为罗马数字分析。这个概念可能有点复杂,如果你之前从未听说过它,但请耐心听我讲解。

我的目标是建立一个可以自动分析乐谱的系统。给定一个乐谱,系统将返回同样的乐谱,并在其中添加一个包含罗马数字和弦的附加五线谱。这主要适用于古典调性音乐,但并不限于此。

在本文的其余部分,我将介绍罗马数字、图神经网络的概念,并讨论我开发的模型及其结果。希望你喜欢!

罗马数字简介

罗马数字分析是一种用于理解和分析音乐中和弦及和声进行的方法,特别是在西方古典音乐和流行音乐中。和弦使用罗马数字而非传统音乐记谱法表示。

在罗马数字分析中,每个和弦根据其在给定调性中的位置和功能被分配一个罗马数字。罗马数字表示调性的音阶度数,大写数字表示大调和弦,小写数字表示小调和弦。

例如,在 C 大调中,C 大调和弦用罗马数字“I”表示(大写“I”表示大调和弦)。D 小调和弦用“ii”表示(小写“ii”表示小调和弦)。G 大调和弦用“V”表示(大写“V”表示大调和弦),因为它是 C 大调中的第五和弦。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 C 大调中,四声部和声的两个小节的罗马数字分析示例。

罗马数字总是相对于一个调性。因此,如果调性是 C 大调,那么罗马数字“V”将是属和弦或 G 大调和弦。但和弦确实有不同的性质,例如小调或大调。在罗马数字中,大写字母表示大调性质,小写字母表示小调性质。

在音乐分析中,通常最低音是和弦特性的参考点。罗马数字也能传达这一信息。在上面的例子中,第二个和弦的低音(最低和弦音)是 F#,但和弦的根音是 D,因此和弦处于第一转位,用数字 6 表示。

罗马数字的另一个有趣的标记能力与借用和弦有关。这种效果称为副级,隐含地,每个罗马数字(主要)都有一个副级的主音(即 I 或 i),然而,当副级被标注时,我们可以知道哪个音阶度数暂时充当主音。上例中的第三个和弦,其主要度数为属七和弦,副级为 C 大调的属和弦。V65 表示在第二转位中的七和弦。

罗马数字分析帮助音乐家和音乐理论家理解音乐作品中的和弦结构和关系。它使他们能够识别常见和弦进行,分析和声模式,并比较不同的音乐作品。这是作曲家、编曲家和表演者理解潜在和声并根据这些知识做出音乐决策的有用工具。

自动罗马数字分析

现在我们有了关于罗马数字分析在实践中是如何进行的基础,我们可以讨论如何自动化它。在本文中,我们将介绍一种从符号音乐中预测罗马数字的方法,即数字乐谱(MusicXML、MIDI、Mei、Kern、MuseScore 等)。请注意,您可以从任何乐谱编辑软件中获取这些格式,如 Finale、Sibelius、MuseScore 或其他任何软件。通常,这些软件允许导出为 musicxml(未压缩)格式。不过,如果您没有这些编辑器,我建议使用 MuseScore。

现在我们将更深入地讨论这些表示方式。与音频表示方式不同,音乐可以在波形级别上视为数字序列,或在频域上视为二维频谱图,而符号表示法则具有包含起始时间、持续时间和音高拼写(音符名称)等信息的单独音符事件。符号表示法通常被视为伪音频表示,将乐谱分解为量化的时间框架,例如下图所示的钢琴卷轴。然而,最近一些研究提出了一种乐谱的图形表示方法,其中每个音符代表图中的一个顶点,边表示音符之间的关系。对于后一种方法,乐谱可以转换为这种图结构,这在涉及机器学习模型时特别有用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分数摘录的不同表示方式显示在中间。顶部:量化时间框架表示,底部:图形表示。

因此,给定一个符号乐谱,图形是通过建模音符之间的三种关系来构建的。

  • 音符同时开始,即相同的起始时间。

  • 一个音符在另一个音符结束时开始,即连续音符。

  • 一个音符在另一个音符发声时开始,即在连接期间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

乐谱的图形可以作为图神经网络的输入,图神经网络通过沿图的边传播信息来隐式学习。但在解释模型如何在乐谱上工作的之前,让我们首先简要解释图神经网络的工作原理。

图神经网络

那么,图神经网络到底是什么呢?本质上,GNN 是一种深度学习模型,旨在处理表示为图的数据。就像现实世界中的网络一样,图由相互连接的节点或顶点组成,每个节点都有其独特的特征。GNN 利用这种互联性来捕捉丰富的关系和依赖,从而执行分析和预测任务。

那么 GNN 是如何工作的呢?想象一个音乐乐谱,其中每个音符都是一个节点,音符之间的关系表示它们之间的连接。传统模型会将每个音符实例单独处理,忽略音乐背景。然而,GNN 通过同时考虑个体的特征(例如音高拼写、持续时间)和它们的关系(相同起始点、连续)来拥抱这种背景。通过聚合来自邻近节点的信息,GNN 使我们能够理解不仅是单个音符,还有整个网络中的动态和模式。

为了实现这一点,GNN 使用了一系列迭代的消息传递步骤。在每一步中,节点从其邻居那里收集信息,更新自身的表示,并将这些更新后的特征进一步传播通过网络。这一迭代过程使得 GNN 能够捕捉和完善来自附近节点的信息,逐步构建对整个图的全面理解。

迭代地进行的消息传递过程有时被称为图卷积。我们在音乐分析模型中使用的一个流行的图卷积块叫做 SageConv,来自著名的 GraphSAGE 论文。我们在这里不会详细讲解,但有许多资料涵盖了 GraphSAGE 的功能,例如 这个。

GNN 的美妙之处在于它们从图数据中提取有意义的表示的能力。通过从局部上下文中学习并结合全局信息,GNN 能够发现隐藏的模式,做出准确的预测,甚至生成新的见解。这使它们在从社交网络分析到药物发现,从交通预测到欺诈检测,再到音乐分析等广泛领域中都显得非常宝贵。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

用于罗马数字分析的模型叫做 ChordGNN。

正如其名,ChordGNN 是一个基于图神经网络的自动罗马数字分析模型。该模型的一个特点是利用了逐音符的信息,但生成的是逐个起始点的预测,即为乐谱中的每个独特起始事件预测一个罗马数字。这意味着在同一个起始点的多个音符将共享相同的罗马数字,就像为乐谱做标注一样。然而,通过使用图卷积,来自每个音符的信息被传递到邻近的音符和起始点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

ChordGNN 模型架构示意图。

ChordGNN 基于图卷积递归神经网络架构,由堆叠的 GraphSAGE 卷积块组成,这些块在音符级别上操作。

图卷积操作后是一个 Onset-Pooling 层,它将音符表示收缩到起始级别,从而为乐谱中的每个唯一起始点生成一个向量嵌入。这是一个重要步骤,因为它将表示从图形移动到序列。

由 Onset-Pooling 获得的嵌入(这些嵌入按时间顺序排列)随后被输入到一个顺序模型中,例如 GRU 堆栈。最后,为每个描述罗马数字的属性添加简单的多层感知机分类器。因此,ChordGNN 也是一个多任务模型。

ChordGNN 并不会直接预测乐谱中每个位置的罗马数字,而是预测度数、局部调性、质量、反转和根音。通过分析每个任务的预测,将每个属性任务的预测组合成一个单一的罗马数字预测。让我们看看输出预测的样子。

ChordGNN 预测示例

在这一部分,我们将查看一些 ChordGNN 的预测,甚至与人工分析进行比较。下面是海顿弦乐四重奏 op.20 №3 第 4 乐章的前几个小节的示例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

人工注释与 ChordGNN 在海顿弦乐四重奏中的对比

弦乐四重奏 op.20 №3 第 4 乐章

在这个示例中,我们可以看到几个方面。在第 2 小节中,人为注释标记了第一个反转中的主和弦;然而,当时的中提琴低于大提琴,因此和弦实际上处于根位置。ChordGNN 能够正确预测这一点。随后,ChordGNN 预测了八分音符的和声节奏,这与注释者的半音符标记不符。通过分析该段落的基本和声,我们可以为我们的 ChordGNN 的选择提供合理解释。

人工注释建议第 2 小节的整个后半部分表示一个 viio 和弦。然而,它不应处于第一个反转,因为大提琴演奏的 F# 是最低音(这是 viio 的根音)。然而,对该段落有两种相互冲突的解释。首先,第三拍的 viio 被视为围绕主和弦的经过和弦,导致下一小节的属和弦。或者,viio 可能已经是一个延续的属和声的一部分(在弱拍上有经过和弦),并导致 V7。ChordGNN 的解决方案兼顾了这两种解释,因为它不试图在更高层次上对和弦进行分组,而是将每个八分音符视为独立和弦,而非经过事件。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

人工注释与 ChordGNN 在莫扎特钢琴奏鸣曲 K279 第 1 乐章中的对比。图片由作者提供

上面是另一个例子,将ChordGNN的预测与莫扎特钢琴奏鸣曲的原始分析进行比较。在这种情况下,ChordGNN的分析略显简单,选择省略了一些和弦。这在两个不同的场合发生,主要和弦七和弦在第 4 转位(V2)中。这对于ChordGNN来说是一个合理的假设,因为缺少了低音。另一个不一致之处发生在接近结尾的半终止。ChordGNN将旋律中的 C#视为过渡音,而注释者则选择指定#11 的扩展。

结论

在本文中,我们讨论了一种使用图神经网络自动化罗马数字分析的新方法。我们讨论了 ChordGNN 模型的工作原理,并展示了它的一些预测结果。

参考文献

E. Karystinaios, G. Widmer. 罗马数字分析与图神经网络:基于音符特征的起始预测。国际音乐信息检索会议(ISMIR),2023 年会议录。

资源

## GitHub - manoskary/ChordGNN: 这是论文的代码库:罗马数字分析与图…

这是论文的代码库:罗马数字分析与图神经网络 - GitHub - manoskary/ChordGNN…

github.com

本文中的所有图像和图形均由作者创建。

轮换值班以进行操作和支持:数据团队的必需品

原文:towardsdatascience.com/rotating-on-call-for-operational-and-support-a-must-for-data-teams-74b9af592253

一个轮换的值班安排用于操作、支持和技术部门,使团队的其他成员能够专注于优秀的开发工作

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 本杰明·图雷尔

·发表于 数据科学前沿 ·阅读时间 7 分钟·2023 年 6 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个数据科学或产品团队面临的一个共同挑战是将新的(产品开发)任务与旧的(操作、支持)任务对齐。当整个团队被要求同时处理两者时,这意味着一方面团队需要满足产品截止日期并推出新产品功能,同时另一方面,团队还需要处理操作工作、修复现有产品以及支持商业问题和电话。这种情况导致意外的上下文切换,并最终导致效率降低、截止日期失败和压力增加。

实际上,这通常会导致某些团队成员承担额外任务或专门负责这些任务。但这很危险,因为一旦这些专门的团队成员休假,公司可能会感到影响并面临问题。

因此,一个高效且可扩展的数据团队需要同时支持操作和新开发工作,并创建一个包括以下内容的系统:

  • 团队成员之间良好的知识分享,了解如何进行操作工作和支持产品/客户

  • 不间断的开发工作,减少上下文切换

  • 明确且估计的维护工作,以避免意外截止日期

轮换值班系统

我们过去发现非常有效的一个系统是轮换值班系统,它处理的不仅仅是生产中的警报。简单来说,这是一个轮换系统,其中一个(或多个)团队成员在特定时间内被指定为值班人员,完全负责操作工作。

值班人员不仅仅是在做一份工作,他们是在保护整个团队免受开发工作之外的所有混乱

为了完成这一点,该系统允许只有值班人员(指定的幸存者)处理所有不属于“新开发”的工作。在这段时间里,值班人员不仅仅是在完成工作,而是保护整个团队免受开发工作之外的所有混乱,包括:

  • 修复生产管道问题

  • 回答商业/客户问题

  • 支持客户电话

  • 减少技术债务(积压)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

概述值班例行工作中的具体任务。

如上图所示,处理“经典”值班系统并确保生产环境正常运行仍然是最重要的。然而,如果生产环境没有问题,这就可以腾出时间处理其他任务,如支持商业请求、客户电话或减少积压。

有什么好处?

切换到该系统最初可能不容易。不是每个团队成员都能负责生产管道、商业支持和技术债务。但这不应该成为障碍。重要的是要妥善沟通,表明值班人员拥有这些项目,是第一道防线可以随时寻求帮助。

从长远来看,这将为团队和整个组织带来很多好处。最直观的好处是更容易估算开发工作,团队将变得更加高效(减少上下文切换)。这同样适用于运营方面,其中参与值班系统的人数决定了可能的运营工作量。这使得与公司和利益相关者的沟通变得更加容易,因为一个有 5 个人的团队中有 1 个人在轮班,这意味着 1/5 的全职员工维护所有系统和现有产品相关的工作(20%运营,80%开发)。这很容易计算和估算。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用轮班制的团队中 20%-80%的运营-开发分配示意图。

然而,随着时间的推移,几乎作为副作用出现了更多好处。所有团队成员将成为全栈数据科学家。原因是每个团队成员需要了解涉及的产品、客户、系统、模型/逻辑和代码基础设施的最低限度。他们不需要成为专家,但最终会变得足够好,能够独立处理这些工作至少 1 周。这也确保了当有价值的团队成员度假时没有问题,因为值班人员将始终支持团队。

此外,尽管值班时间有时可能会更具压力,但它给数据科学家提供了观察团队外部情况的机会,并与商业方面和客户进行合作。这可以是非常有价值和有回报的经历。

如何设置这样的系统?

这里会稍微有点技术性(对于喜欢代码的人,可以直接滚动到最底部)。设置这样的系统相当简单,但可能需要一些编码。最重要的是与团队和相关人员沟通,并告知他们如何进行。

由于系统的核心目的是支持团队,而不是增加更多的管理负担,我强烈推荐完全自动化它。为此,你需要至少部署 3 个系统:

  • 一个与生产系统连接的呼叫系统,当生产失败时会发出警报(例如,OpsgeniePagerduty)。

  • 一个可以检测谁在值班并将该信息传达给另一个系统的调度系统(例如,Apache AirflowKeboola)。

  • 一个用于联系团队并创建票务的通信平台(例如,SlackTeams)。

如果你已经部署了这些系统,并且拥有对呼叫系统和通信平台的 API 访问权限,那么你几乎完成了。剩下的唯一工作就是在调度系统中设置一个作业,该作业首先运行 API 调用,以从呼叫系统中获取当前值班人员的信息,然后再进行 API 推送,以在通信平台中进行通讯或覆盖渠道/组/标签。

以下是一个简单的 API 调用示例,它将提供 Opsgenie 中的值班人员:

curl -X GET \
'https://api.opsgenie.com/v2/schedules/{schedule_name}/on-calls?scheduleIdentifierType=name&flat=true' \
--header 'Authorization: GenieKey {token}'

之后,你需要运行一个在通信系统中执行某些操作的命令。例如,在 Slack 中,覆盖一个用户组,以便只包含值班的用户:

curl -X POST \
-F usergroup={usergroup} \
-F users={user} \
'https://slack.com/api/usergroups.users.update' \
-H 'Authorization: Bearer {token}'

在这个故事的结尾,你会发现一个完整的代码版本,展示如何自动调度这些代码。这将确保每当有人在 Slack 上标记你的组(例如 @ team)时,只有值班人员会被标记,并可以决定是否需要通知更多团队成员。它还允许你快速向 DAG 添加新任务。例如,当你想通知公司或团队谁现在正在值班时,或调整你的票务系统时。

总结

为团队的运营、商业和技术部门工作设置轮换计划,可以提高你的数据团队的效率。这将减少上下文切换,并允许更好的时间估算。此外,它还将培养能够处理各种问题的全栈数据科学家,以保护其余的团队。

所有图片,除非另有说明,均由作者提供。

代码附录:

这是一个 Airflow dag 的示例,它从 Opsgenie 中获取当前值班人员,并覆盖 Slack 中的用户组,使其仅包含该人员。代码确实不完美(数据科学家在工作中),但我相信你明白了:

# Import
from airflow import DAG, XComArg
from typing import Dict, List
from airflow.operators.bash import BashOperator
from airflow.operators.python import PythonOperator
from airflow.models import Variable
import json

# Fetch secret tokens
slack_token = Variable.get("slack_token")
opsgenie_token = Variable.get("opsgenie_token")

# Setup DAG
dag = DAG(
    dag_id,
    schedule_interval=schedule_interval,
    default_args=default_args,
    catchup=catchup,
    max_active_runs=max_active_runs,
)
with dag:

    # Run BashOperator fetching from Opsgenie who is on call
    def fetch_who_is_on_call(**kwargs):
        fetch_who_is_on_call_bash = BashOperator(
            task_id="fetch_who_is_on_call_bash",
            bash_command="""
            curl -X GET \
            'https://api.opsgenie.com/v2/schedules/{schedule_name}/on-calls?scheduleIdentifierType=name&flat=true' \
             --header 'Authorization: GenieKey {token}'
            """.format(
                schedule_name="schedule_name", 
                token=opsgenie_token
             ),
            dag=dag,
        )
        return_value = fetch_who_is_on_call_bash.execute(context=kwargs)
        fetch_who_is_on_call_bash
        return return_value

    # run BashOperator in PythonOperator and provide context
    opsgenie_pull = PythonOperator(
        task_id="opsgenie_pull",
        python_callable=fetch_who_is_on_call,
        provide_context=True,
        dag=dag,
    )

    # Overwrite slack group with the person on call
    def overwrite_slack_group(**kwargs):

        # First: get who is on call from PythonOperator
        ti = kwargs.get("ti")
        xcom_return = json.loads(ti.xcom_pull(task_ids="opsgenie_pull"))
        user_email = xcom_return["data"]["onCallRecipients"][0]

        user_dict = {
            "data_scientist_a": "A03BU00KGK4",
            "data_scientist_b": "B03BU00KGK4",
        }
        user_id = [
            user_dict[k] for k in user_dict.keys() if k == user_email.split(".")[0]
        ]

        # Second: Run BashOperator to overwrite slack group
        overwrite_slack_group_bash = BashOperator(
            task_id="overwrite_slack_group_bash",
            bash_command="""
            curl -X POST \
            -F usergroup={usergroup} \
            -F users={user} \
            https://slack.com/api/usergroups.users.update \
            -H 'Authorization: Bearer {token}'
            """.format(
                usergroup="usergroup_id",
                user=user_id,
                token=slack_token,
            ),
            dag=dag,
        )
        overwrite_slack_group_bash.execute(context=kwargs)
        overwrite_slack_group_bash

    # Run BashOperator for slack overwrite in PythonOperator
    overwrite_slack = PythonOperator(
        task_id="overwrite_slack",
        python_callable=overwrite_slack_group,
        provide_context=True,
        dag=dag,
    )

    opsgenie_pull >> overwrite_slack
    return dag

使用 Rasterio 旋转栅格

原文:towardsdatascience.com/rotating-rasters-with-rasterio-dc36e42b01dd

使用 Python 旋转卫星图像,同时保持地理位置准确性

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Conor O’Sullivan

·发布于Towards Data Science ·6 分钟阅读·2023 年 8 月 7 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来源:作者)

栅格数据类似于普通图像数据。不同之处在于每个像素都与地球表面上的位置相关联。这使得问题复杂化。如果我们想要旋转数据,还必须考虑基础的坐标参考系统(CRS)。在不调整地理位置的情况下扭曲栅格会导致空间分析不准确。

调整地理位置并不简单。幸运的是,Rasterio可以提供帮助。这是一个流行的用于地理空间数据分析的 Python 库。我们将使用该包来:

  • 旋转栅格

  • 重新投影图像到正确的坐标参考系统(CRS)。

在此过程中,我们将讨论 Python 代码,你可以在GitHub上找到完整的项目。

本文假设读者具有基本的栅格数据知识和处理其 CRS 的经验。如果你想复习,可以查看下面的文章。它详细介绍了栅格数据的重新投影。

## 如何在 Landsat 卫星图像上绘制坐标

使用 Landsat 元数据和 Rasterio 将像素位置映射到地理坐标

towardsdatascience.com

下载 Landsat 场景

对于我们的栅格数据,我们将处理卫星图像。具体来说,是 Landsat 场景。你可以通过EarthExplorer门户下载其中之一。或者,如果你想使用 Python,下面的文章将带你完成这个过程:

## 使用 Python 下载 Landsat 卫星图像

使用 landsatxplore Python 包简化 Landsat 场景下载

[towardsdatascience.com

最终,你应该会有一个包含所有文件的文件夹,这些文件是Landsat 2 级科学产品。我们将使用红色可见光波段。对于 Landsat 8 或 9 场景,这由波段 B4 表示。

打开栅格文件

我们使用下面的代码来打开和显示这个波段。ID 给出了这个特定场景的 Landsat 场景 ID(第 8 行)。所有可用的波段都将存放在一个以此 ID 命名的文件夹中。我们使用 rasterio 打开红色波段(第 11 行),并使用 matplotlib 显示它(第 14-15 行)。如图 1所示,Landsat 场景通常在其边界框内被旋转。

import matplotlib.pyplot as plt
import rasterio as rio

# Path to our raster file
data_file = "./data/"

# ID of the raster we want to open
ID = "LC08_L2SP_175083_20131218_20200912_02_T1"

# Open the red band (B4):
B4 = rio.open(data_file + '{}/{}_SR_B4.TIF'.format(ID, ID))

# Display the band
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(B4.read(1), cmap='pink')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:红色可见光波段的可视化(来源:作者)

旋转栅格

旋转栅格的关键在于其变换函数。Rasterio 使用仿射变换将数组位置转换为地理位置。对于我们的 Landsat 场景,地理位置是以 UTM 坐标给出的。如果我们旋转栅格中的像素,我们还必须调整此变换。

我们的卫星图像的仿射变换矩阵如下所示(第 2 行)。xy()函数使用此矩阵将数组位置转换为 UTM 坐标。图 2 中的输出显示了数组位置(1000,2000)与地球表面上的地理位置(222900,-3617400)相关联。

print(B4.crs) # Gives coordinate reference system
print(B4.transform) # Affine transformation matrix

# convert array positions to UTM coordinates
x,y = (1000,2000)
utmx,utmy = B4.xy(y,x)
print("\n"+ str((utmx,utmy))) 

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:上述代码片段的输出(来源:作者)

我们在rotate_raster函数中调整矩阵。最重要的参数是旋转角度。我们首先通过创建一个旋转仿射矩阵(第 22 行)来使用此参数。然后将其与原始变换矩阵相乘(第 27 行)。现在,当点转换为 UTM 时,它们也将被旋转,反之亦然。新的栅格随后使用此 CRS 进行重投影(第 52-58 行)。我们将在接下来的部分中讨论其他参数。

from rasterio.warp import reproject, Resampling
from affine import Affine
import numpy as np

def rotate_raster(in_file,out_file, angle, shift_x=0, shift_y=0,adj_width=0, adj_height=0):
    """Rotate a raster image and save it to disk.
            in_file: path to input raster file
            out_file: path to output raster file
            angle: angle of rotation in degrees
            shift_x: shift in x direction
            shift_y: shift in y direction
            adj_width: adjust width of output raster
            adj_height: adjust height of output raster"""

    with rio.open(in_file) as src:

        # Get the old transform and crs
        src_transform = src.transform 
        crs = src.crs

        # Affine transformations for rotation and translation
        rotate = Affine.rotation(angle)
        trans_x = Affine.translation(shift_x,0)
        trans_y = Affine.translation(0, -shift_y)

        # Combine affine transformations
        dst_transform = src_transform * rotate * trans_x * trans_y

        # Get band data
        band = np.array(src.read(1))

        # Get the new shape
        y,x = band.shape
        dst_height = y + adj_height
        dst_width = x + adj_width

        # set properties for output
        dst_kwargs = src.meta.copy()
        dst_kwargs.update(
            {
                "transform": dst_transform,
                "height": dst_height,
                "width": dst_width,
                "nodata": 0,  
            }
        )

        # write to disk
        with rio.open(out_file, "w", **dst_kwargs) as dst:
            # reproject to new CRS

            reproject(source=band,
                        destination=rio.band(dst, 1),
                        src_transform=src_transform,
                        src_crs=crs,
                        dst_transform=dst_transform,
                        dst_crs=crs,
                        resampling=Resampling.nearest)

现在,让我们看看该函数的工作原理。我们的输入文件与图 1 中显示的红色波段相同(第 1 行)。我们在此位置定义一个新的文件路径(第 2 行)。我们将这些输入到我们的rotate_raster函数中,并设置 12 度的旋转角度(第 4 行)。你可以在图 3 中看到结果栅格。它不再在其边界框内旋转。重要的是,你仍然可以在新的栅格上绘制坐标。

file = data_file + '{}/{}_SR_B4.TIF'.format(ID, ID)
out_file = data_file + '{}/{}_SR_B4_rotated.TIF'.format(ID, ID)

rotate_raster(file,out_file, 12, shift_x=600, shift_y=700)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:旋转后的 Landsat 图像(来源:作者)

移动栅格

你可能已经注意到上面的代码中有shift_xshift_y参数。这些参数在其边界框内沿 x 和 y 方向移动栅格。我们需要这些参数,因为栅格是围绕其左上角旋转的。同时,使用了原始的高度和宽度。结果是卫星图像的一部分被旋转到其边界框之外。你可以在图 4中看到我们所指的内容。

out_file = data_file + '{}/{}_SR_B4_noshift.TIF'.format(ID, ID)

rotate_raster(file,out_file, 12)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:移动前后栅格的对比(来源:作者)

调整尺寸

在某些情况下,我们可能还需要调整栅格的高度和宽度。如果我们以不适合原始尺寸的方式旋转它,就会发生这种情况。你可以在图 5 中看到这一点,其中栅格已旋转了 30 度。除非我们将宽度和高度增加 800 像素,否则它将无法适应其边界框。

#With dimensions adjustment
out_file_1 = data_file + '{}/{}_SR_B4_adjust.TIF'.format(ID, ID)
rotate_raster(file,out_file_1, 30,800,2800,adj_width=800, adj_height=800)

#Without dimensions adjustment
out_file_2 = data_file + '{}/{}_SR_B4_noadjust.TIF'.format(ID, ID)
rotate_raster(file,out_file_2, 30,800,2800)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:调整前后栅格的对比(来源:作者)

在 rotate_raster 函数中,我们通过更新其元数据(dst_kwargs)来改变栅格的尺寸。除了宽度和高度,你会看到我们还更改了变换函数(dst_transform)。这将是原始变换乘以旋转trans_xtrans_y仿射矩阵。最终的变化是nodata元素。将其设置为 0 确保任何新的边界框像素将是黑色的。

可以调整旋转、偏移和尺寸调整参数,以便去掉整个边界框。然而,请记住,任何变换都会“扭曲”像素。像素值使用最近邻方法进行重采样(即Resampling.nearest)。除非用于可视化,否则最好使用原始栅格进行空间分析。

希望你喜欢这篇文章!你可以在 Mastodon | Twitter | YouTube | Newsletter 上找到我——免费注册以获取 Python SHAP 课程

## 加入 Medium 使用我的推荐链接 — Conor O’Sullivan

作为 Medium 会员,你的一部分会员费用会分配给你阅读的作者,你还可以全面访问所有故事……

conorosullyds.medium.com

参考文献

Rasterio 文档 Reprojection rasterio.readthedocs.io/en/stable/topics/reproject.html

仅在另一个 DAG 成功时运行 Airflow DAG

原文:towardsdatascience.com/run-airflow-dag-if-another-dag-succeeds-233aaa4118c1

使用 Airflow 传感器来控制不同计划下 DAG 的执行

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Giorgos Myrianthous

·发布在 Towards Data Science ·阅读时间 11 分钟·2023 年 12 月 19 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 DALL-E2 生成

最近,我一直在尝试协调两个 Airflow DAG,使得其中一个只会在另一个 DAG(每日运行)成功的情况下按自己的小时计划运行。

在今天的教程中,我将引导你了解这个用例,并演示如何通过三种不同的方法实现所需的行为;两种使用ExternalTaskSensor,另一种使用PythonOperator的自定义方法。

使用案例:仅在每日 DAG 成功时运行每小时 DAG

现在,让我们开始处理涉及两个 Airflow DAG 的用例。

第一个 DAG,my_daily_dag,每天在 UTC 时间早上 5 点运行。

from datetime import datetime, timedelta
from pathlib import Path

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator

with DAG(
    catchup=False,
    dag_id='my_daily_dag'
    start_date=datetime(2023, 7, 26),
    default_args={
        'owner': 'airflow',
        'retries': 1,
        'retry_delay': timedelta(minutes=2),
    },
    schedule_interval='0 5 * * *',
    max_active_runs=1,
) as dag:
   DummyOperator(task_id='dummy_task')

第二个 DAG,my_hourly_dag,每小时运行一次,时间在 UTC 的早上 6 点到晚上 8 点之间。

from datetime import datetime, timedelta
from pathlib import Path

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator

with DAG(
    catchup=False,
    dag_id='my_daily_dag'
    start_date=datetime(2023, 7, 26),
    default_args={
        'owner': 'airflow',
        'retries': 1,
        'retry_delay': timedelta(minutes=2),
    },
    schedule_interval='0 6-20 * * *',  # At :00 every hour between 6AM-8PM
    max_active_runs=1,
) as dag:
   DummyOperator(task_id='dummy_task')

在我们的使用案例中,我们希望my_hourly_dag仅在my_daily_dag在当天成功运行的情况下执行。如果没有,则my_hourly_dag应该被跳过。这里需要提到的是,我们不想在my_daily_dag成功后立刻触发my_hourly_dag。那可以通过TriggerDagRun操作符实现。相反,我们希望两个 DAG 各自按照自己的计划运行,但在my_hourly_dag上添加一个条件。

## 如何在 Airflow DAG 中跳过任务

基于特定条件跳过 Airflow DAG 中的任务

towardsdatascience.com

在接下来的两个部分中,我们将讨论并演示如何通过几种不同的方法实现这一点。

确定两个 DAG 的执行日期

在深入实现细节之前,首先了解两个 DAG 在各自 execution_date 方面的区别非常重要。这一点至关重要,因为我们将利用这一知识来确定所需行为的实现方式。

假设今天是 12 月 13 日。每日 DAG my_daily_dagexecution_date2023–12–12 00:00,因为它涵盖了 2023–12–122023–12–13 之间的数据时间段。请记住,Airflow DAG 运行从时间段结束时开始。

与此同时,我们的每小时 my_hourly_dag DAG 具有 execution_date2023–12–13(除了午夜运行,其 execution_date 将为 2023–12–12,因为该时间段的开始为 2023–12–12 23:002023–12–13 00:00)。

使用 ExternalTaskSensor

我们的第一个选择是内置的 [ExternalTaskSensor](https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/external_task/index.html#airflow.sensors.external_task.ExternalTaskSensor) 操作符。

等待不同的 DAG、任务组或任务完成特定的逻辑日期。

默认情况下,ExternalTaskSensor 将等待外部任务成功,此时它也会成功。然而,默认情况下,如果外部任务失败,它将 不会 失败,而是继续检查状态,直到传感器超时(从而给你时间重试外部任务而无需同时清除传感器)。

Airflow 文档

我们可以在 my_hourly_dag 中使用此传感器,该传感器将基本上检查 my_daily_dag 在指定时间段内是否成功。

ExternalTaskSensor 接受 execution_deltaexecution_date_fn 之一。前者可用于指示与前一次执行的时间差。默认情况下,此值设置为与当前任务/DAG 相同的逻辑日期。后者接收一个可调用对象(即一个函数),该函数接受当前执行的逻辑日期作为第一个位置参数,并返回要查询的逻辑日期。

- execution_delta ([*datetime.timedelta*](https://docs.python.org/3/library/datetime.html#datetime.timedelta) | *None*) — 与前一次执行的时间差,默认值是与当前任务或 DAG 相同的逻辑日期。对于昨天,请使用 [positive!] datetime.timedelta(days=1)ExternalTaskSensor 可以传递 execution_deltaexecution_date_fn,但不能同时传递两者。

- execution_date_fn (*Callable* | *None*) — 接收当前执行的逻辑日期作为第一个位置参数的函数,并可选地接收上下文字典中的任何数量的关键字参数,并返回要查询的逻辑日期。ExternalTaskSensor 可以传递 execution_deltaexecution_date_fn,但不能同时传递两者。

由于两个 DAG 的运行时间表不同,传感器的默认行为对我们不起作用。在前面的部分中,我们澄清了两个 DAG 将具有不同的执行日期的原因。

因此,我们需要弄清楚如何使用 execution_deltaexecution_date_fn 来使两个执行日期对齐。

使用 ExternalTaskSensor 和 execution_delta

在我看来,最简单的方法是使用 execution_delta。我们每日 DAG 的数据间隔开始日期是“昨天的 UTC 时间上午 5 点”。由于我们知道 my_hourly_dag 每小时运行一次,因此我们可以提出一个公式来计算小时级 DAG 的间隔开始时间与每日 DAG 的间隔开始时间之间的差异。

以下将创建一个累加的差异:

  • 24 对应于两个 DAG 之间的 24 小时差异,前提是它们的运行时间表不同,如前所述。

  • 小时级 DAG 的间隔开始时间与每天 DAG 运行的时间 5 之间的差异,即每天 DAG 每天运行的小时。

24 + (hourly_dag_interval_start_hour - 5)

作为示例,考虑以下场景,当小时级 DAG 从早上 6 点开始运行(直到晚上 8 点):

在早上 6 点:

  • 小时数据间隔从上午 5 点开始(并于上午 6 点结束)

  • 每日数据间隔从昨天的上午 5 点开始。

  • execution_delta=24 + (5-5) = 24

  • 传感器将检查每日 DAG 的成功情况,其数据间隔开始日期设置为 24 小时之前。

在早上 7 点:

  • 小时数据间隔从早上 6 点开始(并于早上 7 点结束)

  • 每日数据间隔从昨天的上午 5 点开始。

  • execution_delta=24 + (6-5) = 25

  • 传感器将检查每日 DAG 的成功情况,其数据间隔开始日期设置为 25 小时之前。

等等。

那么我们该如何实现呢?我们需要面对一个问题是(在本文撰写时),execution_delta 不是一个模板字段,这意味着我们不能使用提供有用信息的 模板变量,包括 data_interval_start

因此,我们将必须手动构造小时级 DAG 的 data_interval_start。鉴于 DAG 每小时运行一次,数据间隔开始小时对应于当前小时减去一小时。

from datetime import datetime, timezone

datetime.now(timezone.utc).hour - 1

因此,execution_delta 作为参数提供给 ExternalTaskSensor 现在可以定义为:

execution_delta=timedelta(hours=24 + datetime.now(timezone.utc).hour - 1 - 5)

这是我们小时级 DAG 的完整代码,该 DAG 将在 UTC 时间早上 6 点到晚上 8 点之间每小时运行一次,前提是每日 DAG 今天已经成功。

from datetime import datetime, timedelta, timezone
from pathlib import Path

from airflow.models import DAG
from airflow.operators.dummy import DummyOperator
from airflow.sensors.external_task import ExternalTaskSensor

with DAG(
    catchup=False,
    dag_id='my_daily_dag'
    start_date=datetime(2023, 7, 26),
    default_args={
        'owner': 'airflow',
        'retries': 1,
        'retry_delay': timedelta(minutes=2),
    },
    schedule_interval='0 6-20 * * *',  # At :00 every hour between 6AM-8PM
    max_active_runs=1,
) as dag:
    sensor_task = ExternalTaskSensor(
        task_id='daily_dag_completed_successfully',
        external_dag_id='my_daily_dag',
        soft_fail=True,
        check_existence=True,
        execution_delta=timedelta(hours=24 + datetime.now(timezone.utc).hour - 1 - 5),
        poke_interval=30,
        timeout=120,
    )

    dummy_task = DummyOperator(task_id='dummy_task')

    sensor_task >> dummy_task

使用 ExternalTaskSensor 和 execution_date_fn

现在,除了 execution_delta 之外,传感器还可以配置为与 execution_date_fn 一起使用,该函数接受一个可调用对象,返回要查询的逻辑日期。

换句话说,我们需要创建一个函数,并获取每日 DAG 所需的逻辑日期,以便与传感器的条件相匹配,该条件默认会检查指定间隔的 DagRun 状态是否成功。

以下函数将获取日常 DAG 的 DagRuns,并仅在它发生在与每小时 DAG 相同的日期时返回 DagRun 的执行日期。如果未找到 DagRun(这意味着日常 DAG 在过去未执行),将引发 AirflowSkipException,以便跳过传感器任务(以及任何下游任务)。同样,如果没有找到与每小时 DAG 相同日期的日常 DAG 的 DagRun,将返回 current_logical_dt,这本质上是由 ExternalTaskSensor 检查的默认值(也是使用 execution_date_fn 参数时提供的函数定义中必须存在的参数)。

请记住,这两个 DAG 的调度不同,这意味着它们的 execution_date 不同。为了进行适当的比较并确定日常 DAG 是否在每小时 DAG 运行的同一天成功执行,我们需要从每小时 DAG 的执行日期中减去一天。请注意,我们只关心两个 DAG 之间的年份、月份和日期是否相同(在此上下文中我们不太关心时间信息)。

import logging 

from airflow.exceptions import AirflowSkipException
from airflow.models import DagRun

def get_most_recent_dag_run(current_logical_dt):
    dag_id = 'my_daily_dag'
    # Get the historical DagRuns of the daily DAG
    dag_runs = DagRun.find(dag_id=dag_id)

    # Sort DagRuns on descending order such that the first element
    # in the list, corresponds to the latest DagRun of the daily DAG
    dag_runs.sort(key=lambda x: x.execution_date, reverse=True)

    # If the daily DAG was not executed ever before, simply raise an 
    # exception to skip. 
    if not dag_runs:
        logging.info(f'No DAG runs found for {dag_id}. Skipping..')
        raise AirflowSkipException

    # Get the latest DagRun of the daily DAG
    latest_daily_dag_run = dag_runs[0]

    # Subtract one day from hourly's DAG current execution_date in order to 
    # align with the daily DAG's scedule
    current_logical_dt_yesterday = current_logical_dt.subtract(hours=24)

    # if year/month/day of daily's DAG execution_date and hourly's DAG execution_date
    # (minus one day) are the same, it means the daily DAG was executed today. 
    # We therefore return the execution_date of the latest daily DagRun. 
    # It's state (i.e. if successful) will be handled by the sensor and the configuration 
    # we provide to it. 
    if (
        current_logical_dt_yesterday.day == latest_daily_dag_run.execution_date.day
        and current_logical_dt_yesterday.month == latest_daily_dag_run.execution_date.month
        and current_logical_dt_yesterday.year == latest_daily_dag_run.execution_date.year
    ):
        logging.info(f'DAG run was found for {dag_id} today.')
        return latest_daily_dag_run.execution_date

    # Alternatively, return the current execution_date of the hourly DAG
    # This is the default value the sensor would otherwise use, and essentially
    # it means that the sensor won't be triggered given that the intervals between 
    # the daily DAG and the sensor won't align. 
    return current_logical_dt

以下是我们使用 execution_function_fnExternalTaskSensor 的每小时 DAG 的完整代码。

import logging
from datetime import datetime, timedelta
from pathlib import Path

from airflow.exceptions import AirflowSkipException
from airflow.models import DAG, DagRun
from airflow.operators.dummy import DummyOperator
from airflow.sensors.external_task import ExternalTaskSensor

def get_most_recent_dag_run(current_logical_dt):
    dag_id = 'my_daily_dag'
    # Get the historical DagRuns of the daily DAG
    dag_runs = DagRun.find(dag_id=dag_id)

    # Sort DagRuns on descending order such that the first element
    # in the list, corresponds to the latest DagRun of the daily DAG
    dag_runs.sort(key=lambda x: x.execution_date, reverse=True)

    # If the daily DAG was not executed ever before, simply raise an 
    # exception to skip. 
    if not dag_runs:
        logging.info(f'No DAG runs found for {dag_id}. Skipping..')
        raise AirflowSkipException

    # Get the latest DagRun of the daily DAG
    latest_daily_dag_run = dag_runs[0]

    # Subtract one day from hourly DAG's current execution_date in order to 
    # align with the daily DAG's scedule
    current_logical_dt_yesterday = current_logical_dt.subtract(hours=24)

    # if year/month/day of daily DAG's execution_date and hourly DAG's execution_date
    # (minus one day) are the same, it means the daily DAG was executed today. 
    # We therefore return the execution_date of the latest daily DagRun. 
    # It's state (i.e. if successful) will be handled by the sensor and the configuration 
    # we provide to it. 
    if (
        current_logical_dt_yesterday.day == latest_daily_dag_run.execution_date.day
        and current_logical_dt_yesterday.month == latest_daily_dag_run.execution_date.month
        and current_logical_dt_yesterday.year == latest_daily_dag_run.execution_date.year
    ):
        logging.info(f'DAG run was found for {dag_id} today.')
        return latest_daily_dag_run.execution_date

    # Alternatively, return the current execution_date of the hourly DAG
    # This is the default value the sensor would otherwise use, and essentially
    # it means that the sensor won't be triggered given that the intervals between 
    # the daily DAG and the sensor won't align. 
    return current_logical_dt

with DAG(
    catchup=False,
    dag_id='my_daily_dag'
    start_date=datetime(2023, 7, 26),
    default_args={
        'owner': 'airflow',
        'retries': 1,
        'retry_delay': timedelta(minutes=2),
    },
    schedule_interval='0 6-20 * * *',  # At :00 every hour between 6AM-8PM
    max_active_runs=1,
) as dag:
    sensor_task = ExternalTaskSensor(
        task_id='daily_dag_completed_successfully',
        external_dag_id='my_daily_dag',
        soft_fail=True,
        check_existence=True,
        execution_function_fn=get_most_recent_dag_run,
        poke_interval=30,
        timeout=120,
    )

    dummy_task = DummyOperator(task_id='dummy_task')

    sensor_task >> dummy_task

使用 PythonOperator

第二种方法涉及一个更为定制的解决方案。更具体地说,我们可以以编程方式找到我们日常 DAG 的最新成功 DagRun,并相应地处理操作符的行为。换句话说,如果日常 DAG 的最新成功 DagRun 与我们每小时 DAG 的执行日期不一致,则该任务将被跳过(以及下游任务)。

因此,我们可以编写一个函数——类似于我们在前一节中编写的,并作为 ExternalTaskSensorexecution_date_fn 参数使用。

更具体地说,我们需要获取日常 DAG 的 DagRuns,确定今天是否有人成功完成(即每小时 DAG 运行的同一天)。如果没有找到,我们将引发 AirflowSkipException,以便跳过每小时 DAG 的执行。在这种情况下,PythonOperator 支持模板变量,因此我们将充分利用这一点。

这就是我们的函数的样子:

from airflow.exceptions import AirflowSkipException
from airflow.models import DagRun
from airflow.utils.state import DagRunState

def check_daily_dag_success_today(**kwargs):
    dag_id = 'my_daily_dag'
    # Get the historical DagRuns of the daily DAG
    dag_runs = DagRun.find(dag_id=dag_id)

    # Sort DagRuns on descending order such that the first element
    # in the list, corresponds to the latest DagRun of the daily DAG
    dag_runs.sort(key=lambda x: x.execution_date, reverse=True)

    # If the daily DAG was not executed ever before, simply raise an
    # exception to skip.
    if not dag_runs:
        logging.info(f'No DAG runs found for {dag_id}. Skipping..')
        raise AirflowSkipException

    # Get the latest DagRun of the daily DAG
    latest_daily_dag_run = dag_runs[0]

    # Subtract one day from hourly DAG's current execution_date in order to
    # align with the daily DAG's schedule
    data_interval_start = kwargs['data_interval_start']
    data_interval_start_yesterday = data_interval_start.subtract(hours=24)

    # Check the intervals and the success of the daily DAg's DagRun. If conditions are not met,
    # DAG run should be skipped.
    if not (
        latest_daily_dag_run.state == DagRunState.SUCCESS
        and data_interval_start_yesterday.day == latest_daily_dag_run.execution_date.day
        and data_interval_start_yesterday.month == latest_daily_dag_run.execution_date.month
        and data_interval_start_yesterday.year == latest_daily_dag_run.execution_date.year
    ):
        logging.info(f'No successful DAG run was found for {dag_id} today. Skipping..')
        raise AirflowSkipException

    logging.info(f'Successful DAG run was found for {dag_id} today.')

以下是 my_hourly_dag DAG 的完整代码,使用 PythonOperator 来检查 my_daily_dag 的状态:

from datetime import datetime, timedelta
from pathlib import Path

from airflow.exceptions import AirflowSkipException
from airflow.models import DAG, DagRun
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator

def check_daily_dag_success_today(**kwargs):
    dag_id = 'my_daily_dag'
    # Get the historical DagRuns of the daily DAG
    dag_runs = DagRun.find(dag_id=dag_id)

    # Sort DagRuns on descending order such that the first element
    # in the list, corresponds to the latest DagRun of the daily DAG
    dag_runs.sort(key=lambda x: x.execution_date, reverse=True)

    # If the daily DAG was not executed ever before, simply raise an
    # exception to skip.
    if not dag_runs:
        logging.info(f'No DAG runs found for {dag_id}. Skipping..')
        raise AirflowSkipException

    # Get the latest DagRun of the daily DAG
    latest_daily_dag_run = dag_runs[0]

    # Subtract one day from hourly DAG's current execution_date in order to
    # align with the daily DAG's schedule
    data_interval_start = kwargs['data_interval_start']
    data_interval_start_yesterday = data_interval_start.subtract(hours=24)

    # Check the intervals and the success of the daily DAg's DagRun. If conditions are not met,
    # DAG run should be skipped.
    if not (
        latest_daily_dag_run.state == DagRunState.SUCCESS
        and data_interval_start_yesterday.day == latest_daily_dag_run.execution_date.day
        and data_interval_start_yesterday.month == latest_daily_dag_run.execution_date.month
        and data_interval_start_yesterday.year == latest_daily_dag_run.execution_date.year
    ):
        logging.info(f'No successful DAG run was found for {dag_id} today. Skipping..')
        raise AirflowSkipException

    logging.info(f'Successful DAG run was found for {dag_id} today.')

with DAG(
    catchup=False,
    dag_id='my_daily_dag'
    start_date=datetime(2023, 7, 26),
    default_args={
        'owner': 'airflow',
        'retries': 1,
        'retry_delay': timedelta(minutes=2),
    },
    schedule_interval='0 6-20 * * *',  # At :00 every hour between 6AM-8PM
    max_active_runs=1,
) as dag:
   check_task = PythonOperator(
       task_id='check_daily_dag', 
       python_callable=check_daily_dag_success_today,
   )
   dummy_task = DummyOperator(task_id='dummy_task')

   check_task >> dummy_task

最后的想法…

在今天的教程中,我们讨论了如何处理使用 Airflow 时不同 DAG 之间的依赖关系。更具体地说,我们讨论了如何在一个 DAG 以每小时执行的情况下,仅在另一个按日计划的 DAG 在当天成功执行后运行它。

演示了三种不同的方法。根据你的用例的复杂性,你应该选择最合适且代码更优雅的方法。

订阅数据管道,这是一个专注于数据工程的新闻通讯

使用 PHP 在你的网站上运行 ChatGPT 和 GPT 模型

原文:towardsdatascience.com/run-chatgpt-and-gpt-models-on-your-website-with-php-517ea20266d7

一种非常简单的解决方案,将 GPT 模型的 AI 交付给你的用户

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Benjamin Marie

·发表于Towards Data Science ·12 分钟阅读·2023 年 5 月 2 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来自Pixabay

GPT 模型可以提升网站和网络应用的用户体验。它们可以翻译、总结、回答问题,还能完成许多其他任务。

将所有这些功能集成到你的在线服务中,通过 OpenAI API 相当简单。目前,OpenAI 仅提供对 Python 和 NodeJS 绑定的官方支持。

许多第三方绑定已经由社区开发,以便在其他编程语言中进行部署。

在这篇文章中,我将展示如何将你的网站连接到 OpenAI 的 API。我还会解释如何解析和解读 API 返回的结果。

我将只涵盖 GPT 模型,但你可以使用相同的流程来处理 DALL-E 和 Whisper 模型。

先决条件

GPT 模型

你不需要熟悉 GPT 模型就可以理解和实现这篇文章,但我仍建议你阅读我关于 GPT 模型的简单介绍:

## GPT 模型的简单介绍

欢迎来到令牌生成器的新世界

towardsdatascience.com

PHP

你只需了解 PHP 的基础知识。

我将使用一个可以通过 Composer 安装的 PHP 库(所以你需要 Composer),并且要求至少 PHP 8.1。注意:你无法在旧版本的 PHP 上安装该库。

OpenAI 账户

你需要一个 OpenAI 账户。如果你没有,请参考我的指南,了解如何创建和管理 OpenAI 账户:

## OpenAI 账户:文档、游乐场和模型的超参数

使用 OpenAI API 所需了解的全部内容

medium.com

如果你想运行示例,你需要在账户中创建一个 API 密钥并保留几分钱的积分。

OpenAI PHP

我们将使用由 OpenAI PHP(MIT 许可证)维护的客户端与 OpenAI API 进行通信。

其他 PHP 库也能做到这一点,但我选择这个库是因为以下原因:

  • 它由 OpenAI 列出,合理保证了这个库可以信任。

  • 在所有 PHP 绑定 OpenAI API 的库中,它在 GitHub 上拥有最多的 stars。

  • 它易于安装和使用。

  • 它会定期更新,以考虑 API 的变化和新的 OpenAI 模型。

要安装它,打开终端,进入你的网站/应用程序父目录,并按如下方式运行 composer:

composer require openai-php/client

如果没有任何错误,你可以开始使用 PHP 的 OpenAI API。

在 PHP 中设置你的 API 密钥

你必须在你的 OpenAI 账户中创建一个 API 密钥。

出于安全原因,我建议为每个你希望连接到 API 的 Web 应用程序创建一个新的 API 密钥。

如果你的某个产品发生了安全漏洞,你可以只销毁 OpenAI 账户中的密钥,而不会影响其他应用。

你不应该直接在 PHP 文件中写入这个密钥,而是使用操作系统环境变量来存储它。例如,在 Ubuntu/Debian 上,运行:

export MY_OPENAI_KEY={your key}
#Replace {your key} by your OpenAI API key

在你的 PHP 脚本中,你可以使用以下方式获取这个环境变量的值:

<?php
$yourApiKey = getenv('MY_OPENAI_KEY');
....//remainder of your script
?>

如果你无法访问操作系统环境变量,最简单的替代方案是在一个单独的文件中定义一个 PHP 常量,并在所有使用 API 的 PHP 脚本中引入该文件。

例如,创建一个文件“key.php”,最好不要放在你网站的主目录中,并写入:

<?php
define('MY_OPENAI_KEY'. '{your key}');
?>

然后在所有将使用 API 的文件顶部写入以下内容:

<?php
require_once("path/to/key.php"); //the path to your key.php file
$yourApiKey = MY_OPENAI_KEY;
....//remainder of your script
?>

使用 GPT 模型的 PHP 补全任务

OpenAI PHP 客户端支持通过 OpenAI API 访问的所有任务。在这篇文章中,我将重点讨论使用 GPT 模型的“补全任务”。

补全任务是指我们 提示 模型一个文本,API 通过在此提示后添加文本来作出回应。

API 提供了两种不同类型的补全任务:

  • standard: 提示 GPT-3 或 GPT-4 模型并生成跟随该提示的 tokens

  • chat: 给定一个描述对话历史的消息列表,模型将返回一个响应。因此,这里的提示是一组包含关于是模型还是用户写的信息的消息。

我将演示如何使用 OpenAI PHP 客户端来完成这两种类型的任务。

使用 GPT-3 完成任务

首先,我们需要一个目标。我们希望 GPT 模型完成什么?

对于这篇文章,我们可以设定目标是“将”文本翻译成表情符号。

使用 GPT 模型时最关键的步骤之一是找到适合我们任务的良好提示。如果你的提示不好,模型的回答也不会很出色。

什么是好的提示?

提示工程是一个非常活跃的研究领域。我不会在这里讨论这个话题,但我计划在我的下一篇文章中进行探讨。

对于我们的任务,受到之前使用大型语言模型的机器翻译工作的启发,我提出了以下提示,取得了相当不错的结果:

将以下文本翻译成表情符号:

[TXT]

其中 [TXT] 将被替换为要翻译成表情符号的文本。

这个提示的优点是简短。使用它不会花费太多。

例如,我们将尝试将以下文本翻译成表情符号:

我想要一个不加洋葱的汉堡。

所以我们的提示变成了:

将以下文本翻译成表情符号:

我想要一个不加洋葱的汉堡。

使用 OpenAI PHP 客户端,我们可以通过以下代码实现:

<?php
//This line is necessary to load the PHP client installed by Composer
require_once('../vendor/autoload.php');

//Change the next line to $yourApiKey = MY_OPENAI_KEY; if you didn't use an environment variable and set your key in a separate file
$yourApiKey = getenv('MY_OPENAI_KEY');

//Create a client object
$client = OpenAI::client($yourApiKey);

//The $prompt variable stores our entire prompt
$prompt = "Translate the following text into emoji:

I would like an hamburger without onions.
";

//We send our prompt along with parameters to the API
//It creates a completion task
$result = $client->completions()->create([
    'model' => 'text-davinci-003',
    'prompt' => $prompt
]);

//After a few seconds the response will be stored in $results
//We can print the text answered by GPT
echo $result['choices'][0]['text']; 

?>

在这个代码中,我假设你在你的网站的根目录下。

它应该打印一系列表情符号。我得到了这个:

🍔🚫🧅

你可能会得到不同的序列,因为 GPT 模型是“非确定性的”。

我使用了“text-davinci-003” GPT 模型,这是最强大的 GPT-3 模型。

如果你的任务非常简单,你可以使用更便宜的 GPT 模型。例如,我们可以尝试用“ada”替换“text-davinci-003”模型。

'model' => 'ada',

我得到了以下回答:

例如,输入 这是文本 “Looking For a hamburger”

是的,这相当糟糕。这个回应中没有任何表情符号。选择正确的模型是你在将 OpenAI API 集成到产品中时必须做出的最关键的选择。

  • 如果你选择一个旧的或小型的模型,结果会很低质量,并且可能无法完成请求的任务。

  • 如果你选择一个更大的模型,你可能会得到最好的结果,但成本会更高。

你需要尝试多个模型,以确定哪个是最适合你目标的选项。作为起点,OpenAI 提供了一些使用建议和可用模型列表

除了模型名称和提示,完成任务还可以接受更多参数。它们都在API 文档中描述。

我们可以指定例如响应中的最大标记数,如下所示:

$result = $client->completions()->create([
    'model' => 'text-davinci-003',
    'prompt' => $prompt,
    'max_tokens' => 2
]);

这不应该生成任何内容,只有 1 行空白。为什么?

1 个表情符号由 text-davinci-003 中的 3 个标记组成。所以如果我们将‘max_tokens’设置为 2,模型甚至无法生成 1 个表情符号。

我怎么知道一个表情符号由 3 个标记组成?

我在我的 OpenAI 用户账户的 playground 中简单检查了一下。例如,如果你在那里输入“🍔🚫🧅”,模型会计算出 9 个 tokens。

此外,GPT 模型在 emoji 序列前生成一个换行符。它算作一个额外的 token。总的来说,GPT 给了我 10 个 tokens 的回答。

请注意,“$result”变量包含所有这些信息。我们将在下面的下一部分中查看它。

但在此之前,让我们看看聊天完成任务。

聊天完成任务

聊天完成任务与我们使用 GPT-3 时略有不同。聊天任务由 gpt-3.5-turbo 提供支持,它也为 ChatGPT 提供支持。

在 gpt-3.5-turbo 中,“prompt” 参数被“messages” 替代。

从技术上讲,“messages” 是包含两个必需键和一个可选键的关联数组,如下所示:

  • role (required): 可以是“system”,“assistant”或“user”。在我撰写本文时,OpenAI 文档中几乎忽略了“system”。剩下的是“assistant”即模型,以及“user”即人类。

  • content (required): 这是我们放置提示或提示的上下文的地方,例如聊天历史。

  • name (optional): 如果你想给消息的作者指定一个特定的名字。

消息的长度和数量几乎是无限的。这样,gpt-3.5-turbo 可以接受非常长的聊天历史作为输入。

聊天完成可以执行与标准 GPT-3 相似的任务。在文档中,OpenAI 写了如下内容:

因为 gpt-3.5-turbo 的能力与 text-davinci-003 相似,但每个 token 的价格仅为 10%,所以我们推荐在大多数用例中使用 gpt-3.5-turbo

让我们用翻译文本为 emoji 的任务来检查它。

我们只需进行少量修改:

<?php
//This line is necessary to load the PHP client installed by Composer
require_once('../vendor/autoload.php');

//Change the next line to $yourApiKey = MY_OPENAI_KEY; if you set your key in a separate file
$yourApiKey = getenv('MY_OPENAI_KEY');

//Create a client object
$client = OpenAI::client($yourApiKey);

//The $prompt variable stores our entire prompt
$prompt = "Translate the following text into emoji:

I would like an hamburger without onions.
";

//We send our prompt along with parameters to the API
//It creates a chat completion task
$result = $client->chat()->create([
    'model' => 'gpt-3.5-turbo',
    'messages' => [
        ['role' => 'user', 'content' => $prompt],
    ],
]);

//After a few seconds the respone will be store in results
//We can print the text answer by GPT
echo $result['choices'][0]['message']['content']; 

?>

我获得了与 text-davinci-003 相同的答案,“🍔🚫🧅”,但价格仅为 text-davinci-003 的 10%。

现在你知道如何在 PHP 中与 OpenAI API 通信,我们可以更仔细地查看 API 返回的内容。正如我们将看到的,响应中包含有用的数据,我们可以用来监控 API 成本、跟踪用户活动(例如标记禁止的行为)等。

使用 PHP 解读 OpenAI API 响应

我们可以这样制作“$result”变量的可打印版本:

print_r($result->toArray());

对于聊天完成任务,它将打印出如下内容:

Array
(
    [id] => chatcmpl-7AJFw****
    [object] => chat.completion
    [created] => 1682691656
    [model] => gpt-3.5-turbo-0301
    [choices] => Array
        (
            [0] => Array
                (
                    [index] => 0
                    [message] => Array
                        (
                            [role] => assistant
                            [content] => 🍔🚫🧅
                        )

                    [finish_reason] => stop
                )

        )

    [usage] => Array
        (
            [prompt_tokens] => 23
            [completion_tokens] => 9
            [total_tokens] => 32
        )

注意:我手动遮蔽了部分“id”。

我们有以下条目:

  • id: OpenAI 为响应分配的唯一 ID。这些信息可以帮助跟踪 API 和用户之间的交互。

  • object: 执行的任务类型。

  • created: 响应创建的时间戳。

  • model: 用于生成响应的模型。

  • choices: 默认情况下,你将仅获得一个聊天完成任务的消息,除非你在调用 API 时更改“n”选项。

  • index: 从 0 开始的消息索引。

  • message: 关于生成的消息的信息。

  • role: 消息发送者的角色。

  • content: 消息本身。

  • finish_reason:API 停止生成消息的原因。默认情况下,它将是“stop”,即模型在没有任何约束的情况下停止生成。如果你在调用 API 时指定了“stop”参数,则可能会发生变化。然后,模型会在生成了你在“stop”中提到的一个标记后停止生成。

  • usage:有关令牌长度的信息。它可以用于监控 API 成本。

  • prompt_tokens:你提示中的令牌数量。

  • completion_tokens:API 生成的消息中的令牌数量。

  • total_tokens: “prompt_tokens”和“completion_tokens”的总和。

最重要的字段是“choices”,因为这是你将要交付给用户的内容,以及“usage”,因为这是唯一能够告诉你生成这个答案花费了多少的指标。

要知道 API 调用的确切成本,你必须将“total_tokens”的值乘以每个令牌的模型成本。注意 OpenAI 显示的是 1,000 个令牌的价格,因此你需要将这个数字除以 1,000 来获得每个令牌的价格。

例如,如果我们使用每 1,000 个令牌花费 $0.002 的模型,而“total_tokens”为 32,我们可以按如下方式计算总成本:

0.002 / 1000 * 32 = 0.000064

这个 API 调用将花费 $0.000064。

标准 GPT-3 完成的响应字段与聊天完成任务的字段几乎相同。

唯一显著的区别是,“text.completion”任务还可以返回 t 个最可能的令牌的日志概率。你可以在调用 API 时使用“logprobs”参数来指示“t”。t 的最大值是 5。注意:OpenAI 的 API 参考文档表示,如果你的应用需要更大的值,你可以手动请求 OpenAI。

在网页应用程序/网站中集成的下一步是什么?

我们已经学会了如何用 PHP 与 OpenAI API 通信。你的在线服务现在可以利用 GPT 模型的全部功能。

下一步将是实现前端。你不需要为此做过于复杂的事情。一个简单的 AJAX 脚本,例如使用 jQuery,就足够异步地从执行 API 调用的 PHP 脚本中获取响应。

它可以简单到这样:

$.ajax({  
            type:"POST",  
            url:"call.php",  
            data:{ prompt: my_prompt //my_prompt stores the prompt
            },
            success:function(data){  
            data = $.parseJSON(data);
            $('#my_GPT_response').html(data["choices"][0]["message"]["content"]);
            }  
        }); 

这将把聊天完成的内容打印在一个 HTML 对象中,该对象的 id 属性设置为“my_GPT_response”。

你的 PHP 脚本必须接收“prompt”作为 $_POST 变量,并且 API 回答应该编码为 JSON 对象,如下所示:

<?php
//This line is necessary to load the PHP client installed by Composer
require_once('../vendor/autoload.php');

//At least check that the prompt is sent
//Of course you should also check the content of the variable according to what you want to do with it
if (isset($_POST['prompt'])){
  //Change the next line to $yourApiKey = MY_OPENAI_KEY; if you set your key in a separate file
  $yourApiKey = getenv('MY_OPENAI_KEY');

  //Create a client object
  $client = OpenAI::client($yourApiKey);

  //The $prompt variable stores our entire prompt
  $prompt = "Translate the following text into emoji:

  ".$_POST['prompt']."
  ";

  //We send our prompt along with parameters to the API
  //It creates a chat completion task
  $result = $client->chat()->create([
      'model' => 'gpt-3.5-turbo',
      'messages' => [
          ['role' => 'user', 'content' => $prompt],
      ],
  ]);
  $result = $response->toArray();
  echo json_encode($result);
}

?>

总结这篇文章,我应该再次提到,你必须始终检查你发送给 API 的内容,以确保你没有违反 OpenAI 的政策和使用条款。

你可以利用 审查模型,这是 OpenAI 提供的免费服务,可以在将内容发送到 GPT 模型之前标记不安全的内容。

重要的是检查用户的年龄。OpenAI 的使用条款禁止 13 岁以下的儿童使用其服务,而 18 岁以下的儿童只能在成人监督下使用这些服务。

如果你喜欢这篇文章并且对接下来的文章感兴趣,支持我工作的最佳方式是通过这个链接成为 Medium 会员:

[## 通过我的推荐链接加入 Medium - 本杰明·玛丽

加入我们的 AI 社区,获取前沿研究成果。本博客旨在揭示最近在 AI 领域的进展……

medium.com](https://medium.com/@bnjmn_marie/membership?source=post_page-----517ea20266d7--------------------------------)

如果你已经是会员并希望支持这项工作, 请在 Medium 上关注我

在 Jupyter Notebook 中与 ChatGPT 运行交互式会话

原文:towardsdatascience.com/run-interactive-sessions-with-chatgpt-in-jupyter-notebook-87e00f2ee461

使用 LangChain 和 IPyWidgets 在 Jupyter Notebook 中与 ChatGPT 进行关于自定义文档的对话

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Konstantin Rink

·发布于 Towards Data Science ·阅读时间 6 分钟·2023 年 5 月 4 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

原始照片由 Charles Etoroma 提供,来自 Unsplash

2023 年 3 月,OpenAI 发布了其 API,供开发者访问 ChatGPT 和 Whisper 模型。从那时起,开发者可以通过 API 将这些服务和模型集成到他们的应用程序和产品中。许多精彩的教程随后发布了如何使用其 API 结合 Streamlit 或 Streamlit Chat 创建自己的 ChatGPT 聊天 web 应用程序。

本文提出了一种 更轻量级的方法。无需运行或托管 Streamlit 服务器或使用 Docker 容器,所有工作都在 Jupyter Notebook 中完成

在本文中,你将学习 如何使用 OpenAI 的大型语言模型(LLM)ChatGPT 在 Jupyter Notebook 中运行关于自定义文档的交互式会话,方法是使用 LangChainIPyWidgets

最终结果将如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1. 最终结果的演示(图片来自作者)。

以下章节将分别解释代码的每一部分。

💾 这里 你可以找到完整的代码 Notebook。

先决条件

在我们开始与 ChatGPT 的对话之前,需要先完成一些准备工作。

OpenAI API 密钥 🔑

由于我们想要使用 ChatGPT,首先需要一个有效的 OpenAI API 密钥。所需的密钥可以在 此链接 中创建,然后点击

+ 创建新的密钥 按钮。

OpenAI 提供了一个免费的试用期,之后才收费。在我看来,价格非常公平,考虑到在许多情况下,托管你自己的 LLM 更为昂贵。

安装 OpenAI 包 📦

一旦我们拥有密钥,我们还需要通过运行以下命令来安装官方 OpenAI 包:

pip install openai

安装 LangChain 包 🦜️🔗

LangChain 是一个相对较新的框架,建立在 ChatGPT 等 LLMs 之上。它的目标是将不同的组件链式连接在一起,以创建更高级的使用案例,例如特定文档上的问答或聊天机器人。

要安装它,请运行以下命令:

pip install langchain

安装或更新 Jupyter Widgets 🪐

如果你使用 Jupyter Notebook 或 Jupyter Lab,ipywidgets 应该已经安装。然而,可能你正在使用旧版本的包。本文使用的是(最新)版本 8.0.5

要安装或更新 ipywidgets,请运行下面的命令:

pip install -U ipywidgets

安装成功后,重启你的 Jupyter Notebook/Lab。

题外话:…我知道… 木星没有光环 — 这是土星 🪐 😛

数据 📑

正如上面承诺的,我们不仅会创建一个与 ChatGPT 的互动会话,还会将自己的文档发送给 ChatGPT,然后询问有关这些文档的问题。我们用作示例的文档是关于 硅谷银行倒闭 的 Wikipedia 文章(Wikipedia 贡献者,CC BY-SA 3.0

注意:我使用 wikipedia package 下载了提到的文章作为文本文件。当然,你也可以使用任何你喜欢的文本文件或 DocumentLoaders。

项目结构

我们的项目有以下文件夹结构

documents/
|- Collapse_of_Silicon_Valley_Bank.txt
images/
|- bear_avatar.png
|- cat_avatar.png
|- loading.gif
InteractiveSession.ipynb

将 ChatGPT 与 LangChain 集成

现在我们已经拥有了所需的所有包和工具,我们可以着手将 ChatGPT 与 LangChain 结合使用。正如上面提到的,LangChain 具有许多有用的功能来加载文档和开始与 ChatGPT 的对话会话。

下面的代码展示了我们稍后将与 Jupyter Widgets 结合的逻辑。

line 10 中,我们必须设置我们的 OpenAI API 密钥。Lines 12-13 从指定的 /documents 路径加载所有文本文件(在本例中只有一个)。

Chromalines 15-16)是一个内存中的嵌入数据库,包含我们以 OpenAIEmbeddings 形式的文本文档。

在我们初始化 ChatGPT 的 line 18 后,我们创建了一个 ConversationalRetrievalChainline 21)。要开始与 ChatGPT 对话,我们需要指定一个问题和聊天记录(lines 24–29 和 line 31),以便它记住之前的对话,例如,当我们在后续问题中引用之前的答案时。

注意:如果你在选择 Jupyter Notebook 和 Jupyter Lab 之间犹豫,请选择后者。使用 Jupyter Lab,你有更多的选项来调试代码(即,日志控制台)。

将其与交互设计结合起来。

根据上述逻辑,我们已经可以开始对话了。然而,这将不会是互动式的。对于每个新问题,我们都需要创建一个新单元格,包含lines 24–29 and line 31中的代码。

为了使我们的对话具有互动性,我们将使用 Jupyter 小部件、CSS,并可以选择使用两个头像图像和一个加载动画 gif。

下面的代码片段展示了我们需要首先导入的库。

from datetime import datetime
from IPython.display import HTML, display
from ipywidgets import widgets

导入后,我们将以下代码添加到一个新单元格中。这是必要的,因为我们使用了%%html单元格魔法。

以下代码展示了如何将上述逻辑(将 ChatGPT 与 LangChain 集成)与 HTML 代码结合起来。

为了使我们的会话具有互动性,我们需要创建一个由 Jupyter Text Widget(我们的输入字段)更改触发的方法。

我们必须将text.continuous_update设置为False。否则,我们定义的方法将在每个字符输入时被触发。

最后但同样重要的是,我们定义了输入字段、输出和加载条的外观或布局。

我们在这里使用flex_flow="column-reverse",以便始终将滚动条置于底部,这样我们就不必为每条新消息向下滚动。

就这些!现在我们可以开始一个互动会话了!

完整代码可以在这里找到。

演示会话

如上所述,我为这个演示使用了一篇关于硅谷银行倒闭的文章。由于此事件发生在 ChatGPT 最近的更新或刷新之后,它无法了解这次倒闭。

让我们通过使用官方控制台(图 2)来找出答案:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2. ChatGPT 的(ChatGPT Mar 23 版本)回答(图片由作者提供)。

我们可以看到当前版本(3 月 23 日)尚未了解倒闭情况。让我们开始我们的互动环节:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3. 互动输出(作者提供的图片)。

BÄM!通过我们提供的信息,ChatGPT 能够提供关于倒闭的更详细回答。

结论

如果你在寻找一个轻量级的替代方案来创建你自己的 ChatGPT 聊天网页应用,那么这个方法可能是一个相当不错的选择。所有需要的工作都可以在 Jupyter Notebook 中通过 LangChain、IPyWidgets 和 HTML/CSS 完成。由于 LangChain 相对较新,我预计不久会有许多更新和可能的代码更改。

来源

维基百科贡献者。(2023 年 5 月 1 日)。硅谷银行倒闭。在维基百科,自由百科全书中。检索于 2023 年 5 月 1 日 18:56,来自en.wikipedia.org/w/index.php?title=Collapse_of_Silicon_Valley_Bank&oldid=1152681730

在你的 GPU 上运行 Llama 2 70B 使用 ExLlamaV2

原文:towardsdatascience.com/run-llama-2-70b-on-your-gpu-with-exllamav2-588141a88598?source=collection_archive---------0-----------------------#2023-09-29

找到适合你硬件的最佳混合精度量化方法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Benjamin Marie

·

关注 发表在 Towards Data Science ·7 min read·2023 年 9 月 29 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供 — 制作自 Pixabay

Llama 2 系列中最大且最优秀的模型拥有 700 亿个参数。一个 fp16 参数占用 2 字节。加载 Llama 2 70B 需要 140 GB 的内存(70 亿 * 2 字节)。

在之前的一篇文章中,我展示了如何通过量化技术在 100 GB 的 CPU RAM 上运行一个 1800 亿参数的模型 Falcon 180B。

[## Falcon 180B: 能在你的电脑上运行吗?

是的,如果你有足够的 CPU RAM

newsletter.kaitchup.com

Llama 2 70B 比 Falcon 180B 小得多。

它能完全适配到单个消费级 GPU 中吗?

这具有挑战性。一款高端消费级 GPU,如 NVIDIA RTX 3090 或 4090,具有 24 GB 的 VRAM。如果我们将 Llama 2 70B 量化到 4 位精度,我们仍然需要 35 GB 的内存(70 亿 * 0.5 字节)。该模型可以适配到 2 个消费级 GPU 中。

通过 GPTQ 量化,我们可以进一步将精度降低到 3 位,而不会在模型性能上损失太多。一个 3 位参数在内存中占用 0.375 字节。Llama 2 70B 量化到 3 位仍将占用 26.25 GB。这无法适配到一个消费级 GPU 中。

[## 运行 Llama 3.1 的最佳量化方法

基准测试 AQLM、bitsandbytes、AWQ、GPTQ 和 AutoRound 的推理吞吐量、准确性和内存消耗

newsletter.kaitchup.com

我们可以将精度降低到 2 位。这将适配 24 GB 的 VRAM,但根据之前对 2 位量化的研究,模型的性能也会显著下降。

为了避免在模型性能上损失过多,我们可以将模型的重要层或部分量化为更高的精度,将不太重要的部分量化为更低的精度。模型将以混合精度进行量化。

ExLlamaV2(MIT 许可证)实现了混合精度量化。

在这篇文章中,我展示了如何使用 ExLlamaV2 进行混合精度量化模型。更具体地,我们将看到如何将 Llama 2 70B 量化到低于 3 位的平均精度

我实现了一个演示并基准测试 Llama 2 混合精度量化的笔记本。可以在这里获取:

获取笔记本(#18)

Llama 2 的混合精度量化

要求

要进行混合精度量化并运行模型,我们需要安装 ExLlamaV2。

从源代码安装:

git clone https://github.com/turboderp/exllamav2
cd exllamav2
pip install -r requirements.txt

我们的目标是让模型在消费级 GPU 上运行。

  • Llama 2 70B:我们目标是 24 GB 的 VRAM。NVIDIA RTX3090/4090 GPU 将适用。如果你使用 Google Colab,则不能在免费的 Google Colab 上运行。只有 Google Colab PRO 的 A100 具有足够的 VRAM。

  • Llama 2 13B:我们目标是 12 GB 的 VRAM。许多至少有 12 GB VRAM 的 GPU 可用,如 RTX3060/3080/4060/4080 等。它可以在免费的 Google Colab 上使用 T4 GPU 运行。

如何使用 ExLlamaV2 进行混合精度量化

ExLlamaV2 使用的量化算法类似于 GPTQ。但不同于选择一种精度类型,ExLlamaV2 尝试了每一层的不同精度类型,同时测量量化误差。所有尝试和相关的误差率都被保存。然后,给定用户提供的目标精度,ExLlamaV2 算法会为每层模块选择能平均达到目标精度的量化精度,且误差率最低。

在量化过程中,ExLlamaV2 输出所有尝试:

Llama 2 13B 第 10 层 up_proj 模块的量化尝试

-- Linear: model.layers.10.mlp.up_proj
 -- 0.05:3b/0.95:2b 32g s4         2.18 bpw    rfn_error: 0.21867
 -- 0.25:3b/0.75:2b 32g s4         2.38 bpw    rfn_error: 0.20617
 -- 0.25:4b/0.75:2b 32g s4         2.63 bpw    rfn_error: 0.20230
 -- 0.1:4b/0.4:3b/0.5:2b 32g s4    2.73 bpw    rfn_error: 0.18449
 -- 0.1:4b/0.9:3b 32g s4           3.23 bpw    rfn_error: 0.10229
 -- 0.2:6b/0.8:3b 32g s4           3.73 bpw    rfn_error: 0.09791
 -- 1.0:3b 128g s4                 3.03 bpw    rfn_error: 0.11354
 -- 1.0:3b 32g s4                  3.13 bpw    rfn_error: 0.10491
 -- 0.05:4b/0.95:3b 32g s4         3.18 bpw    rfn_error: 0.10363
 -- 0.4:4b/0.6:3b 32g s4           3.53 bpw    rfn_error: 0.09272
 -- 0.6:4b/0.4:3b 64g s4           3.66 bpw    rfn_error: 0.08835
 -- 1.0:4b 128g s4                 4.03 bpw    rfn_error: 0.05756
 -- 1.0:4b 32g s4                  4.13 bpw    rfn_error: 0.05007
 -- 0.1:5b/0.9:4b 32g s4           4.23 bpw    rfn_error: 0.04889
 -- 0.1:6b/0.9:4b 32g s4           4.33 bpw    rfn_error: 0.04861
 -- 1.0:5b 128g s4                 5.03 bpw    rfn_error: 0.02879
 -- 0.1:6b/0.9:5b 32g s4           5.23 bpw    rfn_error: 0.02494
 -- 0.05:8b/0.05:6b/0.9:5b 32g s4  5.33 bpw    rfn_error: 0.02486
 -- 0.4:6b/0.6:5b 32g s4           5.53 bpw    rfn_error: 0.02297
 -- 0.1:8b/0.3:6b/0.6:5b 32g s4    5.73 bpw    rfn_error: 0.02280
 -- 1.0:6b 128g s4                 6.03 bpw    rfn_error: 0.01503
 -- 1.0:6b 32g s4                  6.13 bpw    rfn_error: 0.01471
 -- 0.1:8b/0.9:6b 128g s4          6.23 bpw    rfn_error: 0.01463
 -- 1.0:8b 32g s4                  8.13 bpw    rfn_error: 0.00934
 -- Time: 19.57 seconds

我们可以看到,随着量化精度(bpw,即每重量的位)增加,误差率如预期那样降低。

使用 ExLlamaV2 进行量化就像运行 convert.py 脚本一样简单:

注意:convert.py 在 ExLlamaV2 的根目录下

python convert.py \
    -i ./Llama-2-13b-hf/ \
    -o ./Llama-2-13b-hf/temp/ \
    -c test.parquet \
    -cf ./Llama-2-13b-hf/3.0bpw/ \
    -b 3.0

ExLlamaV2 不支持 Hugging Face 库。它期望模型和校准数据集存储在本地。

脚本的主要参数如下:

  • 输入模型 (-i):一个包含以“safetensors”格式存储模型的本地目录。

  • 用于校准的 dataset (-c):我们需要一个数据集来进行量化校准。它必须以“parquet”格式存储在本地。

  • 输出目录 (-cf):量化模型将被保存的本地目录。

  • 量化的目标精度 (-b):模型将以混合精度进行量化,平均精度将是目标精度。在这里,我选择了 3 位精度。

此量化过程花费了 2 小时 5 分钟。我使用了 Google Colab PRO 的 T4 GPU 和高 CPU RAM。在整个过程中,它的 VRAM 消耗没有超过 5 GB,但 CPU RAM 有高达 20 GB 的峰值消耗。

T4 的速度相当慢。使用 Google Colab V100 或 RTX GPU 可以减少量化时间。注意:我不清楚量化过程中 GPU 的使用情况。可能 CPU 的速度对量化时间的影响大于 GPU。

要量化 Llama 2 70B,你可以做同样的操作。

我们应该针对什么精度,以便量化后的 Llama 2 70B 能适应 24 GB 的 VRAM?

这是你可以应用的方法,以决定根据你的硬件选择模型的精度。

假设我们有 24 GB 的 VRAM。我们还应该总是预期一些推理的内存开销。因此,我们的目标量化模型大小为 22 GB。

首先,我们需要将 22 GB 转换为位:

  • 22 GB = 2.2e+10 bytes = 1.76e+11 bits(因为 1 字节 = 8 位)

我们有 1.76e+11 位(b)可用。Llama 2 70B 有 7e+10 个参数(p)需要量化。我们目标的精度是我称之为 bpw 的精度。

  • bpw = b/p

  • bpw = 176 000 000 000 / 70 000 000 000 = 2.51

所以我们可以承受每个参数 2.51 位的平均精度。

我将其四舍五入到 2.5 位。

要将 Llama 2 70B 量化为平均 2.5 位精度,我们运行:

python convert.py \
    -i ./Llama-2-70b-hf/ \
    -o ./Llama-2-70b-hf/temp/ \
    -c test.parquet \
    -cf ./Llama-2-70b-hf/2.5bpw/ \
    -b 2.5

这种量化在配备 24 GB GPU 的消费级硬件上也是可行的。可能需要长达 15 小时。如果你打算使用 Google Colab 进行此操作,请注意,由于 A100 GPU 的存储空间过小,你必须将原始模型存储在 Google Colab 硬盘之外。

在你的 GPU 上运行 Llama 2 70B,使用 ExLlamaV2

ExLlamaV2 提供了运行混合精度量化模型所需的一切。

有一个 chat.py 脚本,可以将模型作为聊天机器人进行交互使用。你也可以简单地使用 test_inference.py 测试模型。这是我们将用来检查模型速度和内存消耗的方法。

为测试量化为 2.5 bpw 的 Llama 2 70B,我们运行:

python test_inference.py -m ./Llama-2-70b-2.5bpw/ -p "Once upon a time,"

注意:“-p”是测试提示。

这应该需要几分钟(在 A100 GPU 上约 8 分钟)。ExLlamaV2 使用“torch.compile”。根据 PyTorch 文档:

torch.compile 通过将 PyTorch 代码即时编译成优化的内核来加速 PyTorch 代码的运行,同时需要最少的代码更改。

这个编译过程比较耗时,但会被缓存。

如果你运行 test_inference.py,通常应该只需 30 秒。

模型本身的重量正好是 22.15 GB。在我的推理实验中,它正好占用了 24 GB。它几乎适用于我们的消费级 GPU。

为什么它不仅仅消耗 22.15 GB?

内存中的模型实际占用 22.15 GB,但推理本身也会消耗额外的内存。例如,我们必须对提示进行编码并将其存储在内存中。此外,如果你设置了更高的最大序列长度或进行批量解码,推理将消耗更多内存。

我在这个实验中使用了 Google Colab 的 A100。如果你使用 24 GB 的 GPU,你可能会在推理过程中遇到 CUDA 内存不足错误,尤其是当你还使用 GPU 运行操作系统图形用户界面(例如,Ubuntu 桌面大约消耗 1.5 GB 的显存)时。

为了给你一些余地,目标设置较低的 bpw。2.4 甚至 2.3 会留下几 GB 的显存供推理使用。

ExLlamaV2 模型也非常快速。我观察到生成速度在 15 到 30 个 token/秒之间。为了给你一个比较点,当我用 GPTQ 将 Llama 2 7B 量化为 4-bit,一个小 10 倍的模型时,使用 Hugging Face transformers 进行生成时的速度约为 28 tokens/sec。

[## GPTQ 还是 bitsandbytes:LLM 使用哪种量化方法 - 以 Llama 2 为例]

适合在你的计算机上进行经济实惠的微调和推理的大型语言模型量化

newsletter.kaitchup.com](https://newsletter.kaitchup.com/p/gptq-or-bitsandbytes-which-quantization?source=post_page-----588141a88598--------------------------------)

结论

混合精度量化是直观的。我们在模型的影响较小的地方大幅降低精度。

在单个消费级 GPU 上运行大型模型如 Llama 2 70B 是可能的。

一定要评估你在不同目标精度下量化的模型。虽然较大的模型在量化时性能损失较少,但总有一种精度下,量化模型的表现会比未量化但参数较少的模型差,例如,Llama 2 70B 2-bit 可能会显著比 Llama 2 7B 4-bit 表现更差,即使前者更大。

在 Julia 中后台运行任务

原文:towardsdatascience.com/run-things-in-the-background-with-julia-c9e72e59fc48

停止等待,开始多线程

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 本斯·科马尔尼茨基

·发布于 Towards Data Science ·阅读时长 4 分钟·2023 年 5 月 26 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Max Wolfs 提供,来自 Unsplash

即使 Julia 是现有的最快语言之一,有时执行任务也需要时间。如果你是一个使用 Julia 的数据科学家或分析师,也许你希望将计算任务发送到服务器,等待完成后再处理结果。

但等待是无聊的。

当你在工作中充满创意和热情,渴望交付有趣的内容时,你希望 不断敲击键盘寻找其他内容

让我向你展示 Julia 中一个简单的技巧,如何 将计算任务分配到另一个线程,然后继续你的工作。

设置工作环境

正如我之前所说,Julia 很快。作为一种现代语言,它也 考虑了多线程处理。所以,如果你知道如何操作,使用计算机上的额外核心非常容易。

首先,我们必须确保以多个线程启动 Julia 实例:

julia -t 4

这将使用 4 个线程启动 Julia。我们可以通过查询线程数来确认这一点:

julia> using Base.Threads

julia> Threads.nthreads()
4

制作一个慢速函数

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Frederick Yang 提供,来自 Unsplash

现在我们有了更多的线程,是时候看看这些魔法的实际效果了。但我们需要一些东西运行一段时间才有意义。我假设你在阅读这篇文章时已经有了一些想法,但因为我喜欢在文章中提供完整的示例,我会在这里写一个小函数来娱乐一下自己。

这个“慢”的函数可能是构建 ML 模型的调用,运行一些类似 SQL 的数据库查询,或从云存储中获取一些数据。发挥你的想象力,尽情尝试吧!

julia> function collatz(n, i=0)
           if n == 1
               i
           elseif iseven(n)
               collatz(n / 2, i + 1)
           else
               collatz(3n + 1, i + 1)
           end
       end
collatz (generic function with 2 methods)

julia> collatz(989345275647)
1348

julia> averageSteps(n) = sum(i -> collatz(i) / n, 1:n)
averageSteps (generic function with 1 method

如果你对上述内容感到好奇,以及为什么我选择了 989,345,275,647,那么阅读这个维基页面

获取一些魔法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 K. Mitch HodgeUnsplash 上拍摄

由于我们在命名空间中有 Threads,我们可以使用 **@spawn** 宏将计算发送到另一个线程。这意味着我们可以立即返回 REPL,并继续像以前一样工作。

julia> res = @spawn averageSteps(1e7)
Task (runnable) @0x000000015d061f90

julia> 2+ 12
44

julia> fetch(res)
155.2724831

忽略我缺乏想象力,生成后我只是懒得想出更复杂的东西。

基本上,这里发生的事情是 @spawn 返回一个 Task这个任务会自动分派到一个空闲的线程,该线程可以在后台处理它,允许你在此期间编写更多代码和提出更多问题。一旦你需要结果,你可以用 **fetch** 收集任务的结果,它会等待 Task 完成并返回结果。

证明这有效

一种展示这确实有效的方法是展示一些时间记录。

首先,我们将在当前线程上运行我们的函数并测量所需时间。然后我们将生成一个 Task,最后我们将生成并立即等待结果。

julia> @time averageSteps(1e7)
 16.040698 seconds
155.2724831

julia> @time res = @spawn averageSteps(1e7)
  0.009290 seconds (31.72 k allocations: 1.988 MiB)
Task (runnable) @0x000000015d179f90

julia> @time fetch(@spawn averageSteps(1e7))
 16.358641 seconds (24.31 k allocations: 1.553 MiB, 0.06% compilation time)
155.2724831

正如你所见,我们的函数运行大约需要 16 秒。但如果我们调度任务,那么我们 会立即返回一个任务。这带来了一些开销,如你在最后一行所见,因为这比在主线程上运行计算稍微慢了 0.3 秒。

感谢阅读!

希望这个小技巧能让新手对 Julia 的现代多线程语言的强大功能有更多了解。如果你喜欢我关于这个话题的啰嗦,请给我一个 👏 或 👏 👏。

在 GCP 上运行稳定扩散集群并使用 tensorflow-serving(第一部分)

原文:towardsdatascience.com/running-a-stable-diffusion-cluster-on-gcp-with-tensorflow-serving-part-1-4f7a8e2f66df?source=collection_archive---------10-----------------------#2023-03-07

第一部分:使用 Terraform 设置基础设施

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Thushan Ganegedara

·

关注 发布于 Towards Data Science ·11 分钟阅读·2023 年 3 月 7 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Kier in Sight 提供,来源于 Unsplash

在这个两部分教程的第一部分中,我们将学习在 GCP 上创建部署稳定扩散模型的 Kubernetes 集群。Stable Diffusion(一种生成式 AI)是新晋中的潮流。稳定扩散允许我们从给定的文本提示生成逼真的图像。由于稳定扩散模型带来的新颖性和计算负载,它提供了解决一些独特挑战的宝贵机会。

注意:即使您是免费用户(只要您还有免费套餐余额),您也可以全程跟随本教程。

Github:github.com/thushv89/tf-serving-gke/tree/master/infrastrcture

但是,要创建完美的风暴(或完美的产品),仅有最新版本的模型权重不足以应对。需要努力构建一个可靠的生产系统,以支持用户请求并以合理的延迟可靠地提供服务。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我从部署模型中获取的一些图像示例。你能猜到提示是什么吗?(作者提供的图像)

为此,我们将学习如何在 GKE 集群上运行一个稳定扩散模型。这个 2 部分教程将包括 4 部分:

  • 设置帐户和角色(第一部分)

  • 设置集群(第一部分)

  • 在配置的集群中部署预测服务(第二部分)

  • 部署终端节点生成新图片(第二部分)

开始之前,请确保您已创建了一个 GCP 项目并通过gcloud auth login登录到您的用户帐户。您可以使用gcloud config set project <project_id>gcloud config set region <region>来确保您在正确的项目和区域中。

注意:我在这里讨论的大多数 IAM(身份和访问管理)都基于我(有限)个人的经验。如果您发现任何可以改进的地方,请告诉我!

terraform:时尚地管理基础架构

如果您已经熟悉terraform,请直接跳到“定义帐户和角色(IAM)”部分。

概述

对于 GCP 上所有基础架构的设置,我们将使用terraform;一种 IaaS(基础设施即服务)工具,允许我们将所有基础架构需求编码化。为什么要通过代码管理云资源,而不是容易出错和痛苦的手动操作呢?还有许多其他原因:

  • 人类可读的代码使架构更容易理解,提高了可重用性等。

  • terraform自动管理依赖关系并按正确顺序执行操作

  • 版本控制代码使您能够在某个特定时间点获取系统状态的快照(用于故障排除)

terraform 提供了一个全面的开箱即用 API,可以快速构建所有常见提供者的基础设施,如 GCP、AWS、Azure 等。

terraform 概念

terraform 术语将代码组织成配置。terraform 配置在一个工作目录中操作,该目录下的配置文件以 .tf.tf.json 扩展名结尾;

  • variables.tf — 包含配置使用的所有变量定义

  • outputs.tf — 任何需要写出的输出

  • 除此之外,你可以包含任意数量的 .tf 文件,包含资源定义、提供者等。在我们简单的场景中,我们只需要一个文件,称之为 main.tf

接下来,让我们看看 terraform 如何实现代码的模块化。

terraform 是一种声明性语言,这意味着你告诉 terraform 要做什么(像 SQL),而不是怎么做(像 Python)。由 terraform 来构建计划(例如图形形式)并执行。

然后我们可以使用 模块 来组成我们的 terraform 配置。模块化是可选的,但它将复杂的基础设施拆分为逻辑组件/子系统,并大大增强了可重用性。在我们的案例中,我们将定义三个模块;

  1. 管理账户和角色 (modules/iam)

  2. 管理 GKE 集群 (modules/gke_cluster)

  3. 管理存储 — 设置 GCS 存储桶 (modules/storage)

当你进入代码中的这些模块时,你会看到以下基本构建块和谐地用于达到所需的基础设施状态(有关具体示例,请参见附录)。

  • 资源块 — 描述基础设施对象(例如 VM、集群、VPC)

  • 数据源 / 数据块 — 代表数据源(例如文件)及其相关数据

  • 提供者插件 — 提供对某个提供者相关的资源类型和数据源的访问。

  • 模块的输入和输出变量

一旦定义了配置,你可以运行 terraform plan 来查看 terraform 将执行什么。接下来可以使用 terraform apply 来应用这些更改。应用后,terraform 会在 terraform.tfstate 文件中记录所做的更改。因此,如果你想进行更改(或销毁),terraform 会了解基础设施的当前状态,从而为所需的更改创建计划。

如果你需要进一步巩固 terraform 概念,你可以阅读文档 这里 或查看这个 GCP 教程。现在我们了解了 terraform 的基础知识,接下来让我们理解逻辑。

定义账户和角色(IAM)

对于我们设置 GKE 集群的操作,我们将创建一个服务账号。顾名思义,服务账号通常由应用程序和工作负载使用,而不是实际的人。例如,GKE 节点可以使用服务账号来执行应用程序。服务账号可以被分配权限和角色(即以有意义的方式汇总的权限集合),就像用户账号一样。服务账号的几个优势包括,

  • 我们可以快速绑定/移除用户与服务账号的绑定,允许我们为用户提供必要的权限,而无需重复分配角色/权限给各个用户。

  • 服务账号可以通过 设置短期凭据 来提高安全性。

我们将设置两个具有以下 ID 的服务账号:

  • gke-admin — 具有创建 GKE 集群和配置节点所需的权限

  • gke-node — 具有成功执行工作负载所需的权限(例如,从 GCS 存储桶中读取)

虽然服务账号不直接由人使用或附加,但可以 模拟服务账号,允许用户像服务账号一样执行命令。这是我们将用于设置集群的方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

身份和资源的高层次视图(图像由作者提供)

这是我们将在 terraform 代码中概述的过程,

  • 创建服务账号:gke-admingke-node

  • 为创建的账号分配所需的角色

    gke-admincontainer.admin(例如创建集群)、compute.viewer(例如创建节点池)、iam.serviceAccountUser

    gke-nodecontainer.nodeServiceAccount(适用于典型的 Kubernetes 工作负载的权限)

    — 你可以在 GCP 控制台 → IAM → 角色 中查看每个角色提供了哪些权限。

  • 分配所需的角色给用户账号以创建短期访问令牌(iam.serviceAccountTokenCreator

  • 从用户账号创建一个绑定到服务账号,以便用户可以模拟服务账号

最后,我们将在 outputs.tf 中声明我们创建的两个服务账号的名称作为输出,以便配置和其他子模块可以引用。

为了提供基础设施,我们将使用两种形式的身份验证,

  • 通过运行 gcloud auth login 获取的典型身份验证将用于创建服务账号和绑定。

  • 之后,我们将使用模拟服务账号来设置集群

注意 1:我在用户账号上附加了 owner 角色(即项目所有者),如果你没有,你需要获得创建服务账号等所需的权限。

注意 2:即使你拥有所有者角色,进行所有这些服务账户创建和绑定可能看起来有些冗余,但是在一个项目中,与团队协作(或在组织中)时,你需要以最小权限用户的心态来思考和设置权限,以避免安全漏洞。

我们暂时不会运行terraform apply,因为我们将一次性创建服务账户和 GKE 集群。

定义 GKE 集群

我们将创建一个即使是免费用户也可以设置的 GKE 集群。一个集群由一个控制平面和一个或多个工作节点组成。控制平面提供对集群的访问,使你能够检查节点、Pods、服务等。每个节点可以运行一个或多个 Pods(具有特定资源要求——例如 CPU/内存)。一个 Pod(可能运行一个或多个容器)将运行指定的工作负载(例如tensorflow-serving镜像以提供模型)。你可以参考这里了解 GKE 架构。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

GKE 集群的高级架构(作者提供的图片)

我们将以标准模式创建一个集群,配置如下:

  • machine_type:n2-standard-4

  • max_node_count(最大节点数):2

  • preemptible:true(你也可以使用比preemptible实例更便宜的spot实例。了解这些差异请点击这里

注意,我们没有使用 GPU,因为作为免费用户,你可能没有 GPU 配额的资格。但是如果有,请随意按照这里的过程设置带有 GPU 的节点池。

注意 1:如果你是免费用户,你将受到两个重要配额的限制:

all_regions_cpus:默认为12

all_n2_cpus:默认为8

all_regions_gpus:默认为0

由于我们使用的是 N2 类型实例,每个实例具有 4 个 vCPUs,因此在配额范围内我们只能启动 2 个这样的实例。如果你想要在集群中拥有更多节点,可以尝试其他实例类型,如n2-standard-2n1实例。

注意 2:这些是全球配额,这意味着,例如,如果你有一个启动了其他n2类型实例的 Vertex AI 笔记本,它也会计入该配额。

如果不遵守这些,你在应用这些基础设施时可能会遇到Quota exceeded类型的错误。

你可以在这里查看完整的配置。我在这里不会详细说明,因为它很简单。然而,我想提出一个警告,即区域集群和区域集群的概念。忽略这种区别可能会导致一些神秘的错误,比如这个 Stackoverflow 问题

在 GCP 上创建基础设施和资源

在应用讨论过的terraform更改之前,我们需要进行一些整理。首先,运行

./setup.sh -u <user name> -p <project id> -r <region>

这将创建一个配置文件,其中包含定义的参数,以便将它们导入到terraform代码中。接下来,运行,

terraform init 

这将安装提供的插件以及我们定义的本地模块。接下来,我们可以运行以下命令来了解terraform将要执行的操作。

terraform plan [-var="include_module_storage=<true or false>"]

计划将是这样的。

Terraform used the selected providers to generate the following execution plan. Resource actions are indicated with the
following symbols:
  + create
 <= read (data resources)

Terraform will perform the following actions:

  # data.google_service_account_access_token.default will be read during apply
  # (config refers to values not yet known)
 <= data "google_service_account_access_token" "default" {
      + access_token           = (sensitive value)
      + id                     = (known after apply)
      + scopes                 = [
          + "cloud-platform",
          + "userinfo-email",
        ]
      + target_service_account = (known after apply)
    }

  ...

  # module.iam.google_service_account_iam_binding.admin_account_iam will be created
  + resource "google_service_account_iam_binding" "admin_account_iam" {
      + etag               = (known after apply)
      + id                 = (known after apply)
      + members            = [
          + "user:thushv@gmail.com",
        ]
      + role               = "roles/iam.serviceAccountTokenCreator"
      + service_account_id = (known after apply)
    }

Plan: 9 to add, 0 to change, 0 to destroy.

如果我们对计划满意,我们可以运行以下命令来应用更改。

terraform apply [-var="include_module_storage=<true or false>"]

如果一切成功,你应该会在工作目录中看到一个terraform.tfstate文件,列出所有应用的更改。访问README以获取详细说明。你可以前往GCP 控制台 → IAM → 服务帐户,确保服务帐户已正确创建。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

应用 terraform 转换后创建的服务帐户(图像由作者提供)

你还会在GCP 控制台 → Kubernetes 引擎 → 集群中看到一个名为sd-cluster的集群。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

集群已经用单个节点初始化(图像由作者提供)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一旦进入集群,你可以看到有关节点池和节点的更多信息(图像由作者提供)

很好,现在我们拥有了部署 ML 模型作为服务所需的一切。我们将在教程的下一部分中查看如何完成这项工作。

到目前为止,你,

  • 了解了什么是terraform以及它如何简化基础设施管理

  • 创建了身份(服务帐户)并将其设置为正确的角色

  • 理解了什么是 GKE 集群并通过模拟所需的服务帐户创建了一个

故障排除与注意事项

  • 错误无法连接到服务器:x509: 证书已过期或尚未生效

  • 解决方案 1:这可能是由于gcloud会话过期。只需运行gcloud auth login并完成登录过程。

  • 解决方案 2:WSL 中存在一个 bug,其中 WSL 内的时钟与 Windows 时钟不同步。你可以运行sudo hwclock -s来触发同步。

警告:如果你在 Powershell 中使用 bash(由WSL支持),可能无法导出环境变量(供terraform使用)。因此,如果你依赖环境变量,建议不要使用它们。

附录

资源块

描述一个或多个基础设施对象(例如,虚拟机、集群、VPC)。每个资源通过资源类型唯一名称来标识。

resource "google_service_account" "sa_gke_admin" {
  account_id   = "gke-admin"
  display_name = "GKE Service Account (Admin)"
}

数据源 / 数据块

表示数据源及其关联的数据

data "google_service_account_access_token" "default" {
 provider                = google.impersonation_helper
 target_service_account  = module.iam.service_account_gke_admin
 scopes                  = ["userinfo-email", "cloud-platform"]
 depends_on = [module.iam]
}

提供者插件

提供对某个提供者关联的资源类型和数据源的访问。

terraform {
  required_providers {
    google = {
      source  = "hashicorp/google"
      version = "3.5.0"
    }
  }
}

输入和输出变量

作为模块的参数和返回类型。

variable "gcp_user" {
  type = string
  description = "Your username for GCP"
}
output "service_account_gke_node" {
  description = "GKE node service account"
  value       = google_service_account.sa_gke_node.email
}
variable "gcp_user" {
  type = string
  description = "Your username for GCP"
}
output "service_account_gke_node" {
  description = "GKE node service account"
  value       = google_service_account.sa_gke_node.email
}

致谢

我想感谢ML 开发者计划以及团队提供的 GCP 积分,使这个教程取得了成功。

在 GCP 上使用 tensorflow-serving 运行稳定扩散集群(第二部分)

原文:towardsdatascience.com/running-a-stable-diffusion-cluster-on-gcp-with-tensorflow-serving-part-2-c421ecb7472a?source=collection_archive---------9-----------------------#2023-03-14

创建工件并在集群上部署模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Thushan Ganegedara

·

关注 发表于 Towards Data Science ·14 分钟阅读·2023 年 3 月 14 日

在 第一部分 中,我们学习了如何使用terraform方便地设置和管理基础设施。在这一部分中,我们将继续我们的旅程,将运行中的稳定扩散模型部署到提供的集群上。

注意:即使你是免费用户,也可以完整地跟随本教程(只要你还有一些免费层积分)。

除非另有说明,所有图片均由作者提供

Github: github.com/thushv89/tf-serving-gke

让我们看看最终结果会是什么。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

部署的稳定扩散模型生成的一些图像。

准备模型工件

稳定扩散到底是什么?

构建稳定扩散模型有五个主要组件:

  • 分词器——将给定字符串分词为令牌列表(数值 ID)。

  • 文本编码器——接受分词后的文本并生成文本嵌入。

  • 扩散模型——接受文本嵌入和潜在图像(最初是噪声)作为输入,并逐步优化潜在图像以编码越来越多有用的信息(视觉上令人愉悦)。

  • 解码器——接受最终的潜在图像并生成实际图像。

  • 图像编码器(用于修复功能——在本练习中我们将忽略这一点)。

稳定扩散(扩散模型)的核心突破性理念是,

如果你在多次步骤中逐渐向图像添加一点噪声,最后你会得到一个包含噪声的图像。通过反转这个过程,你可以得到一个输入(噪声)和一个目标(原始图像)。然后训练一个模型从噪声中预测原始图像。

上述所有组件协同工作以实现这一理念。

存储稳定扩散模型

代码:github.com/thushv89/tf-serving-gke/blob/master/notebooks/savedmodel_stable_diffusion.ipynb

为了构建稳定扩散模型,我们将使用keras_cv库,该库包括用于图像分类、分割、生成 AI 等的流行深度学习视觉模型集合。你可以在这里找到一个教程,讲解如何在keras_cv中使用StableDiffusion。你可以打开一个笔记本并与模型一起玩以熟悉它。

我们的目标是将StableDiffusion模型保存为SavedModel格式;这是序列化 TensorFlow 模型的标准方法。做到这一点的一个关键要求是确保所有使用的操作都是 TensorFlow 图兼容的。不幸的是,情况并非如此。

  • 当前版本的模型使用与 TensorFlow 图不兼容的分词器,因此需要将其从打包模型中提取出来,并在单独的步骤中使用。

  • 当前版本使用predict_on_batch来生成图像,但 TensorFlow 图构建不支持此功能。

修正模型

为了修补急切模式的 StableDiffusion 模型,我们将创建一个名为 StableDiffusionNoTokenizer 的新模型。通过这个新模型,我们将用图形兼容的 __call__() 替换所有 predict_on_batch() 调用。正如名字所示,我们还将把标记化过程与模型解耦。此外,在 generate_image() 函数中,我们将替换,

timesteps = tf.range(1, 1000, 1000 // num_steps)
alphas, alphas_prev = self._get_initial_alphas(timesteps)
progbar = keras.utils.Progbar(len(timesteps))
iteration = 0
for index, timestep in list(enumerate(timesteps))[::-1]:
    latent_prev = latent  # Set aside the previous latent vector
    t_emb = self._get_timestep_embedding(timestep, batch_size)
    unconditional_latent = self.diffusion_model.predict_on_batch(
        [latent, t_emb, unconditional_context]
    )
    latent = self.diffusion_model.predict_on_batch(
        [latent, t_emb, context]
    )
    latent = unconditional_latent + unconditional_guidance_scale * (
        latent - unconditional_latent
    )
    a_t, a_prev = alphas[index], alphas_prev[index]
    pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(
        a_t
    )
    latent = (
        latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
    )
    iteration += 1
    progbar.update(iteration)

与,

latent = self.diffusion_reverse_loop(
    latent,
    context=context, 
    unconditional_context=unconditional_context, 
    batch_size=batch_size, 
    unconditional_guidance_scale=unconditional_guidance_scale,
    num_steps=num_steps,
)

其中,

@tf.function
def diffusion_reverse_loop(self, latent, context, unconditional_context, batch_size, unconditional_guidance_scale, num_steps):

  index = num_steps -1
  cond = tf.math.greater(index, -1)
  timesteps = tf.range(1, 1000, 1000 // num_steps)
  alphas, alphas_prev = self._get_initial_alphas(timesteps)
  iter_partial_fn = functools.partial(
      self._diffusion_reverse_iter, 
      timesteps=timesteps, 
      alphas=alphas, 
      alphas_prev=alphas_prev, 
      context=context, 
      unconditional_context=unconditional_context, 
      batch_size=batch_size, 
      unconditional_guidance_scale=unconditional_guidance_scale, 
      num_steps=num_steps
  )

  latent, index = tf.while_loop(cond=lambda _, i: tf.math.greater(i, -1), body=iter_partial_fn, loop_vars=[latent, index])

  return latent 

@tf.function
def _diffusion_reverse_iter(self, latent_prev, index, timesteps,  alphas, alphas_prev, context, unconditional_context, batch_size, unconditional_guidance_scale, num_steps):

  t_emb = self._get_timestep_embedding(timesteps[index], batch_size)

  combined_latent = self.diffusion_model(
            [
                tf.concat([latent_prev, latent_prev],axis=0), 
                tf.concat([t_emb, t_emb], axis=0), 
                tf.concat([context, unconditional_context], axis=0)
            ], training=False
        )
  latent, unconditional_latent = tf.split(combined_latent, 2, axis=0)
  latent = unconditional_latent + unconditional_guidance_scale * (
        latent - unconditional_latent
  )
  a_t, a_prev = alphas[index], alphas_prev[index]
  pred_x0 = (latent_prev - tf.math.sqrt(1 - a_t) * latent) / tf.math.sqrt(a_t)
  latent = latent * tf.math.sqrt(1.0 - a_prev) + tf.math.sqrt(a_prev) * pred_x0
  index -= 1

  return latent, index

我做的两个主要更改是:

  • 我使用了 tf.while_loop 代替 Python for 循环,因为在 TensorFlow 中它的性能更优。

  • 将两个独立的 diffusion_model 调用合并为一个调用,然后再拆分输出。

还有其他更改,如用 TensorFlow 等效函数替换各种操作(例如 np.clip() -> tf.clip_by_value()),你可以对比 原始模型此版本 来进行比较。

在 TensorFlow 的图执行模式下,你可以使用 tf.print() 语句以确保代码在执行过程中的有效性。有关 tf.print() 的更多信息,请参考附录。

一旦底层模型修复完成,我们可以创建以下模型,该模型可以在图模式下无缝执行。

class StableDiffusionTFModel(tf.keras.models.Model):

  def __init__(self):
    super().__init__()
    self.image_width = self.image_height = 384
    self.model = StableDiffusionNoTokenizer(img_width=self.image_width, img_height=self.image_height, encoded_text_length=None, jit_compile=True)
    # This forces the model download its components
    # self.image_encoder is only required for in-painting - we will ignore this functionality in this excercise
    self.text_encoder = self.model.text_encoder
    self.diffusion_model = self.model.diffusion_model
    self.decoder = self.model.decoder

    self.default_num_steps = tf.constant(40) 
    self.default_batch_size = tf.constant(2)

    # These negative prompt tokens are borrowed from the original stable diffusion model
    self.default_negative_prompt_tokens = tf.constant(
        [
            49406, 8159, 267, 83, 3299, 267, 21101, 8893, 3500, 267, 21101, 
            8893, 4804, 267, 21101, 8893, 1710, 267, 620, 539, 6481, 267, 
            38626, 267, 12598, 943, 267, 4231, 34886, 267, 4231, 7072, 267, 
            4231, 5706, 267, 1518, 15630, 267, 561, 6528, 267, 3417, 268, 
            3272, 267, 1774, 620, 539, 6481, 267, 21977, 267, 2103, 794, 
            267, 2103, 15376, 267, 38013, 267, 4160, 267, 2505, 2110, 267, 
            782, 23257, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407
        ], dtype=tf.int32
    )

  def call(self, inputs):

    encoded_text = self.text_encoder([inputs["tokens"], self.model._get_pos_ids()], training=False)

    images = self.model.generate_image(
        encoded_text, 
        negative_prompt_tokens=inputs.get("negative_prompt_tokens", self.default_negative_prompt_tokens),
        num_steps=inputs.get("num_steps", self.default_num_steps), 
        batch_size=inputs.get("batch_size", self.default_batch_size)
    )
    return images

model = StableDiffusionTFModel()

这个模型接受以下输入:

  • input_tokens: 输入字符串的标记化表示

  • negative_prompt_tokens: 负面提示的标记化表示(有关负面提示的更多信息:这里)

  • num_steps: 执行扩散过程的步骤数

  • batch_size: 每张图片生成的图片数量

这是这个模型的一个使用示例:

# Tokenizing the prompts
tokenizer = SimpleTokenizer()

def generate_tokens(tokenizer, prompt, MAX_PROMPT_LENGTH):

  inputs = tokenizer.encode(prompt)
  if len(inputs) > MAX_PROMPT_LENGTH:
      raise ValueError(
          f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
      )
  phrase = tf.concat([inputs, ([49407] * (MAX_PROMPT_LENGTH - len(inputs)))], axis=0)
  return phrase

tokens = generate_tokens(tokenizer, "a ferrari car with wings", MAX_PROMPT_LENGTH)

# Invoking the model
all_images = []
num_steps = 30
tokens = generate_tokens(tokenizer, "a castle in Norway overlooking a glacier, landscape, surrounded by fairies fighting trolls, sunset, high quality", MAX_PROMPT_LENGTH)
neg_tokens = generate_tokens(tokenizer, "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy", MAX_PROMPT_LENGTH)
images = model({
    "tokens": tokens, 
    "negative_prompt_tokens": neg_tokens,
    "num_steps": tf.constant(num_steps), 
    "batch_size": tf.constant(1)
})

记住,我(即免费用户等级)在这个项目中受到配额的严重限制。

  • 完全没有 GPU 配额

  • 最多 8 个 N2 CPU(如果选择 N1 CPU,可以达到 12 个)

因此,我不能使用任何 GPU 实例或超过 2 个 n2-standard-4 实例。由于稳定扩散模型较慢,因此使用 CPU 实例时我们将面临延迟问题。

下面是不同参数下所需时间的详细信息。测试是在 n2-standard-8 机器上,在 Vertex AI workbench 上进行的。

  • 图像大小(num_steps = 40

    — 512x512 图像:474s

    — 384x384 图像:233s

  • batch_sizenum_steps

    batch size = 1:21.6s(num_steps=5),67.7s(num_steps=20)和 99.5s(num_steps=30

    batch size = 2,55.6s(num_steps=5),121.1s(num_steps=20)和 180.2s(num_steps=30

    batch size=4,21.6s(num_steps=5),67.7s(num_steps=20)和 99.5s(num_steps=30

如你所见,增加 image_sizebatch_sizenum_steps 会导致时间消耗增加。因此,在平衡计算成本和图像质量后,我们为部署的模型选择了以下参数。

  • image_size: 384x384

  • num_steps: 30

  • batch_size: 1

模型创建后,将模型上传到创建的 GCS 存储桶中。

!gsutil -m cp -r  ./stable_diffusion_model gs://<project>-bucket/

这将是我们用来将模型部署为预测服务的数据源。

让我们再次欣赏一些模型生成的图像,然后继续下一个部分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

部署模型生成的图像

部署和提供模型

代码: github.com/thushv89/tf-serving-gke/tree/master/infrastrcture

要部署我们的模型并设置预测服务,我们需要 3 个配置:

  • configmap.yaml — 这定义了部署过程中所需的各种变量。例如,这将包括在 GCS 上保存模型的位置(即通过环境变量 MODEL_PATH 访问)。

  • deployment.yamlDeployment 定义了 Pod 的规格(例如 CPU)和应该运行的容器。在这种情况下,我们将运行一个单独的容器,运行 tensorflow-serving 来提供位于 MODEL_PATH 的模型。

  • service.yamlService 是我们暴露在 Pod 中运行的 tensorflow-serving 应用的机制。例如,我们可以让它通过负载均衡器暴露我们的 Pod。

部署

我们首先查看 deployment 的规格:

spec:
  replicas: 1
  selector:
    matchLabels:
      app: stable-diffusion
  template:
    metadata:
      labels:
        app: stable-diffusion
    spec:
      containers:
      - name: tf-serving
        image: "tensorflow/serving:2.11.0"
        args:
        - "--model_name=$(MODEL_NAME)"
        - "--model_base_path=$(MODEL_PATH)"
        - "--rest_api_timeout_in_ms=720000"
        envFrom:
        - configMapRef:
            name: tfserving-configs
        imagePullPolicy: IfNotPresent
        readinessProbe:
          httpGet:
            path: "/v1/models/stable-diffusion"
            port: 8501
            scheme: HTTP
          initialDelaySeconds: 30
          periodSeconds: 15
          failureThreshold: 10
        ports:
        - name: http
          containerPort: 8501
          protocol: TCP
        - name: grpc
          containerPort: 8500
          protocol: TCP
        resources:
          requests:
            cpu: "3"
            memory: "12Gi"

我们可以做一些有趣的观察:

  • 我们在脚本中只声明了一个副本,缩放将在其他地方设置,并通过自动缩放策略进行控制。

  • 我们提供一个 selector,服务将会在部署中查找它,以确保它在正确的部署上提供服务。

  • 我们暴露了两个端口:8501(HTTP 流量)和 8500(GRPC 流量)。

  • 我们将为每个容器请求 3 个“CPU 时间”和 12Gi 的内存。

注意 1: 节点通常会运行 Kubernetes 需要的其他 Pods(例如 DNS、监控等)。因此,在规定 Pod 的计算资源时需要考虑这些因素。您可以看到,尽管节点上有 4 个 CPU,我们只请求了 3 个(您也可以请求分数的 CPU 资源 — 例如 3.5)。您可以在 GCP 上查看每个节点的可分配 CPU/内存(GCP 控制台 → 集群 → 节点 → 单击节点)或使用 kubectl describe nodes

如果您的节点无法满足您指定的计算资源,Kubernetes 将无法运行 Pods 并抛出错误(例如 PodUnschedulable)。

注意 2:你需要特别注意的一个关键参数是--rest_api_timeout_in_ms=720000。处理一个请求大约需要 250 秒,所以我们这里将超时时间设置为大约三倍的时间,以应对并行请求时任何排队的请求。如果你将其设置为过小的值,你的请求将在完成之前超时。

定义服务

在这里,我们定义了一个LoadBalancer类型的服务,我们将通过 GCP 负载均衡器暴露stable-diffusion应用。在这种方法中,你将获得负载均衡器的 IP 地址,负载均衡器将把流量路由到到达它的副本。用户将向负载均衡器的 IP 地址发起请求。

metadata:
  name: stable-diffusion
  namespace: default
  labels:
    app: stable-diffusion
spec:
  type: LoadBalancer
  ports:
  - port: 8500
    protocol: TCP
    name: tf-serving-grpc
  - port: 8501
    protocol: TCP
    name: tf-serving-http
  selector:
    app: stable-diffusion

自动扩展

我们一直拖延的一个重要话题是:扩展我们的服务。在现实世界中,你可能需要服务数千、数百万甚至数十亿的客户。为了做到这一点,你的服务需要能够根据需求上下扩展集群中的节点/副本数量。幸运的是,GCP 提供了多种选项,从完全托管的自动扩展到半托管/完全用户管理的自动扩展。你可以通过这个视频了解更多信息。

在这里,我们将使用水平副本自动扩展器(HPA)。水平副本自动扩展器将根据你提供的一些阈值(例如 CPU 或内存使用情况)扩展副本的数量。这是一个示例。

kubectl autoscale deployment stable-diffusion --cpu-percent=60 --min=1 --max=2

在这里,我们将 HPA 的最小副本数设置为 1,最大副本数设置为 2,并要求它在当前副本集的平均 CPU 超过 60%时添加更多副本。

应用更改

我们现在已经准备好所有的构建块来启动我们的服务。只需运行以下命令。

 gcloud container clusters get-credentials sd-cluster --zone us-central1-c && \
kubectl apply -f tf-serving/configmap.yaml && \
kubectl apply -f tf-serving/deployment.yaml && \
kubectl autoscale deployment stable-diffusion --cpu-percent=60 --min=1 --max=2 && \
kubectl apply -f tf-serving/service.yaml

从服务模型中预测

为了进行预测,你只需向正确的 URL 发起一个 POST 请求,负载中包含模型的输入。

顺序预测

作为第一个示例,我们展示了如何一个接一个地发起一系列请求。

def predict_rest(json_data, url):
    json_response = requests.post(url, data=json_data)
    response = json.loads(json_response.text)
    if "predictions" not in response:
      print(response)
    rest_outputs = np.array(response["predictions"])
    return rest_outputs

url = f"http://{stable_diffusion_service_ip}:8501/v1/models/stable-diffusion:predict"

tokens_list = [
    generate_tokens(tokenizer, "A wine glass made from lego bricks, rainbow colored liquid being poured into it, hyper realistic, high detail", MAX_PROMPT_LENGTH).numpy().tolist(),
    generate_tokens(tokenizer, "A staircase made from color pencils, hyper realistic, high detail", MAX_PROMPT_LENGTH).numpy().tolist(),
    generate_tokens(tokenizer, "A ferrari car in the space astronaut driving it, futuristic, hyper realistic, high detail", MAX_PROMPT_LENGTH).numpy().tolist(),
    generate_tokens(tokenizer, "a dragon covered with weapons fighting an army, fire, explosions, hyper realistic, high detail", MAX_PROMPT_LENGTH).numpy().tolist(),
    generate_tokens(tokenizer, "A sawing girl in a boat, hyper realistic, high detail", MAX_PROMPT_LENGTH).numpy().tolist(),

]
negative_tokens = generate_tokens(tokenizer, "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy", MAX_PROMPT_LENGTH).numpy().tolist()

all_images = []
all_data = []
for tokens, negative_tokens in zip(tokens_list, [negative_tokens for _ in range(5)]):
    all_data.append(generate_json_data(tokens, negative_tokens))

all_images = [predict_rest(data, url) for data in all_data]

当我运行实验时,这花费了超过 1600 秒。正如你想象的,这种设置相当低效,无法利用集群的扩展能力。

并行预测

你可以使用 Python 的多处理库来进行并行请求,这更贴近真实用户请求的情况。

def predict_rest(input_data, url):
    json_data, sleep_time = input_data["data"], input_data["sleep_time"]

    # We add a delay to simulate real world user requests
    time.sleep(sleep_time)
    print("Making a request")
    t1 = time.perf_counter()
    json_response = requests.post(url, data=json_data)
    response = json.loads(json_response.text)
    result = np.array([])
    try: 
        result = np.array(response["predictions"])
    except KeyError:
        print(f"Couldn't complete the request {response}")
    finally:
        t2 = time.perf_counter() 
        print(f"It took {t2-t1}s to complete a single request")
        return result

t1 = time.perf_counter()

with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    all_images_gen = executor.map(
        functools.partial(predict_rest, url=url), 
        [{"data": data, "sleep_time": min(i*20, 60)} for i, data in enumerate(all_data)]
    )
    all_images = [img for img in all_images_gen]

t2 = time.perf_counter()    
print(f"It took {t2-t1}s to complete {n_requests} requests")

这运行了 900 秒。因此,通过将集群扩展到最多 2 个副本,我们实现了约 180%的加速。

关于设置超时的说明

在设置并行请求时要小心。如果你一次性发送所有并行请求(因为这只有 6 个请求),它们可能会超时。这是因为创建新节点和初始化新副本需要时间。所以如果所有请求瞬间发出,负载均衡器可能甚至没有时间看到第二个节点,最终会尝试将所有请求服务于单个节点。

上述定义的超时时间是从请求接收的时间(即进入*tensorflow-serving*队列)开始计算的,而不是从开始处理请求的时间开始计算。因此,如果请求在队列中等待时间过长,也会计入超时。

你可以在 GCP 上监控计算指标,如 CPU 使用率和内存消耗(GCP → Kubernetes Engine → Services & Ingress → 选择你的服务)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

顺序请求的使用图(上)并行请求的使用图(下)

结论

在这个两部分的教程中,我们,

  • 使用 terraform(一个 IaaS 工具)设置基础设施,主要包括一个集群和一个节点池(第一部分)

  • 部署了一个模型,并创建了一个预测服务来处理用户请求,使用了一个 Stable Diffusion 模型(第二部分)

我们设置了这个教程,使得即使是免费用户也能运行。我们设置了一个包含 2 个节点的集群,并为每个节点创建了 1 个 pod。然后我们进行了顺序和并行预测,发现并行预测带来了约 180%的吞吐量提升。

下一步:

  • 模型预热tensorflow-serving 提供了一种简单的方法来预热模型。你可以解析示例请求,它们将被加载并发送到模型中,实际处理用户请求之前。这将减少初始用户请求的延迟。

  • 动态批处理 请求 — 你可以选择动态批处理传入的请求。这将允许模型对一批输入进行预测,而不是对每个输入进行预测。只要有足够的内存,这可能会提供吞吐量的提升,让你在合理的时间范围内处理大量请求。

附录

在 pods 中进行调试

当我尝试启动它时,遇到的一个痛苦问题是遇到了以下砖墙。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 Workloads → Deployment 部分显示的错误

当我进入部署中的一个 pod 时,我得到了一个更合理的(仍然不显眼的)错误。但仍然不足以明确指出到底哪里出了问题。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

单个 pod 产生的事件

所以我必须找到一种方法来微观地调查根本原因。为此,我首先登录到相关 pod 的容器中,

kubectl exec --stdin --tty <container name> -- /bin/bash

一旦我进入,就可以利用 Linux 所赖以生存的“一切皆文件”这一范式。换句话说,你可以简单地访问一个文件以查看进程的输出/错误流。例如,在我的情况下,tensorflow-serving 进程的 PID 是 7,因此,/proc/7/fd/2 给出了该进程的错误流。

tail -n 10  /proc/7/fd/2

在这里,我能够准确地看到为什么这没有启动。这是因为容器没有必要的权限来访问MODEL_PATH中指定的 GCS 桶。

使用tf.print进行调试

正如你所知,TensorFlow 提供了两种执行风格:命令式和声明式。由于我们使用__call__()来调用模型(即self.model(<inputs>)),这些调用作为图操作执行。你可能已经知道,图执行因内部图造成的模糊性而难以调试。TensorFlow 提供的一种解决方案是使用tf.print语句。

你可以在模型调用中放置tf.print语句,这些打印语句会作为操作添加到图中,因此你可以查看执行的张量的值等,这样你可以更好地调试代码,而不是盲目尝试。

确保你的tf.print语句打印的输入出现在你希望它被打印的时间之前。如果你添加了独立/虚拟的tf.print语句,它们不会被正确地嵌入到图中。这可能会给你一种误导性的感觉,认为某些计算正在非常快速地进行,这是由于图的错误放置所导致的。

关于机器类型的说明

对于这个练习,你可以使用两种主要的机器类型n1n2N2 实例使用第三代 Xeon 处理器,这些处理器配备了特殊的指令集(AVX-512),以加速诸如矩阵乘法等操作。因此,CPU 密集型 TensorFlow 代码在 n2 机器上运行得比在 n1 上更快。

致谢

我想感谢ML Developer Programs及其团队提供的 GCP 积分,使这次教程得以成功。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值