TowardsDataScience 2023 博客中文翻译(三百零五)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

教人工智能玩棋盘游戏

原文:towardsdatascience.com/teaching-ai-to-play-board-games-77e5d1749dd9

使用从零开始的强化学习教计算机玩井字棋

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Heiko Hotz

·发表于 Towards Data Science ·18 分钟阅读·2023 年 12 月 12 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供(由 ChatGPT 创建)

这是什么内容?

目前,人工智能领域似乎每个人都在提升他们的强化学习(RL)技能,特别是在 Q-learning 方面,跟随关于 OpenAI 新 AI 模型 Q* 的最新传闻,我也参与其中。然而,我决定用我对棋盘游戏的热情来介绍 Q-learning 🤓,而不是对 Q* 进行猜测或重温 Q-learning 的旧论文和示例。

在这篇博客文章中,我将从头开始创建一个简单的程序,教一个模型如何玩井字棋(TTT)。我将避免使用任何强化学习库,比如 GymStable Baselines;所有内容都是用原生 Python 手动编写的,脚本只有 100 行。如果你对如何指导人工智能玩游戏感到好奇,请继续阅读。

你可以在 GitHub 上找到所有代码,链接为 github.com/marshmellow77/tictactoe-q

为什么这很重要?

教人工智能玩井字棋(TTT)可能看起来并不那么重要。然而,它确实提供了一个(希望)清晰且易于理解的 Q-learning 和 RL 的介绍,这在生成式人工智能(GenAI)领域可能是重要的,因为有人猜测像 GPT-4 这样的独立 GenAI 模型对于显著的进步是不够的。它们的局限性在于只能预测下一个标记,而无法进行任何推理。RL 被认为能够解决这个问题,并可能增强 GenAI 模型的响应能力。

无论你是为了迎接这些进展而提升你的 RL 技能,还是仅仅寻求一个有趣的 Q 学习入门教程,这个教程都适合这两种情况🤗

理解 Q 学习

从本质上讲,Q 学习是一种算法,它学习特定状态下一个动作的价值,然后利用这些信息找到最佳动作。让我们考虑Frozen Lake游戏的例子,这是一款用于演示 Q 学习的流行单人游戏。

在 Frozen Lake 中,玩家(从单元格 0 开始)在冰和水的网格上移动,目标是到达目标(单元格 15)而不掉入水中。每个单元格代表一个状态,玩家可以向四个方向移动:上、下、左或右。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片(使用 Stable Diffusion 创建)

在游戏开始时,代理(这就是 AI 玩家通常的称呼)没有任何信息,只会随机尝试一些动作。在 Q 学习的背景下,这个探索阶段至关重要。代理通过根据其动作获得奖励或惩罚来学习。在 Frozen Lake 中,达到目标会获得高奖励,而掉入水中则会受到惩罚。这种奖励和惩罚的系统引导代理学习最有效的到达目标的路径。

Q 学习使用一个表格,称为 Q-表,用于记录每个状态下每个动作的价值。随着代理探索环境,这个表格会不断更新。Q-表条目,称为 Q 值,表示在给定状态下采取某个动作的预期效用,它们通过贝尔曼方程进行更新。这个方程考虑了动作的即时奖励和可能的最高未来奖励(稍后会详细讲解)。

基本上,Q-表是代理的备忘单或查找表:根据游戏的状态,代理会查找该状态,确定哪个动作具有最高效用(即哪个是最佳动作),然后执行该动作。以下是 Q-表可能的示例:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

在这个例子中,如果玩家处于状态 1(即在单元格 1),他会选择动作,因为这是具有最高价值的动作。

随着时间的推移,代理探索环境并更新 Q-表,它在导航 Frozen Lake 时变得更加熟练,最终学会了一种最佳或接近最佳的策略,以可靠地到达目标。Q 学习在这种情况下的美妙之处在于它的无模型性质,这意味着它不需要环境模型,可以仅通过交互学习,使其广泛适用于各种 RL 问题。

存在许多教程演示了如何利用和实现 Q-learning 来解决 Frozen Lake 游戏,例如 towardsdatascience.com/q-learning-for-beginners-2837b777741。然而,正如前面提到的,作为一个棋盘游戏爱好者,我对将这种方法适用于双人游戏,甚至更多玩家的游戏更感兴趣。

双人游戏中的挑战

将 Q-learning 应用于双人游戏,如井字棋,需要进行一些小的修改。在 Frozen Lake 游戏中,下一状态仅由代理的行动决定。然而,在井字棋中,尽管玩家可能采取一个回合,但随后的状态还依赖于对手的行动。例如,如果我在左上角放置一个‘X’,那么下一状态是不确定的,因为我的对手有几个潜在的移动:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

可以采用几种方法来解决这个问题。一种方法是模拟对手所有可能的行动及其相应结果。这需要生成所有潜在后续状态的概率分布,并根据这些状态的预期结果更新 Q 值。然而,这种方法可能计算量较大。在本教程中,我们将采用一种更简单的方法,为对手随机采取一个动作,并根据这个动作的实际结果更新 Q 表。这很好地反映了对手的不可预测性,正如我们后面将看到的那样。通过这种方法,Q-learning 可以有效地适应双人游戏,使 AI 不仅能够学习最佳移动,还能(最终)适应人类对手的策略。

这种方法原则上与训练 AlphaGo Zero 的方法类似。该 AI 程序在快速连续的对弈中自我对弈了 490 万局围棋。在这个过程中,它不断提高自己的技能,自主学习和调整策略。这种自学习方法,绕过了模拟对手每一个可能的移动的需求,为 AI 系统提供了一种高效且有效的学习和适应复杂任务的方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

李世石与 AlphaGo 的第 2 局比赛,AlphaGo 的著名第 37 步。图片来源:commons.wikimedia.org/wiki/File:Lee_Sedol_(W)_vs_AlphaGo_(B)_-_Game_2.svg(许可 CC BY-SA 4.0)

在接下来的部分中,我们将深入探讨这些原则如何在井字棋的具体情况下应用,展示在双人环境中 Q-learning 的实现。

井字棋的 Q-Learning

当我们开始将 Q-learning 应用于井字棋时,了解我们程序的设置以及 AI 代理将要操作的环境非常重要。

概述

这段代码旨在训练一个 AI(我们称之为玩家 1智能体),通过 Q 学习(一种强化学习形式)来玩类似井字棋的游戏。它首先设置学习参数并初始化一个 Q 表来存储不同状态下不同动作的值。脚本定义了几个函数来管理游戏机制,如确定可能的移动、检查胜者、更新游戏状态,以及在移动后计算下一个状态和奖励。

在脚本的主要部分中,实现了 Q 学习算法。它运行多个回合,模拟智能体与其对手(我们称之为玩家 2)之间的游戏。在每一回合中,AI 要么探索一个随机动作,要么利用 Q 表中的知识来进行决策,从结果中学习以更新 Q 表的值。这个过程涉及随着时间的推移调整探索率,从随机探索转向更具策略性的动作。

我们设置的一个关键方面是 AI 的对手。与对手可能拥有复杂策略的更复杂场景不同,我们的 AI 将与一个随机移动的对手进行游戏。这一选择简化了学习环境,使我们可以专注于 AI 的学习过程,而不是对手策略的复杂性。

Q 学习设置

我们的 Q 学习设置涉及一些关键参数,这些参数将影响 AI 的学习方式:

learning_rate = 0.2
discount_factor = 0.9
num_episodes = int(1e7)
epsilon = 1.0  # Exploration rate
epsilon_min = 0.01
epsilon_decay = 0.999
  • 学习率 (**learning_rate**): 这决定了新信息对现有知识的影响程度。较高的学习率加速了学习过程,但可能导致不稳定。学习率为 0.2 在学习新策略和保留之前的学习之间取得了平衡。

  • 折扣因子 (**discount_factor**): 这反映了未来奖励的重要性,影响 AI 策略的远见程度。折扣因子为 0.9 时,AI 会特别重视未来奖励,鼓励 AI 前瞻性思考,而不仅仅关注即时收益。

  • 回合数 (**num_episodes**): 这是 AI 学习的游戏数量,为 AI 提供了充足的机会来体验各种游戏场景。将其设置为 1000 万 (1e7) 允许广泛的训练,为 AI 提供了从各种游戏场景中学习的充足机会。

  • 探索率 (**epsilon**): 探索率(epsilon)最初设置较高,以允许 AI 探索各种动作,而不是仅仅利用已知策略。最初,AI 会更多地进行探索(由于 epsilon 为 1.0)。随着时间的推移,epsilon 逐渐减小到 epsilon_min,AI 将开始更多地利用其学习到的策略。

关于探索率的附注

在 Q 学习中,探索率通常用符号 ε(epsilon)表示,这是一个关键参数,决定了探索(尝试新动作)和利用(使用已知最佳动作)之间的平衡。最初,智能体对环境了解不多,因此它需要广泛探索,通过尝试不同的动作。探索率通常在开始时设定为较高的值(例如 1 或接近 1),决定了智能体选择随机动作而不是根据 Q 表选择最佳已知动作的概率。

然而,随着智能体对环境的了解越来越多,Q 表变得更加可靠,探索的必要性减少,利用已获得的知识变得更加有益。这时,探索率衰减就发挥作用了。探索率衰减是一个随着时间推移而减少探索率的因素。它确保智能体在学习和收集更多信息的过程中,逐渐从探索环境转向利用 Q 表中学到的值。

这种平衡在 Q 学习中很重要,因为它可以避免两个主要问题:

陷入局部最优: 如果智能体只利用已知信息(低探索),可能会陷入局部最优。这意味着它会根据有限的信息反复选择看似最佳的动作,但可能错过发现能带来更好长期奖励的动作。

低效学习: 另一方面,如果智能体过度探索(高探索)且时间过长,可能导致低效学习。智能体可能会不断尝试次优动作而没有充分利用已经获得的知识,从而导致收敛到最优策略的速度变慢。

通过适当设置探索率及其衰减,Q-learning 算法可以有效地平衡这两个方面,使智能体能够最初探索环境,然后逐渐更多地专注于利用它所学到的最佳策略。这种平衡对于在复杂环境中学习的效率和有效性至关重要。

在接下来的章节中,我们将深入代码,看看 AI 如何使用 Q-learning 来做决策、更新策略,并最终掌握 Tic-Tac-Toe。

代码深度解析

训练脚本

这是 train.py 文件的详细解读。

训练从 for 循环开始(大致在脚本的中间),我们将在其中进行一定数量的回合:

for episode in range(num_episodes):
    state = [0] * 9  # Starting state - empty board

接着,我们随机确定起始玩家。一个更简单的方法是让我们的智能体总是作为起始玩家。然而,实现一个随机起始玩家并不比直接总是让智能体作为起始玩家多花费多少精力,并且这种方法使 Q 表模式更加通用,即我们的智能体将学习如何作为起始玩家以及非起始玩家进行游戏。

如果玩家 2 开始游戏,那么我们将为玩家 2 进行随机移动:

 # If Player 2 starts, make a random move
    if current_player == 2:
        actions = get_possible_actions(state)
        random_action = random.choice(actions)
        state = update_state(state, random_action, 2)
        current_player = 1  # Switch to Player 1

现在我们进入实际的 TTT 游戏训练循环,只有在游戏结束时才会停止。一个关键机制是之前讨论的开发 vs 探索机制。它的实现如下:

if random.uniform(0, 1) < epsilon:
    # Explore: choose a random action
    action = random.choice(actions)
else:
    # Exploit: choose the best action based on Q-table
    action = max(actions, key=lambda x: Q_table[state_str][x])

epsilon 值越低,智能体通过随机移动进行的探索越少,它将更多地利用 Q 表。

一旦选择了智能体的动作,我们将执行它并确定下一状态(以及适用的奖励):

# Take action and observe new state and reward
new_state, reward = get_next_state_and_reward(state, action)

处理所有这些操作的函数值得更仔细地查看:

def get_next_state_and_reward(state, action):
    new_state = update_state(state, action, 1)  # Player 1's move
    if is_winner(new_state, 1):
        return (new_state, 1)  # Reward for winning
    elif 0 not in new_state:
        return (new_state, 0.1)  # Draw
    else:
        # Player 2 (random) makes a move
        actions = get_possible_actions(new_state)
        random_action = random.choice(actions)
        new_state = update_state(new_state, random_action, 2)
        if is_winner(new_state, 2):
            return (new_state, -1)  # Penalty for losing
        else:
            return (new_state, 0)  # No immediate reward or penalty

在这个函数中,我们首先更新棋盘的状态并检查我们的智能体是否赢得了游戏。如果没有,我们为对手进行随机移动,并再次检查对手是否赢得了游戏。根据结果,我们返回 0(游戏仍在进行中)、0.1(平局)、+1(智能体获胜)或 -1(对手获胜)。我们选择 0.1 作为平局的奖励是为了激励智能体尽快结束游戏。

现在我们已经确定了奖励,接下来是整个程序中最关键的部分:通过 Bellman 方程更新 Q 表:

Q_table[state_str][action] += learning_rate * (
            reward + discount_factor * max(Q_table[new_state_str]) - Q_table[state_str][action])

这个 Bellman 方程在其他博客文章中解释得更好(再次参考 towardsdatascience.com/q-learning-for-beginners-2837b777741)。但简要解释如下:

如前所述,Q 表本质上是一个大的备忘单:它跟踪游戏中的所有可能状态以及从该状态开始的每个可能移动的价值。它告诉智能体在给定情况下每个移动的好坏,基于它迄今为止学到的知识。

Bellman 方程更新这个 Q 表。它通过查看智能体收到的即时奖励(赢、输、平局)和它可以移动到的未来状态(即未来奖励)的质量来实现。因此,在每局游戏后,智能体使用 Bellman 方程来修订其 Q 表,学习哪些移动可能导致胜利、失败或平局。

最后,我们调整探索率,以便在未来的游戏中,智能体更多地使用 Q 表而较少进行探索。

epsilon = max(epsilon_min, epsilon_decay * epsilon)

运行训练

一旦训练脚本准备好,我们就可以执行它。幸运的是,这个过程计算需求不高,完成得非常快,不需要特别的计算能力。例如,我在 MacBook M1 Air 上执行了这个过程,它在 1000 万局游戏中不到 5 分钟就完成了。训练完成后,我们将保存 Q 表(它不是特别大),以便我们可以用它来测试智能体,与 AI 对战,并可能在稍后的阶段继续训练,以进一步增强表格。我们来看看吧 🧐

Q 表的人工检查

这个表格相对容易理解:每一行代表了棋盘状态、可采取的行动及其质量。让我们来看看一些有趣的状态。请注意,你的表格可能会有不同(但希望是相似)的值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

棋盘状态显示了每个玩家已经放置的位置(前 3 个数字代表第一行,接下来的 3 个代表第二行,最后 3 个代表最后一行。动作对应棋盘上的位置,每个动作的数字表示该动作的质量。在这个例子中,我们看到一个状态,似乎只有一个动作(动作 7)被认为是好的,其他所有动作都显得较差。

注意:棋盘位置的索引如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

所以,让我们来可视化 Q 表中的这个特定条目的棋盘状态:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

的确,在这个位置,代理(玩家 1)唯一的好选择是选择位置 7。所有其他移动可能会导致输掉比赛(请记住,玩家 2 将在下一轮随机移动,因此输掉比赛并非必然)。

让我们再看一个例子:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

在这个例子中,显然最佳移动是选择位置 8(右下角)并赢得比赛。如果代理选择其他位置,它很可能会输掉比赛。因此,Q 表将指示我们的代理采取动作 8。

测试新代理

现在我们已经训练了模型,我们可以用 GH 仓库中的脚本test.py来测试它。在脚本中,我们将让代理与一个随机移动的对手进行若干局比赛,看看它的表现如何。我们首先初始化我们的代理并加载 Q 表以便在游戏环境中用于决策。play_game函数模拟了一场比赛,使用加载的 Q 表来指导代理的决策。这里的游戏环境是一个简单的 3x3 棋盘,每个状态代表棋盘的不同配置。

代理以玩家 1 的身份,根据 Q 表做出决策——选择当前状态下值最高的行动。如果在 Q 表中找不到状态,代理将做出随机移动。这种学习行为和随机性的结合有助于评估训练的鲁棒性。玩家 2 的移动完全随机,为代理提供了多样化的场景。

这些游戏的结果会被跟踪,量化胜利、失败和平局的数量。这有助于评估训练模型的效果。如果设置了log_lost_games标志,将保存详细的失败游戏日志,这对于进一步分析和改进模型是非常宝贵的。这一测试过程,通过进行大量游戏,提供了对训练后代理能力的全面了解。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

与 AI 对战

看起来对随机机器人进行的测试很成功。我们的 AI 赢得了超过 95%的比赛。现在,我们想亲自与 AI 对战。我们可以使用play.py来实现这一点。

在这个程序中,我们通过一个简单的控制台界面与 AI 互动。游戏板表示为一个 3x3 的网格,每个位置从 0 到 8 编号。当轮到我们时,我们会被提示输入一个数字,以选择我们想要移动的位置。

AI 使用从 CSV 文件加载的 Q 表来做出决策。这个 Q 表来源于之前的训练过程,引导 AI 根据当前的游戏板状态选择最佳可能的移动。如果 AI 遇到 Q 表中没有的状态,它将默认进行随机移动。

游戏在我们的回合和 AI 的回合之间交替进行。每次移动后,更新后的棋盘会被显示,程序会检查是否有赢家。如果玩家获胜或游戏结果为平局,游戏结束,结果将被宣布——无论是我们获胜、AI 获胜还是平局。

这个互动游戏提供了一个很好的机会来实时测试 AI 的能力。让我们开始吧:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

在这个游戏中,如果我们不选择动作 0(左上角),AI 将有机会赢得比赛。它会意识到这一点吗?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

确实做到了!很好😊

结论

在这篇文章中,我们训练了我们的 AI 代理对抗一个进行随机移动的玩家。这已经足够好,能够在对抗进行随机移动的对手时达到超过 95%的胜率。但是,有方法可以改进训练过程,希望也能提高 AI 的表现。

参数调整的影响

将 Q 学习应用于井字游戏揭示了强化学习的一个关键方面:调整参数的艺术。正确设置这些参数,如开发与探索之间的平衡、学习率和折扣因子,是 RL 代理成功的关键。

  • 探索与利用:epsilon值控制,这一平衡决定了智能体尝试新策略的频率与依赖已知策略的比例。高探索率鼓励智能体尝试新事物,可能导致创新策略,而高利用率使智能体依赖现有知识,虽然可能更高效,但可能会错过更好的策略。

  • 学习率: 高学习率意味着智能体迅速采纳新信息,这在动态环境中可能有利,但如果智能体过快地覆盖有用的学习,可能导致不稳定。相反,低学习率意味着智能体更多依赖过去的知识,导致稳定但可能较慢的学习。

  • 折扣因子: 这个参数影响智能体对未来奖励的重视程度。高折扣因子使智能体更具前瞻性,考虑其行动的长期后果。相反,低折扣因子则使智能体目光短浅,专注于即时奖励。

这些参数的变化可以显著改变 RL 智能体的行为。例如,折扣因子低的智能体可能会以攻击性方式玩井字棋,专注于即时胜利,而不是制定未来的策略。相反,折扣因子高的智能体可能会更具策略性,考虑每一步对游戏未来状态的影响。

同样,高学习率的智能体可能迅速适应新策略,不断发展其游戏玩法,而低学习率的智能体可能坚持经过验证的策略,游戏中的变化较小。

轮到你来实验了

这就是强化学习真正的激动所在。每个参数都可以进行微调,以观察它如何影响 AI 智能体的学习和表现。我邀请你深入这个实验的世界。调整学习率、探索率和折扣因子,观察这些变化如何影响 AI 在井字棋游戏中的策略。

更高级的技术

为了进一步提高模型的表现,实施自我对弈机制,即 AI 与来自不同训练阶段的自身版本对弈(而不是与进行随机移动的对手对弈),可能是一种有效的策略。这种方法在 AlphaGo 等系统中成功应用过,并可能导致更强大和适应性更强的 AI 玩家。

对于更复杂的游戏,如国际象棋和围棋,维持一个 Q 表将不再可行,因为它变得过于庞大。在这些游戏中,采用像深度 Q 学习这样的技术可以显著增强 AI 的学习能力。通过使用神经网络来逼近 Q 表,AI 可以处理超出简单 3x3 井字棋网格的更复杂状态,使其在更复杂的游戏中具备可扩展性。

总之,目前的设置已经展示了有希望的结果。然而,这些建议的改进可能会进一步提升 AI 的表现,将其从一个合格的井字棋玩家转变为一个能够应对更复杂战略游戏的高级 AI。

进一步的相关资料

如果你对学习更多关于强化学习如何应用于棋盘游戏感兴趣,可以查看下面的两个视频。第一个视频非常简短,深入探讨了现代象棋 AI 机器人如何进行游戏

第二个视频是电影AlphaGo(在 YouTube 上免费观看),讲述了 AlphaGo 模型的开发过程以及它如何击败当时的世界冠军:

Heiko Hotz

👋 在MediumLinkedIn关注我,阅读更多关于生成 AI、机器学习和自然语言处理的内容。

👥 如果你在伦敦,可以参加我们的NLP London Meetups

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

教学 CLIP 时尚

原文:towardsdatascience.com/teaching-clip-some-fashion-3005ac3fdcc3?source=collection_archive---------3-----------------------#2023-03-07

训练 FashionCLIP,一个专门用于时尚的 CLIP 模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Federico Bianchi

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 3 月 7 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Domenico Loia 提供,发布在 Unsplash 上。

这是一篇简短的博客文章,描述了 FashionCLIP。如果你是数据科学家,你可能需要处理图像和文本。然而,你的数据将非常特定于你的领域,标准模型可能效果不佳。本文解释了如何在领域特定的环境中使用领域特定的视觉和语言模型,以及为何使用这些模型可能是创建搜索引擎或(零样本)分类器的一个有前景的方式。

FashionCLIP 是一种用于时尚行业的新视觉和语言模型,支持从业者解决两个任务:

  • 分类:产品图像的零样本分类;

  • 搜索:根据查询高效检索产品。

尽管 FashionCLIP 是许多人努力工作的结果,这篇博客文章主要是我在构建过程中获得的惊人经验的总结和个人观点,并不一定代表所有其他作者及其组织的观点。

模型

我们目前以两种不同的格式发布模型:

我们还有一个 colab 教程 介绍了使用 FashionCLIP 可以做的大部分事情。

介绍

时尚是可以从 AI 产品中受益最多的行业之一。实际上,由于领域的性质、不同的目录和客户特定的数据集,通常很难构建可以无缝应用于不同问题的解决方案。

想象一下在一家大型时尚公司工作的两位数据科学家:Mary 和 Luis。他们必须应对不断变化的系统,其操作需要持续的关注:

  • Mary 正在构建一个 产品分类器 以帮助大规模分类:她的模型接收一个产品并从一系列类别中选择一个(鞋子、连衣裙等);

  • Luis 正在研究 产品匹配 以改善搜索体验:他的模型接受一种支持的语言中的查询(例如,“一件红色连衣裙”),并返回匹配该查询的产品列表。

正如每个从业者所知道的,任何新的生产模型都会带来复杂的生命周期和某种程度的脆弱依赖:

  • 随着库存的增长和类别的变化,Mary 的模型需要不断重新训练;

  • Luis 的模型依赖于产品元数据的质量。

同一公司,不同用例,不同模型。

如果有另一种方法呢?

今天我们尝试向前迈出一步,展示如何构建一个用于时尚数据的通用模型。我们描述了 FashionCLIP,它是著名的 CLIP 模型的微调版本,专门处理时尚数据。我们最近的关于 FashionCLIP 的论文已在《自然科学报告》中发布。

Chia, P.J., Attanasio, G., Bianchi, F. 一般时尚概念的对比语言与视觉学习Sci Rep 12, 18958 (2022)。 doi.org/10.1038/s41598-022-23052-9

FashionCLIP 的诞生源于与Farfetch的合作,这是一家在纽约证券交易所上市的巨大(且真实的)奢侈品电商。FashionCLIP 是与来自业界(Coveo、Farfetch)和学术界(斯坦福、博科尼、比科卡)的人们共同完成的工作。模型权重可以在线获得,格式为HuggingFace。使用示例可以在Patrick 的 repo中找到。

我们将首先介绍用例,并解释一些模型的更深入细节。最后,我们将分享我们用来训练模型的代码以及如何获取权重。

FashionCLIP: 故事

FashionCLIP 是一个通用模型,用于将时尚产品的图像及其描述嵌入到同一个向量空间中:每个图像和每个产品将由一个单独的稠密向量表示。

为什么我们要把它们放在同一个向量空间中? 这样它们才能进行比较。 这个原则是像 CLIP 这样的模型成功的关键。

FashionCLIP 源自原始的 CLIP。这个想法非常简单。如果你:

  • 大量带有标题的图像;

  • 一个图像编码器(这可以是 CNN 或 ViT);

  • 一个文本编码器(这可以是基于 transformers 的语言模型)。

你可以训练一个模型(使用对比损失)来使图像的嵌入接近其标题嵌入,并远离不相关的标题。在 GIF 中,你展示了一个二维的例子。这个概念可以推广到 N 维。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FashionCLIP 将描述和图像嵌入到同一个向量空间中。这对于零-shot 分类和图像检索非常有用。图片由作者使用 Farfetch 目录提供。

最终结果是一个多模态空间,允许你在视觉和文本交互之间移动,使用新的图像和新的文本描述:如果你有一些文本,你可以检索到对应的图像(如产品搜索);如果你有一些图像,你可以排序标题基于语义相似性(如分类)。

要微调 CLIP,你需要一个好的数据集。我们与 Farfetch 合作,使用高质量的图像和标题来训练 CLIP。这个数据集(即将公开发布)包含了超过 80 万的样本。

我们训练模型几个周期,并检查在多个基准上的表现,包括零-shot 分类、探测和检索。在查看结果之前,让我们深入了解一下现在有了训练好的 FashionCLIP 后我们可以做什么。

我们不会深入探讨 CLIP 本身。如果你想了解更多关于 CLIP 的内容,我这里有一篇专门的博客文章:

[## 如何训练你的 CLIP]

介绍 CLIP 以及我们如何在 HuggingFace 社区周期间为意大利语言微调它。

towardsdatascience.com

FashionCLIP 可以处理的两个关键任务是:

  • 图像检索

  • 零-shot 分类

检索:从文本到图像

我们首先从文本到图像:我们使用 FashionCLIP 文本编码器对搜索查询(“一件红色连衣裙”)进行编码,并通过简单的点积检索最接近的图像向量。点积的值越大,文本和图像之间的相似度越高。在下面的 GIF 中,搜索以 4 个产品图像为例进行。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对于检索,我们可以在目标目录上预先计算图像嵌入。在运行时,我们编码查询并通过简单的点积对图像进行排名。图片由作者使用 Farfetch 目录提供。

虽然“红色连衣裙”是一个简单的查询,搜索引擎可能不需要额外的输入,但稍微模糊一些的查询,如“浅红色连衣裙”与“深红色连衣裙”则变得有趣,其中“浅”和“深”是同一颜色的修饰词:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FashionCLIP 有助于消歧义几何特征。图片由作者使用 Farfetch 目录提供。

更有趣的是 FashionCLIP 捕捉到衣物中代表的物品的能力。产品描述通常未能明确提及具象图案,FashionCLIP 能够识别印刷的物品,即使是类似卡通的形状,如下面 T 恤上挂着的

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

FashionCLIP 识别印刷在 T 恤上的具象物品。图片由作者使用 Farfetch 目录提供。

虽然我们尚未详细评估这一能力,但我们相信这可能来自原始 CLIP 所具备的“知识”,在微调过程中部分保留。

当然,信息在描述中(例如,品牌通常在描述中提及)比 FashionCLIP 可能捕获的任何语义细微差别编码得更好。然而,它在增强标准学习排名信号而没有行为数据方面的能力可能大大改善搜索体验,特别是在冷启动场景下。

分类:从图像到文本

我们现在从图像到文本进行分类:我们使用 FashionCLIP 的图像编码器对要分类的时尚物品图像进行编码,并通过点积检索最接近的标签向量:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对于零-shot 分类,我们计算查询项的图像嵌入和目标标签的文本嵌入。图片由作者使用 Farfetch 目录提供。

CLIP-like 模型的技巧在于将标签视为语义上有意义的标签,而不是类别变量。

换句话说,当我们“分类”时,我们在问“这些文本中哪个是这个图像的最佳标题?”。

得益于 CLIP 的预训练和自然语言的无限可能性,我们现在拥有一个不局限于任何特定标签、类别或属性的分类器:当然,首要应用可能是在 Farfetch 目录中的新产品上使用该分类器,我们还可以在具有不同标签或用途的其他数据集上重复使用相同的模型,例如:

  • 如果供应商没有将鞋子分类为“高跟鞋”与“平底鞋”,我们可以添加该属性;

  • 如果商品管理员在目录中创建新的视图——例如,将项目匹配到风格——我们可以根据新的维度(“优雅”、“街头风”等)对现有产品进行分类。

CLIP 的泛化能力当然是以某些精度为代价的:也就是说,如果我们以监督方式训练一个新的分类器来解决上述用例,它们都会比 FashionCLIP 更好。像往常一样,真实世界的机器学习没有一刀切的方案,模型之间的权衡可以根据用例的重要性、训练时间、标注成本等不同方式进行评估。

性能

我们在两个不同任务和多个数据集上将 FashionCLIP 与 CLIP 进行比较。有关设置的更多细节请参阅论文,本节的范围只是为了展示在时尚相关任务中使用 FashionCLIP 替代 CLIP 时性能的提升。

对于零样本分类,我们使用了三个不同的数据集(KAGL、DEEP 和 FMNIST),这些数据集应作为分布外数据集(我们知道从其他实验中我们在领域内数据上表现比 CLIP 好得多,但这是预期中的)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同数据集上的加权宏 F1 分数(领域外数据)。FashionCLIP 在这些数据集上显示出相对于 CLIP 的显著提升。

Zero-shot 结果确认我们的模型表现如预期!

对于图像检索,我们使用了在训练时遗漏的原始数据集的一部分。需要注意的是,这显然使我们相对于 CLIP 有优势,因为这个数据集对于我们来说是领域内的。然而,这仍然是一个有趣的实验。以下结果确认我们的模型表现最佳:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在我们内部测试集上的前 5 和前 10 精度(领域内数据)。FashionCLIP 的检索性能明显更好。

Torch 实现和 HuggingFace 权重

由于 Patrick 的工作,FashionCLIP 使用起来非常简单。你只需加载模型并使用简单的方法进行零样本分类,所有这些都可以用 Python 完成!

fclip = [...load FCLIP ...]

test_captions = [
    "nike sneakers", "adidas sneakers", "nike blue sneakers", 
    "converse", "nike", "library", "the flag of italy",
    "pizza", "a gucci dress"
]
test_img_path = 'images/16790484.jpg'

fclip.zero_shot_classification([test_img_path], test_captions)

你还可以进行图像检索!

candidates = fclip.retrieval(['shoes'])
print(candidates)

告别

漫长旅程的总结

构建 FashionCLIP 是一段长时间且有趣的冒险,涉及到来自地球上一些最酷地方的老朋友和新朋友。结果总是更美好,当你和朋友一起获得它们时。此外,我们中的一些人已经合作多年,实际上从未在现实生活中见过面!

从更务实的角度来看,我们希望 FashionCLIP 能为快速迭代内部和外部时尚用例的公司开辟前所未有的机会:例如,虽然你可能会最终构建一个专注的风格分类器,但使用 FashionCLIP 进行概念验证将大大证明该功能的价值,而无需在新的模型生命周期支持上进行前期投资

当我们考虑零售领域日益增长的智能 API SaaS 服务提供商——如 Coveo、Algolia、Bloomreach——时,垂直模型的重要性不可低估:由于 B2B 公司以账户为基础增长,稳健性和可重用性比纯粹的精准度更为重要。我们展望不久的将来,FashionCLIP —— 以及 DIYCLIP、ElectronicsCLIP 等 —— 将成为 B2B 机器学习参与者的标准组件,使得迭代迅速、数据标准化,并在完全不同于目前的水平上实现规模经济。

我去年也在 Pinecone 上做了一个关于 FashionCLIP 的演讲:

我在 Pinecone 上关于如何构建像 FashionCLIP 这样的模型的演讲。

另一个演示

开源的力量是什么?Pablo 看到这个模型并联系了我们,提供了一个用户界面来帮助我们测试标准的 HuggingFace CLIP 与我们刚刚发布的 FashionCLIP 之间的差异——然后我使用了 Objective Search 来测试使用 FashionCLIP 的几个查询(您可以在这里亲自查看):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 FashionCLIP 进行搜索。GIF 由作者提供,图片来自 H&M 数据集。

很酷,不是吗?

局限性、偏见与公平性

我们承认 FashionCLIP 存在某些限制,并预计它继承了原始 CLIP 模型中的一些局限性和偏见。我们不期望我们的微调会显著增加这些限制:我们承认,我们使用的时尚数据对性别的概念做出了明确假设,例如“女性的蓝色鞋子”,这不可避免地将服装的某些方面与特定的人联系在一起。

我们的调查还表明,所使用的数据在 FashionCLIP 中引入了某些限制。从文本模态来看,鉴于大多数来自 Farfetch 数据集的标题较长,我们观察到 FashionCLIP 在处理较长查询时可能比短查询表现更好。

从图像模态来看,FashionCLIP 对标准产品图像(居中、白色背景)也存在偏见。这意味着模型可能在不具备相同结构的图像上表现不佳。

我们做的更多事情

FashionCLIP 的发展经历了漫长的过程,但在等待正式发布期间我们做了一些事情。

GradedRecs

我们在 FashionCLIP 的基础上进行了探索,通过遍历潜在空间来研究推荐。如果你感兴趣,请查看我们的 论文

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

GradedRec。图片由作者提供。

推荐系统评估中的公平性

如果你对相关行业任务感兴趣,例如推荐系统,我们去年进行了一项关于推荐系统全面评估的挑战。

这个挑战旨在理解如何构建不仅仅关注点对点度量(例如准确率)的评估。你可以在这里找到一些细节和介绍性的博客文章

[## 关于推荐系统的全面评估

EvalRS:在多个测试中评估推荐系统

fede-bianchi.medium.com](https://fede-bianchi.medium.com/a-rounded-evaluation-of-recommender-systems-b9fa101ef79a?source=post_page-----3005ac3fdcc3--------------------------------)

教学很难:如何训练小模型并超越大型对手

原文:towardsdatascience.com/teaching-is-hard-how-to-train-small-models-and-outperforming-large-counterparts-f131f9d463e1

|模型蒸馏|人工智能|大型语言模型|

蒸馏大型模型的知识是复杂的,但一种新方法显示出惊人的性能

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Salvatore Raieli

·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 11 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 JESHOOTS.COM 提供,来源于 Unsplash

大型语言模型(LLMs)和少样本学习已经证明我们可以将这些模型用于未见过的任务。然而,这些技能是有代价的:大量的参数。这意味着你还需要一个专业化的基础设施,并且将最先进的 LLMs 限制在只有少数几家公司和研究团队中。

  • 我们真的需要为每个任务设计一个独特的模型吗?

  • 是否有可能创建专门的模型来替代它们用于特定的应用?

  • 我们如何才能拥有一个在特定应用中与大型 LLMs 竞争的小模型?我们是否确实需要大量的数据?

在这篇文章中,我对这些问题给出了答案。

“教育是人生成功的关键,教师在学生的生活中留下了深远的影响。” ——所罗门·奥尔蒂斯

匹配冠军!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Fauzan Saari 提供,来源于 Unsplash

教学的艺术是协助发现的艺术。——马克·范·多伦

大型语言模型(LLMs)展现了革命性的能力。例如,研究人员对像上下文学习这样的难以捉摸的行为感到惊讶。这导致模型规模的增加,越来越大的模型寻找新能力,这些能力超出了参数的数量。

## 关于上下文学习的一切

什么是大型语言模型,它是如何工作的,以及是什么使大型语言模型如此强大

[towardsdatascience.com ## 人工智能中的涌现能力:我们是否在追逐一个神话?

对大型语言模型出现特性的观点变化

[towardsdatascience.com

但这会有代价;例如,GPT-3(超过 175 万亿个参数)至少需要 350 GB 的 GPU 来运行。这意味着你需要专门的基础设施来训练和使用它进行推理。将这样的模型部署以使其公开访问需要克服重大挑战和成本(尤其是如果你想减少延迟)。因此,只有少数公司能够负担得起为实际应用部署一定规模的模型。

拥有超过 100 B 参数的模型具有大型建模能力,但这些能力分散在许多技能上。相比之下,少于 10 B 的模型建模能力较弱,但可以将这种能力集中于单一任务。例如,推理是超过 100 B 参数模型展示的一种能力,但在小型模型中缺失。这项研究的作者表明,推理只是大型 LLM 中的众多能力之一。因此,将小型模型的训练重点放在推理上,即使模型小于 100 B,也可以获得显著的结果。

当然,专注于小型模型会有代价:对其他任务的表现。但通常你只对一个任务感兴趣,因此可以使用小型模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

以前的研究表明,推理能力随着规模的增加而突然出现(左侧图)。这项研究的作者表明,通过专注于推理任务(专业化),你可以在推理方面取得良好的结果。图片来源:这里

因此,几家公司专注于仅对特定任务表现良好的小模型。此外,微调的使用使得为特定应用创建小型专业模型成为可能。对于一些任务,如分类,微调需要一个带注释的数据集。收集这些带注释的数据集是昂贵的,因此使用的另一种技术是蒸馏

蒸馏是一种技术,通过它你可以利用从更大模型生成的标签来训练一个小模型。收集这些未标记的数据集可能同样昂贵(例如,在医疗领域)。性能要求越高,成本也就越高。因此,使用微调或蒸馏来实现与大型语言模型(LLM)相同的性能可能在计算上是昂贵的。

因此,我们如何才能使小模型以数据和时间高效的方式从 LLM 中学习呢?

如何让 LLM 成为高效的教师

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 ThisisEngineering RAEng 提供,来源于 Unsplash

我不能教任何人任何东西;我只能让他们思考。——苏格拉底

当我们想训练一个小模型时,LLM 要么用来为未标记的文本生成标签,要么用于数据增强(从 LLM 生成的示例数据集中提取)。直观上,这可能不足以使模型学习高效。

例如,如果我想让我的小模型学习如何对推文进行排序(积极、消极或中立),我可以下载大量推文,通过 LLM 生成标签,然后用这些标签训练小模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

蒸馏的示意图。图片由作者提供

然而,虽然这对于像推文分类这样的简单任务有效,但对于更复杂的任务来说是不够的。 我们可能会从互联网上下载谜题并让 LLM 解决它们,但解决方案本身并未提供关于解决过程的任何信息。一个通过解决方案训练的小模型不会学会如何解谜。

确实,要学会解决困难的任务(例如解谜),你需要比仅仅解决方案更多的信息。

实际上,这对于 LLM 也是如此,对于推理任务(算术、常识和符号推理),提供链式思维的上下文有助于模型得出解决方案而不会产生幻觉。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

基于这一意图,一些谷歌研究人员甚至训练了在特定任务上超越 LLM 的小模型(770M 参数与 540BPaLM)。他们随后在最近发表的一篇论文中描述了这种方法。

## 逐步提炼!用更少的训练数据和更小的模型超越大型语言模型…

部署大型语言模型(LLMs)具有挑战性,因为它们在内存使用上效率低下,并且计算密集型……

arxiv.org

简而言之,作者利用了 LLM 进行推理的能力(超越单纯生成标签)。通过使用一个未标记的数据集,他们要求 LLM 生成正确的标签和推理(为什么这是最合适的标签的自然语言解释)。之后,他们使用标签和推理来训练小模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

该方法的示意图。图像由作者提供

通过这种方式,他们不仅向小模型提供了问题的解决方案,还提供了老师(LLMs)如何得出该解决方案的过程。 此外,推理不仅包含解释,还包含理解任务的有用元素(这些元素从简单的输入中不易推断出,特别是对于参数有限的模型)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

逐步提炼

更详细地说,作者使用了与链式思维(CoT)相同的提示。一个提示包括一个问题、背景或推理,以及问题的答案。之后,将推理附加到问题上,模型必须给出答案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

小模型通过简单的多任务方法进行训练:它不仅需要预测正确的标签,还需要生成相应的推理。损失函数也会考虑生成推理时是否出现错误。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过这种方式,作者迫使模型生成中间推理步骤,从而引导模型找到正确的答案。

从比喻的角度来看,这就像是一个老师强迫学生写下所有的推理步骤,而不是直接给出答案。这种方法的优点是,在测试时,模型将不再需要老师模型(LLM),而应该学会进行推理。

我们能否教会学生推理?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Element5 Digital提供,来源于Unsplash

你告诉我,我会忘记。你教我,我会记住。你让我参与,我会学习。 – 本杰明·富兰克林

作者使用 PaLM (540 B 参数)作为 LLM 生成理由。他们选择使用T5作为小模型,使用现有的预训练权重检查点。有趣的是,作者使用了一个已经训练过的非常小的模型。通过这种方式,他们使用一个已经具备一般语言知识的模型,但可以适应特定任务。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模型比较,以更好地理解大小差异(圆圈按比例)。图片由作者提供,生成图片的脚本可以在这里找到

他们选择了三个特定的自然语言处理任务:

如所示,这些任务和数据集要求模型展示推理能力。

在文章中,该方法与两种经典方法进行了比较:

  • 微调 其中预训练模型在带有正确标签的注释示例上进行训练。

  • 蒸馏 在该方法中,LLM 用于生成真实标签。

结果显示,新的方法(逐步提炼)在所有基准数据集和任务中都优于标准微调,同时所需示例也远少于达到更好表现的标准。因此,这种方法性能更佳,同时成本更低(仅有 12.5%的示例表现超过传统微调)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

对于标准蒸馏而言,同样的新方法在性能上更优,并且所需的示例数量也少得多。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:这里

作者们随后使用不同版本的模型(220M、770M、11B)采用相同的方法,并与 LLM 基线(PaLM)进行比较。结果表明,新方法根据规模提高了性能(更大的模型表现更好)。此外,逐步蒸馏在某些任务上似乎甚至超越了 LLM 基线。换句话说,770M 模型在 ANLI 中超越了一个大 700 倍的模型。更令人印象深刻的是,在 e-SNLI 中,一个 220M 的模型超越了一个大 2000 倍的模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

在标准微调中,我们使用人工标注的数据,而在蒸馏中,我们使用未标注的数据。结果类似,显示模型即使从 LLM 标注的数据中也能学习。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

这些结果本身已经很令人印象深刻,但令人难以置信的是你不需要整个数据集。 即使仅用 0.1% 的数据集,该方法仍然有效。对于标准的微调和任务蒸馏,您需要更多的示例才能看到显著的性能。在 ANLI 中,对于 T5-770M,80% 的示例足以超越 PaLM 540B。即使使用完整的数据集,标准微调也无法达到 LLM 基线。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

正如作者所提到的,尽管这种方法也适用于其他模型(如 20B GPT-NeoX 模型),但结果不如预期。这是因为 PaLM 提供了更高质量和更详细的推理。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

在一个消融研究中,他们注意到多任务训练效果更好。换句话说,让模型生成推理有助于它的学习。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源: 这里

作者们也发布了供社区测试的代码:

## GitHub - google-research/distilling-step-by-step

通过在 GitHub 上创建帐户,您可以为 google-research/distilling-step-by-step 的开发做出贡献。

GitHub - google-research/distilling-step-by-step

结束语

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Saif71.com 提供,来自 Unsplash

教育是创造所有其他职业的唯一职业。 – 无名氏

本文展示了如何利用 LLM 教导较小的模型解决特定任务。 超越结果,本文还展示了即使是较小的模型,通过提供上下文也能得出解决方案。因此,这种方法使用户能够用更少的数据提炼出一个小模型,并超越大型 LLM:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

文章的示意图。图片来源:这里

作者在本文中展示了比 LLM 小 2000 倍的模型能够学习并在复杂任务(如推理任务)上超越教师模型。此外,与经典的逐步提炼方法相比,它需要的数据要少得多。

一般来说,近年来模型学习研究发生了范式转变,试图将记忆与实际学习分开。

[## 理解:学习是泛化而非记忆

理解神经网络如何学习可以帮助我们避免模型忘记所学内容。

levelup.gitconnected.com

确实,本文表明,要执行特定任务,你并不需要大容量(记忆)。你可以教导一个小模型通过提供解决问题的信息来学习任务(泛化)。

这项工作很重要,因为用少量数据,可以训练出一个更小的模型在任务上表现出色。这些模型可以以更低的成本更容易地部署。此外,这种方法适用于任何模型,因此用户可以使用开源模型(如 LLaMA)或专有模型(GPT-4 或 PaLM)的 API 进行逐步提炼,创建自己的专业模型。

这项工作开辟了许多令人兴奋的可能性,如以低成本开发适用于多个应用的专业模型,并且其性能优于巨型模型。这些模型不仅可以在线部署,还可以在桌面计算机或手机应用中使用。因此,拥有一个小而专有的数据集,你可以用有限的资源开发和部署专家模型。

例如,你可以设想一个用户开发一个专门解决谜题的小模型。你只需与 LLM 创建推理,使用逐步提炼来训练你的专家模型,然后甚至可以将其部署到手机应用上。

TL;DR

  • Google 公布了一种新的简单方法来从大型模型中提取知识。通过使用推理和答案,你可以教导一个小模型(甚至小 2000 倍)在推理任务中超越 LLM。

  • 这种方法超越了之前的最新技术。

  • 这种方法只需要一个小的训练集和较小的模型尺寸

  • 这种方法使得可以为专业任务部署独立的语言模型。现在模型尺寸与网页应用兼容,并且可以在设备上进行推理,无需复杂的基础设施。

你怎么看?在评论中告诉我

如果你觉得这很有趣:

你可以查看我的其他文章,你也可以 订阅 以便在我发布文章时收到通知,也可以在LinkedIn上联系或找到我。

这是我 GitHub 仓库的链接,我计划在这里收集与机器学习、人工智能等相关的代码和资源。

[## GitHub - SalvatoreRa/tutorial: 机器学习、人工智能、数据科学的教程…

关于机器学习、人工智能、数据科学的教程,包括数学解释和可重复使用的代码(用 Python 编写…

github.com](https://github.com/SalvatoreRa/tutorial?source=post_page-----f131f9d463e1--------------------------------)

或者你可能对我的一篇近期文章感兴趣:

## 无需重新训练即可重塑模型的记忆

擦除大型语言模型所学到的有问题内容的任何回响

towardsdatascience.com [## 超越语言:用 AI 解码脑波中的言语

AI 能够从非侵入性脑记录中解码语言

levelup.gitconnected.com](https://levelup.gitconnected.com/beyond-words-unraveling-speech-from-brain-waves-with-ai-7ff81862dfff?source=post_page-----f131f9d463e1--------------------------------)

参考文献

这是我在撰写本文时参考的主要文献列表(只引用了每篇文章的第一作者姓名)。

  1. 傅, 2023, 《将较小的语言模型专门化为多步骤推理》, 链接

  2. 辛顿, 2015, 《提炼神经网络中的知识》, 链接

  3. 霍华德, 2018, 《通用语言模型微调用于文本分类》, 链接

  4. 卡普兰, 2020, 《神经语言模型的规模定律》, 链接

  5. 韦, 2022, 《链式思维提示在大型语言模型中引发推理》, 链接

  6. Hsieh, 2023, 逐步提炼!以更少的训练数据和更小的模型尺寸超越更大的语言模型,链接

  7. Chowdhery, 2022, PaLM: 通过路径扩展语言建模,链接

  8. Raffel, 2019, 使用统一的文本到文本转换器探索迁移学习的极限,链接

教授语言模型使用工具

原文:towardsdatascience.com/teaching-language-models-to-use-tools-7fd58916c66b

使用工具让我们作为人类更具能力。LLMs 是否也是如此?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Cameron R. Wolfe, Ph.D.

·发表于 Towards Data Science ·阅读时间 17 分钟·2023 年 8 月 27 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(照片由 Barn Images 提供,来自 Unsplash

随着我们对大语言模型(LLMs)了解的深入,这些模型变得越来越有趣。这些模型能够准确解决各种复杂任务。然而,与此同时,它们在某些我们人类认为基本的功能上却存在困难!例如,LLMs 常常犯算术错误,缺乏获取当前信息的能力,甚至难以理解时间的进程。鉴于这些局限性,我们可能会想,如何才能使 LLMs 更具能力?LLMs 注定要永远受到这些局限的困扰吗?

人类历史上的许多进步都由获得新的创新工具(例如 印刷机计算机)所推动。相同的发现是否适用于 LLMs? 在这篇概述中,我们将研究一个最新的研究方向,旨在教会 LLMs 如何使用外部工具,这些工具通过简单的文本到文本的 API 提供。通过使用这些工具,LLMs 可以将执行算术或查找当前信息等任务委派给专门的工具。然后,这些工具返回的信息可以被 LLM 在生成输出时用作上下文,从而产生更准确和有依据的响应。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [1] 和 ChatGPT Plus)

使 LLMs 更具能力

为 LLM 提供外部工具是一种可靠的方法,可以解决这些模型面临的一些限制。然而,LLM 不会自然地知道如何使用工具,这就提出了一个问题:我们如何教会模型利用外部工具? 在本节中,我们将探讨我们拥有的一些选项,并列举对构建 LLM 应用程序有用的各种工具。

不同类型的学习

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

LLM 的不同学习形式(作者创建)

教会 LLM 利用工具与学习如何解决其他任务没有什么不同。由于这些模型以几种不同的方式学习,我们将在这里讨论 LLM 的主要学习形式。本文之外,网上还有详细解释

预训练。 LLM 的第一个和最基本的学习形式是预训练。在预训练过程中,模型在大量未标记的文本数据上进行训练,使用语言建模目标。预训练过程从随机初始化开始,计算成本相当高。通常,由于计算成本,预训练只执行一次——我们不希望频繁重复预训练过程!值得注意的是,预训练的计算成本解释了像 ChatGPT 这样的 LLM 中存在知识截止点的原因。这些模型在预训练期间学习所有知识,因此知识截止点仅与最近预训练期间存在的数据相关。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

LLM 的微调方法(来自[11])

微调。 在预训练之后,LLM 可以准确地执行下一个标记预测,但这并不总是意味着它们实际上有用。例如,如果我们玩一下GPT-2的演示,只需 2 分钟,我们立刻会发现准确预测下一个标记可能会导致一些相当无聊和无用的输出!因此,我们通常在预训练之后对 LLM 进行微调,通常通过监督微调(SFT)或从人类反馈中进行强化学习(RLHF);详情见上面的图片和这里。虽然这些技术的细节超出了本文的范围,但基本的思路是:

  1. 筛选更多的训练数据(例如,针对我们要解决的任务的领域数据、正确对话的示例、人类对 LLM 输出的反馈等)。

  2. 使用强化学习或带有(自我)监督目标的梯度下降对模型参数进行训练。

通过这样做,我们可以完成很多事情!例如,使用 RLHF 对 LLM 进行微调 [11] 已被证明可以使 LLM 更有趣、更准确、更有帮助。更进一步,Meta 最近的 LIMA 出版物显示,通过仅 1,000 个高质量对话示例进行 SFT,可以生成一个与 GPT-4 质量相媲美的模型 [12]。简单来说,微调将我们从一个普通的 LLM 提升到真正特别且有用的水平。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [7])

上下文学习。 我们应当了解的最终学习形式是上下文学习;见上文。上下文学习不同于预训练和微调,它并不会实际修改底层模型的参数。相反,我们通过修改提示来教 LLM 更有效地解决问题!特别是,我们可以通过使用特定的提示技术重新表述提示,甚至将数据插入提示中以进行少样本学习。微调和上下文学习之间的区别如下所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [7])

上下文学习是极其强大的,因为它允许我们使用单一模型解决各种不同的任务。我们可以将有用的数据插入到 LLM 的提示中,而不是微调模型或修改其底层参数。LLM 可以从这些数据中学习,更准确地解决任务,而无需修改模型本身!此外,我们可以使用预训练模型和微调模型进行上下文学习。要了解可以与 LLM 配合使用的提示技术,请查看下面的概述:

  • 实用提示 [link]

  • 高级提示 [link]

  • 思维链提示 [link]

  • 提示集合 [link]

对 LLM 有用的工具有哪些?

尽管将 LLM(大语言模型)与外部工具连接的想法很诱人,但我们可能会想:哪些工具最有用? 为了回答这个问题,我们应当关注 LLM 的常见局限性,例如:

  • 缺乏访问最新信息 [2]

  • 有产生幻觉的倾向(即,输出不正确的信息)

  • 处理数学表达式的困难

  • 低资源语言的理解不完全

  • 无法理解时间的推移[8]

如果我们想解决这些问题,我们有几个选项。我们可以专注于通过SFT 或 RLHF对模型进行微调和完善——彻底微调模型以避免上述行为。实际上,大量资源已经投入到通过目标人类反馈来完善像GPT-4这样的模型,这也取得了相当令人印象深刻的结果。然而,我们也可以选择将重点放在让模型采取间接但通常更可靠的方法,而不是在模型内部解决这些问题。特别是,我们可以教会模型如何使用外部工具来帮助回答问题!

工具如何提供帮助? 在解决问题时,LLM 通常会通过查询一个可以提供更多上下文的外部工具来获得帮助。值得注意的有用工具包括(但不限于):

  • 能够返回当前日期的日历应用

  • 能够评估数学表达式的计算器

  • 向量数据库用于存储(可能)相关但无法直接存储在提示中的大量信息。

  • 将数据转换为不同语言的翻译模块

总的来说,工具在提供额外信息或上下文来帮助 LLM 解决问题时极为有用。超越这些简单的工具,我们甚至可以将 LLM 连接到外部代码解释器,使其能够编写和执行任意程序。结合支持代码的 LLM(例如,Codex [10]),这种方法实际上可以非常强大!更多信息请见这里

工具非常受欢迎!

尽管本概述将主要关注最近研究的工具与 LLM 集成,但通过外部工具增强模型(如 GPT-4)已成为近期关注的主题。例如,OpenAI 最近发布了一个模型插件扩展,使这些强大的 LLM 能够利用大量外部工具;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

ChatGPT Plus 插件商店中的热门应用(来自 ChatGPT Plus)

截至撰写时,GPT-4 有近 130 种不同的插件可用,这展示了将各种工具与强大的 LLM 集成的巨大兴趣。超越第三方插件,OpenAI 最近为 GPT-4 发布了代码解释器和互联网搜索工具。互联网搜索工具对于减轻 LLM 中的幻觉非常有用,因为模型提供的答案可以通过从互联网获取的相关、最新信息进行情境化。除了使 LLM 更具事实性和基础性外,代码解释器工具能够处理大量代码和数据文件并对这些数据进行准确分析,以提供有价值的见解。

TL;DR: 主要结论是,工具正在成为 LLM 的一个常见特性。除了 OpenAI 的产品外,我们甚至看到像 Bard 这样的模型正在增强类似功能,而像 LangChain 这样的开源库可以用来轻松构建多种工具类功能供现有 LLM 使用。

教授 LLM 使用工具

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自[1])

在[1]中,作者探讨了一种名为 Toolformer 的方法,它i) 教授 LLM 如何利用外部工具,并且ii) 在过程中保持 LLM 的通用性质。这些工具通过一组简单的文本到文本的 API 提供给 LLM(即模型提供文本作为输入,API 返回文本输出)。有趣的是,我们在[1]中看到 LLM 可以完全端到端地学习如何利用这些工具。模型决定调用哪些 API,向这些 API 传递哪些参数,并且如何最佳地利用返回的信息,而无需任何硬编码的控制流。

“语言模型可以学习控制各种工具,并自行选择何时、如何使用哪个工具。” — 来源于[1]

为了做到这一点,我们策划了一个训练数据集,展示了这些工具的正确使用。在[1]中,这个数据集是使用自监督启发式方法自动创建的——意味着不需要人工干预——只需为每个工具提供几个使用示例。然后,我们在这些数据上微调 LLM,使其学习每个工具的正确使用方法。结果是一个高性能的 LLM,它可以将简单但困难的子任务(如语言翻译、算术运算、访问当前信息等)委托给专门的外部工具,这些工具返回相关且准确的数据供 LLM 生成输出。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自[1])

使用了哪些工具? 在[1]中,Toolformer 使用了以下固定的一组工具:

  • 问答工具: 基于 Atlas [13],一种针对回答简单、基于事实的问题进行微调的 LLM。

  • 计算器: 用于数值运算的基本计算器。

  • 维基百科搜索工具: 一个搜索引擎,给定搜索词返回来自维基百科的简短文本片段。

  • 翻译器: 一个可以将任何语言的文本翻译成英文的语言翻译系统(但不能反向翻译!)。

  • 日历: 一个在查询时只返回当前日期的工具。

这些工具都通过一个简单的文本到文本结构的 API 提供;见上文。要使用这些工具,LLM 必须学习* i)* 识别需要工具的场景,ii) 指定使用哪个工具,iii) 向工具的 API 提供相关的文本输入,以及iv) 使用从 API 返回的文本来制作响应。值得注意的是,这些 API 简单的文本到文本结构允许我们轻松地将工具使用示例直接插入到文本序列中;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对外部 API 的调用以文本格式呈现,并与现有文本序列内嵌在一起(来自[1])

相较于以前工作的改进。 让 LLM 使用外部工具并不是一个新想法。例如,许多研究者尝试通过让 LLM 访问外部计算器来提高其在算术——特别是大数计算——方面的能力(见[4]的附录 B)。然而,主要问题是:我们应该如何教 LLM 使用这样的工具? 以前的方法严重依赖于人类标注的数据集。例如,LaMDA[3]使用外部搜索工具来减少幻觉;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自[3])

然而,我们在[3]中看到,教会 LaMDA 利用外部工具——在这个例子中是外部的信息检索系统——需要大量的人类标注数据。更具体地说,[3]中的作者让大量的众包工人手动编写对话,利用与 LLM 相同的搜索工具,从而提供了 LLM 应如何行为和回应的示例。相关出版物往往依赖于类似的人类中心方法[2]。创建这样的数据集困难、昂贵且耗时,这促使[1]中的作者开发了更高效的解决方案。

自动学习使用工具。 在[1]中,我们看到一个用于教 LLM 如何利用外部工具的数据集——为了简单起见,我们称之为“工具跟随数据集”——可以通过利用现有的、预训练的 LLM 的提示方法自动创建。我们从一个初始(正常)数据集开始,例如用于预训练的文本语料库。然后,我们提示一个预训练的 LLM 用外部 API 调用来增强这些数据。在这里,我们依赖于通用预训练 LLM 的上下文学习能力,来策划一组 API 调用,展示如何正确使用可用工具。下面展示了一个生成请求到问答工具 API 的示例提示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [1])

在我们用每个工具的示例用法扩充了数据集之后,我们需要执行过滤步骤。这一步骤是必要的,因为我们只希望在工具实际上对 LLM 有帮助时才使用外部工具!我们不应该在不需要时总是依赖外部工具——使用工具通常会有延迟(甚至是经济)成本。为了捕捉这个想法,我们可以这样做:

  1. 使用工具测量 LLM 的性能(即,交叉熵损失在 API 调用之后的标记上)。

  2. 测量 LLM 在没有工具情况下的性能。

  3. 丢弃那些使用工具未能使 LLM 的性能超越某个阈值的示例。

在这里,我们假设可以访问一个演示 LLM 应产生正确输出的数据集。通过这种方法,我们可以自动构建一个包含示例的数据集,说明何时以及如何利用工具来实际改善 LLM 的输出。在实践中,实际过程要复杂一些。具体来说,为了在没有工具的情况下测量 LLM 的性能,我们观察两个独立的情况——一个是完全不使用工具的情况,另一个是执行 API 调用但不提供响应的情况。这种方法确保了工具及其数据对 LLM 的有用性。

“如果提供此调用的输入和输出使得预测未来的标记更容易,则 API 调用对[语言模型]是有帮助的”— 来自 [1]

此外,我们没有将 API 调用插入到文本序列中,而是将其作为前缀附加,这样可以避免 LLM 损失的波动。记住,这样的 API 调用在 LLM 的原始预训练语料库中不存在,这意味着直接将 API 调用插入文本序列可能会扭曲用于过滤的结果。模型并不期望在数据中看到这样的 API 调用! 此外,在测量性能时,我们为 API 调用空间上接近的标记分配更高的权重,确保 API 调用发生在所需的地方,而不是在生成输出时的随机时刻。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [1])

[1]中使用的工具跟随数据集的完整构建过程如上所示。与以前的工作不同,这个过程不需要人工劳动。相反,我们利用 LLM 的上下文学习能力和一些巧妙的启发式方法来自动构建数据集。尽管这个过程并不完美(即,某些无用的 API 调用可能会避免过滤),但在实践中效果相当好!

学习使用工具。 一旦我们构建了数据集,教会 LLM 如何利用外部工具是很容易的——我们只需使用标准语言建模目标对模型进行微调。在[1]中,工具跟随数据集来源于预训练语料库。因此,尽管微调后的 LLM 能够利用外部工具,但它仍然是一个通用模型。此外,由于[1]中的筛选过程会去除那些不利于性能的 API 调用,LLM 会在隐含中学习何时以及如何使用每个工具以提升其输出。这种简单的方法取得了相当酷的结果!

工具是否有影响?

在[1]中分析的模型基于GPT-J [5],这是一个拥有 60 亿参数的语言模型,并且采用了CCNet作为训练数据集。Toolformer 与多个基准模型进行了比较,包括禁用 API 调用的 Toolformer 模型、原始的 GPT-J 模型、在 CCNet 上微调的 GPT-J 版本,以及其他一些 LLM,如OPT [6]和GPT-3 [7]。与研究少样本学习的先前工作不同,这些模型使用零样本方法进行评估,这种方法只是简单地向模型描述任务而不提供任何示例,并且使用了贪婪解码策略。在 Toolformer 中,只要<API>(即 API 调用的起始标记)出现在模型的k个最可能标记之一中,就会利用工具。

Toolformer 在多个不同领域中进行了评估。在基于事实的数据集上,我们发现问答工具被大量利用,相比基准模型的准确率显著提高。同样,在数学推理数据集上,计算器工具也被发现非常有用;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自[1])

在(多语言)问答基准上,模型的表现并不像预期那样令人印象深刻(即,Toolformer 在某些情况下不及 GPT-3 或 GPT-J 的表现)。然而,某些工具,如日历工具,被发现对提升 LLM 在时间推理等任务上的表现非常有用。有趣的是,作者还进行了一些分析,修改了 LLM 解码策略中 API 调用的概率。通过这项分析,我们了解到更频繁地利用外部工具并不总是好事——如果工具使用过于频繁,性能会下降;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自[1])

这样的发现突显了[1]中使用的过滤策略的重要性。工具使用不仅有成本,而且可能会降低性能。LLM 必须学习在何种场景下调用工具最为重要。[1]中采取的方法明确地使 LLM 在仅在显著提升模型性能时才利用外部工具。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(摘自[1])

保持通用。 除了上述下游评估,[1]中的作者在工具跟随数据集微调后,在预训练数据集的留出部分上评估了 Toolformer,发现模型在微调前后都达到可比的困惑度;如上所述。换句话说,Toolformer 在学习如何利用外部工具时不会丧失作为通用语言模型的任何能力,这意味着—与先前以任务特定方式接近工具跟随的工作不同[8]—该模型仍然是一个基础模型,能够解决各种不同的任务。

使用工具变得越来越简单

尽管[1]中提出的方法具有突破性并且信息量巨大,但它仍然需要一个广泛的微调过程。与大多数最近应用的 LLM 相比,这确实是一个麻烦!我们是否可以利用仅通过提示的方法来教会 LLM 使用外部工具? 最近围绕 GPT-4 的进展表明,这个问题可能通过提高 LLM 的指令跟随能力来解决。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(作者创建)

GPT-4 插件工作流程。 例如,GPT-4 可以通过插件商店访问各种工具。然而,模型并没有明确地针对商店中的每个插件进行微调。相反,它只是使用上下文学习。特别是,OpenAI 在提升 GPT-4 的可控性方面投入了大量资金,这使得模型能够非常详细地跟随指令和提示。因此,教会 GPT-4 如何使用插件只需要:

  1. 描述插件目的的文本描述

  2. 描述插件 API 的输入/输出格式的架构

使用这些信息,模型可以自行决定何时使用插件,进行格式正确的 API 调用,并将结果信息整合到对话中。这一切都是通过文本描述完成的,没有任何明确的微调,这表明教会 LLM 利用外部工具可能会随着时间的推移变得更加容易。要更详细地了解这一过程,我们可以查看 开源插件实现OpenAI 插件开发文档

结语

类似于人类在使用工具(例如,锤子、计算机、飞机等)后变得更好,LLMs 在获得一组可以提供有用信息或执行简单任务的简单 API 时也变得更有能力。为什么我们要完全依赖 LLM 解决一切问题,而不是将困难的任务委派给更准确、更专业的工具? 我们可以使用这种方法来缓解这些模型常常遇到的问题,例如输出中的不正确信息或缺乏时间推理能力。通过 Toolformer [1],我们看到 LLM 可以通过对工具跟随示例的数据集进行微调来学习利用外部工具。但是,最近的趋势表明,仅通过上下文学习可能就能教会 LLM 使用外部工具。这个领域还有很多未被揭示的内容,观察这些主题和相关应用随时间的发展将会很有趣!

与我联系!

非常感谢你阅读这篇文章。我是 Cameron R. WolfeRebuy 的人工智能总监。我研究深度学习的实证和理论基础。如果你喜欢这个概述,请订阅我的 Deep (Learning) Focus 新闻通讯,在这里我通过从基础开始概述相关主题,帮助读者理解 AI 研究。你还可以在 XLinkedIn 上关注我,或者查看我在 Medium 上的 其他文章

参考文献

[1] Schick, Timo 等人. “Toolformer: 语言模型可以自我学习使用工具。” arXiv 预印本 arXiv:2302.04761 (2023)。

[2] Komeili, Mojtaba, Kurt Shuster 和 Jason Weston. “互联网增强的对话生成。” arXiv 预印本 arXiv:2107.07566 (2021)。

[3] Thoppilan, Romal 等人. “Lamda: 对话应用的语言模型。” arXiv 预印本 arXiv:2201.08239 (2022)。

[4] Wei, Jason 等人. “思维链提示引发大型语言模型的推理。” arXiv 预印本 arXiv:2201.11903 (2022)。

[5] Wang, Ben 和 Aran Komatsuzaki. “GPT-J-6B: 一种 60 亿参数的自回归语言模型。” (2021)。

[6] 张苏珊等,“Opt: 开放预训练变换器语言模型。” arXiv 预印本 arXiv:2205.01068 (2022)。

[7] 布朗·汤姆等,“语言模型是少样本学习者。” 神经信息处理系统进展 33 (2020): 1877–1901。

[8] 帕里西·亚伦、姚赵和诺亚·费德尔,“Talm: 工具增强语言模型。” arXiv 预印本 arXiv:2205.12255 (2022)。

[9] 丁格拉·布万等,“时间感知语言模型作为时间知识库。” 计算语言学学会会刊 10 (2022): 257–273。

[10] 陈马克等,“评估基于代码训练的大型语言模型。” arXiv 预印本 arXiv:2107.03374 (2021)。

[11] 欧阳龙等,“训练语言模型以遵循人类反馈的指令。” 神经信息处理系统进展 35 (2022): 27730–27744。

[12] 周春婷等,“Lima: 对齐的少即是多。” arXiv 预印本 arXiv:2305.11206 (2023)。

[13] 伊扎卡德·戈蒂埃等,“Atlas: 带检索增强的语言模型的少样本学习。” arXiv 预印本 arXiv 2208 (2022)。

时间差学习及探索的重要性:图解指南

原文:towardsdatascience.com/temporal-difference-learning-and-the-importance-of-exploration-an-illustrated-guide-5f9c3371413a?source=collection_archive---------2-----------------------#2023-09-23

在动态网格世界中比较无模型和有模型的强化学习方法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ryan Pégoud

·

关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 9 月 23 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Saffu 供图于 Unsplash

最近,强化学习(RL)算法因解决诸如蛋白质折叠、在无人机竞速中达到超人类水平,甚至在你喜欢的聊天机器人中整合人类反馈等研究问题而受到广泛关注。

的确,RL 为各种顺序决策问题提供了有用的解决方案。时间差分学习(TD 学习)方法是 RL 算法中的一个流行子集。TD 学习方法结合蒙特卡洛动态规划方法的关键方面,以加速学习而不需要完美的环境动态模型。

在这篇文章中,我们将比较不同类型的TD 算法在自定义网格世界中的表现。实验设计将展示持续探索的重要性以及被测试算法的个体特征Q-learningDyna-QDyna-Q+。

本文的概要包括:

  • 环境描述

  • 时间差分(TD)学习

  • 无模型 TD 方法(Q-learning)和基于模型的 TD 方法(Dyna-Q 和 Dyna-Q+)

  • 参数

  • 性能比较

  • 结论

允许重现结果和图表的完整代码可以在这里找到: github.com/RPegoud/Temporal-Difference-learning

环境

我们将在此实验中使用的环境是一个具有以下特征的网格世界:

  • 网格是 12 x 8 单元格。

  • 代理从网格的左下角开始,目标是到达位于右上角的宝藏(一个终端状态,奖励为 1)。

  • 蓝色传送门是相连的,通过位于单元格**(10, 6)的传送门到达单元格(11, 0)**。代理在第一次过渡后不能再次使用该传送门。

  • 紫色传送门仅在100 个剧集后出现,但能使代理更快到达宝藏。这鼓励持续探索环境。

  • 红色传送门陷阱(终端状态,奖励为 0),并结束剧集。

  • 碰到墙壁会导致代理保持在同一状态。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

网格世界不同组件的描述(由作者制作)

本实验旨在比较 Q-learning、Dyna-Q 和 Dyna-Q+ 代理在变化环境中的行为。的确,在100 个剧集之后,最优策略必定会发生变化,成功剧集中的最优步骤数将从17减少到12

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

网格世界的表示,最优路径依赖于当前的剧集(由作者制作)

时间差分学习介绍:

时间差分学习是蒙特卡洛(MC)和动态规划(DP)方法的组合:

  • 与 MC 方法类似,TD 方法可以从经验中学习而不需要环境动态模型。

  • 与 DP 方法类似,TD 方法在每一步后更新估计基于其他学习到的估计,而不是等待结果(这称为 自举)。

TD 方法的一个特点是,它们在每个时间步都更新其价值估计,而 MC 方法则等到回合结束。

确实,这两种方法有不同的更新目标。MC 方法旨在更新回报Gt,它仅在一个回合结束时可用。而 TD 方法则针对:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TD 方法的更新目标

其中V真实价值函数 Vπ估计

因此,TD 方法结合MC采样(通过使用真实价值的估计)和DP自举(通过基于进一步估计的估计更新 V)。

时间差分学习的最简单版本称为**TD(0)**或一步 TD,实际实现 TD(0)看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TD(0)算法的伪代码,摘自《强化学习导论》[4]

当从状态S转移到新状态S’时,TD(0)算法将计算备份值并相应地更新V(S)。这个备份值称为 TD 误差,即观察到的奖励R加上新状态**γV(St+1)的折扣值与当前价值估计V(S)**之间的差异:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TD 误差

总之,TD 方法具有若干优点:

  • 它们不需要环境动态的完美模型p

  • 它们以在线方式实现,在每个时间步后更新目标

  • 如果α(学习率步长)遵循随机逼近条件,TD(0)保证会在任何固定策略π下收敛(更多细节请参见[4]第 55 页*“追踪非平稳问题”*)

实现细节:

以下各节探讨了多个 TD 算法在网格世界中的主要特性和性能。

为了简化起见,所有模型使用了相同的参数:

  • **Epsilon (**ε) = 0.1:在ε-贪心策略中选择随机动作的概率

  • **Gamma (**γ) = 0.9:应用于未来奖励或价值估计的折扣因子

  • Aplha (α) = 0.25:限制 Q 值更新的学习率

  • Planning steps = 100:对于 Dyna-Q 和 Dyna-Q+,每次直接交互执行的规划步骤数量

  • **Kappa (κ) = 0.001:对于 Dyna-Q+,在规划步骤中应用的奖励加权

每个算法的性能首先在单次运行 400 个回合的基础上进行展示(部分:Q 学习Dyna-QDyna-Q+),然后在“总结与算法比较”部分对 100 次运行 250 回合的数据进行平均。

Q 学习

我们在这里实现的第一个算法是著名的 Q 学习(Watkins, 1989):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Q 学习被称为离策略算法,因为其目标是直接逼近最优值函数,而不是代理遵循的策略π的值函数。

实际上,Q 学习仍然依赖于一个策略,通常称为‘行为策略’,以选择哪些状态-动作对被访问和更新。然而,Q 学习是离策略的,因为它基于未来奖励的最佳估计来更新其 Q 值,无论所选动作是否遵循当前策略π

与之前的 TD 学习伪代码相比,有三个主要区别:

  • 我们需要初始化所有状态和动作的 Q 函数,并且 Q(terminal)应为 0

  • 动作是从基于 Q 值的策略中选择的(例如相对于 Q 值的ϵ-贪心策略)

  • 更新的目标是动作值函数 Q 而非状态值函数 V

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Q 学习算法的伪代码,摘自《强化学习导论》[4]

现在我们有了第一个算法读取用于测试,我们可以开始训练阶段。我们的代理将使用其ε-贪心策略在网格世界中导航,相对于 Q 值。该策略以**(1 - ε)的概率选择最高 Q 值的动作,并以ε的概率选择随机动作**。每次行动后,代理将更新其 Q 值估计。

我们可以使用热图可视化每个网格世界单元的估计最大动作值 **Q(S, a)**的演变。这里代理器进行 400 个回合。由于每个回合只有一次更新,Q 值的演变较慢,大部分状态仍未映射:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练过程中学习到的每个状态的 Q 值的热图表示(作者提供)

完成 400 个回合后,对每个单元总访问次数的分析为我们提供了代理平均路径的合理估计。如下面右侧图所示,代理似乎已收敛到一个次优路径避免了单元(4,4),并且始终沿着下墙行进。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(左)每个状态的最大动作值估计,(右)每个状态的访问次数(作者提供)

由于这种次优策略,代理在每回合达到最少21 步,遵循“总访问次数”图中勾画的路径。步骤数量的变化可归因于ε-贪心策略,该策略引入了 10%的随机动作概率。鉴于这一策略,沿下墙行进是一种限制随机动作带来的潜在干扰的不错策略。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练最后 100 回合的步数(300–400)(作者提供)

总结来说,Q 学习代理如前所述收敛于次优策略。此外,Q 函数仍有一部分环境是未被探索的,这阻止了代理在第 100 集后出现紫色传送门时找到新的最佳路径。

这些性能限制可以归因于相对较少的训练步骤(400),这限制了与环境互动的可能性以及 ε-贪婪策略引发的探索。

规划,作为基于模型的强化学习方法的一个基本组成部分,特别有助于提高样本效率动作价值的估计。Dyna-Q 和 Dyna-Q+ 是结合了规划步骤的 TD 算法的良好示例。

Dyna-Q

Dyna-Q 算法(动态 Q 学习)是基于模型的强化学习TD 学习的结合体。

基于模型的强化学习算法依赖于环境模型,将规划作为其更新价值估计的主要方式。相比之下,无模型算法依赖于直接学习。

“环境模型是代理可以用来预测环境如何对其动作做出响应的任何东西” — 强化学习:导论。

在本文的范围内,模型可以被视为对转移动态 p(s’, r|s, a) 的近似。这里,p 返回一个单一的下一个状态和奖励对,给定当前状态-动作对。

随机的环境中,我们区分分布模型和样本模型,前者返回下一个状态和动作的分布,而后者返回从估计分布中抽样得到的单一对。

模型特别有助于模拟情节,因此通过用规划步骤替代现实世界的互动来训练代理,即与模拟环境的互动。

实施 Dyna-Q 算法的代理是规划代理的一部分,这些代理结合了直接强化学习模型学习。它们使用与环境的直接互动来更新它们的价值函数(如 Q 学习所示),同时也学习环境的模型。在每次直接互动之后,它们还可以执行规划步骤,通过模拟互动来更新它们的价值函数。

一个快速的国际象棋示例

想象一下玩一局好的国际象棋。每次走一步棋后,你对手的反应让你评估你的走棋质量。这类似于收到正面或负面的奖励,这让你可以“更新”你的策略。如果你的走棋导致了失误,你可能不会再这样做,前提是棋盘的配置相同。到目前为止,这与直接强化学习是类似的。

现在,让我们加入规划。假设在你每次移动后,当对手思考时,你在脑海中回顾你的每一次移动重新评估它们的质量。你可能会发现最初忽视的弱点,或发现某些移动比你想象的更好。这些思考还可能让你更新策略。这正是规划的意义,在不与真实环境交互的情况下更新值函数,而是对环境的模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计划、行动、模型学习和直接强化学习:一个规划代理的时间表(由作者制定)

因此,Dyna-Q 相比 Q 学习包含了一些额外的步骤:

在每次直接更新 Q 值后,模型会存储观察到的状态-动作对、奖励和下一个状态。这个步骤称为模型训练。

  • 在模型训练后,Dyna-Q 执行n规划步骤:

  • 从模型缓冲区中选择一个随机的状态-动作对(即这个状态-动作对是在直接交互中观察到的)

  • 模型生成模拟的奖励和下一个状态

  • 值函数通过模拟观察进行更新(s, a, r, s’

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Dyna-Q 算法的伪代码,摘自《强化学习简介》[4]

我们现在使用n=100来复制 Dyna-Q 算法的学习过程。这意味着在每次与环境的直接交互后,我们使用模型执行 100 次规划步骤(即更新)。

下图热力图展示了 Dyna-Q 模型的快速收敛。事实上,算法只需约10 个回合即可找到最优路径。这是因为每一步会导致 Q 值的 101 次更新(而 Q 学习只更新 1 次)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练期间每个状态的学习 Q 值的热力图表示(由作者制作)

规划步骤的另一个好处是更好地估计网格中的动作值。由于间接更新针对的是存储在模型中的随机过渡,距离目标较远的状态也会被更新。

相比之下,Q 学习中的动作值会从目标点缓慢传播,导致网格的映射不完整。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(左)每个状态的最大动作值估计,(右)每个状态的访问次数(由作者制作)

使用 Dyna-Q,我们找到一个最优路径,允许在17 步内解决网格世界,如下图红条所示。尽管为了探索偶尔会有ε-贪婪行为的干扰,最佳表现仍然会定期达到。

最终,虽然 Dyna-Q 由于引入了规划,可能看起来比 Q-learning 更具说服力,但需要记住的是,规划带来了权衡,在计算成本现实世界探索之间。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练的最后 100 集(300–400)的步骤数(作者制作)

Dyna-Q+

到目前为止,测试的算法没有一个能找到第 100 步之后出现的最优路径(紫色传送门)。实际上,这两个算法都迅速收敛到一个在训练阶段结束前保持固定的最优解决方案。这突显了持续探索在训练过程中的必要性。

Dyna-Q+与 Dyna-Q 大致相似,但在算法上增加了一个小变化。实际上,Dyna-Q+不断跟踪自每个状态-动作对在与环境的真实交互中尝试以来所经过的时间步数。

特别地,考虑一个奖励r的转移,该转移在τ时间步中没有被尝试。Dyna-Q+会进行规划,假设该转移的奖励为r + κτ**,其中κ足够小(实验中为 0.001)。

这种奖励设计的变化鼓励智能体持续探索环境。它假设状态-动作对未被尝试的时间越长,这对的动态发生变化或模型不正确的可能性就越大。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Dyna-Q+算法的伪代码,摘自《强化学习导论》[4]

如下热图所示,与之前的算法相比,Dyna-Q+在更新方面更加活跃。在第 100 集之前,智能体探索了整个网格,找到了蓝色传送门和第一个最优路线。

网格其余部分的动作值在减少后再缓慢增加,因为左上角的状态-动作对在一段时间内没有被探索。

当紫色传送门在第 100 集出现时,智能体找到新的捷径,整个区域的值上升。在完成 400 集之前,智能体将不断更新每个状态-动作对的动作值,同时保持对网格的偶尔探索。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练过程中每个状态的学习 Q 值的热图表示(作者制作)

多亏了对模型奖励的额外奖金,我们最终得到了Q 函数的完整映射(每个状态或单元都有一个动作值)。

结合持续探索,智能体能够找到出现的新最佳路线(即最优策略),同时保留以前的解决方案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(左)每个状态的最大动作值估计,(右)每个状态的访问次数(作者制作)

然而,Dyna-Q+ 中的探索与利用权衡确实带来了成本。当状态-动作对在足够长时间内未被访问时,探索奖励会鼓励代理重新访问这些状态,这可能会暂时降低其即时性能。这种探索行为优先更新模型以改善长期决策。

这解释了为什么 Dyna-Q+ 有些回合可以长达 70 步,而 Q 学习和 Dyna-Q 最多为 35 步和 25 步。Dyna-Q+ 中较长的回合反映了代理愿意投入额外的步数进行探索,以获取更多关于环境的信息并完善其模型,即使这会导致短期性能下降。

相比之下,Dyna-Q+ 经常实现最佳性能(如下图中的绿色条形图所示),这是以前的算法未能达到的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练最后 100 回合的步数(300–400)(作者提供)

总结与算法比较

为了比较算法之间的关键差异,我们使用了两个指标(请注意,结果依赖于输入参数,为简化起见,所有模型的输入参数均相同):

  • 每回合步数:该指标描述了算法向最优解收敛的速度。它还描述了算法在收敛后的行为,特别是在探索方面。

  • 平均累计奖励:指导致正奖励的回合百分比。

分析每回合的步数(见下图)揭示了基于模型和非基于模型的方法的几个方面:

  • 基于模型的效率:在这个特定的网格世界中,基于模型的算法(Dyna-Q 和 Dyna-Q+)往往更具样本效率(这一特性在 RL 中也较为普遍)。这是因为它们可以利用环境的学习模型进行前瞻性规划,从而更快地收敛到接近最优或最优的解决方案。

  • Q 学习收敛:Q 学习虽然最终会收敛到接近最优解,但需要更多的回合(125)。需要强调的是,Q 学习每步仅执行 1 次更新,这与 Dyna-Q 和 Dyna-Q+ 执行的多次更新形成对比。

  • 多次更新:Dyna-Q 和 Dyna-Q+ 每步执行 101 次更新,这有助于它们更快地收敛。然而,这种样本效率的权衡是计算成本(见下表的运行时间部分)。

  • 复杂环境:在更复杂或随机的环境中,基于模型的方法的优势可能会减弱。模型可能引入错误或不准确,从而导致次优策略。因此,这种比较应被视为不同方法的优缺点概述,而不是直接的性能比较。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

平均每集步骤数的比较(由作者制作)

现在我们引入平均累计奖励(ACR),它表示代理达到目标的集数百分比(因为达到目标的奖励为 1,而触发陷阱的奖励为 0),因此 ACR 计算方式为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中 N 是集数(250),K 是独立运行次数(100),Rn,k 是运行 k 中第 n 集的累计奖励。

以下是所有算法性能的详细分析:

  • Dyna-Q 收敛迅速,达到最高的总体回报,ACR 为 87%。这意味着它在很大一部分集数中能够高效地学习并达到目标。

  • Q-learning 也达到了类似的性能水平,但需要更多的集数才能收敛,这解释了其稍低的 ACR,为 70%。

  • Dyna-Q+ 能够迅速找到一个良好的策略,在仅经过 15 集后达到累计奖励 0.8。然而,奖励的变异性和探索性降低了其性能,直到第 100 步之后才开始改善,因为它发现了新的最优路径。然而,短期的探索会妨碍其性能,导致其 ACR 为 79%,低于 Dyna-Q,但高于 Q-learning。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

平均每集累计奖励的比较(由作者制作)

结论

在本文中,我们介绍了时序差分学习的基本原理,并将 Q-learning、Dyna-Q 和 Dyna-Q+ 应用于自定义网格世界。这个网格世界的设计有助于强调持续探索的重要性,以发现和利用在变化环境中新的最优策略。通过每集步骤数和累计奖励的表现差异,展示了这些算法的优缺点。

总结来说,基于模型的方法(Dyna-Q、Dyna-Q+)相较于基于模型的方法(Q-learning)在样本效率上有优势,但计算效率较低。然而,在随机或更复杂的环境中,模型的不准确性可能会阻碍性能并导致次优策略。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

参考文献:

[1] Demis Hassabis, AlphaFold 揭示了蛋白质宇宙的结构 (2022), DeepMind

[2] Elia Kaufmann, Leonard Bauersfeld, Antonio Loquercio, Matthias Müller, Vladlen Koltun & Davide Scaramuzza, 冠军级无人机竞速使用深度强化学习 (2023), Nature

[3] Nathan Lambert, Louis Castricato, Leandro von Werra, Alex Havrilla, 从人类反馈中阐述强化学习(RLHF), HuggingFace

[4] Sutton, R. S. 和 Barto, A. G. . 强化学习:导论 (2018), 剑桥(马萨诸塞州):麻省理工学院出版社。

[5] Christopher J. C. H. Watkins 和 Peter Dayan, Q-learning (1992), 《机器学习》,Springer Link

Python 中的时序差分:第一个基于样本的强化学习算法

原文:towardsdatascience.com/temporal-differences-with-python-first-sample-based-reinforcement-learning-algorithm-54c11745a0ee

使用 Python 编写并理解 TD(0)算法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Eligijus Bujokas

·发表于 Towards Data Science ·13 分钟阅读·2023 年 1 月 27 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Kurt CotoagaUnsplash上的照片

这是我之前文章的续集:

## Python 强化学习中的第一步

Python 的原始实现,展示了如何在强化学习的基本世界之一中找到最佳位置……

towardsdatascience.com

在这篇文章中,我想让读者熟悉强化学习中的基于样本的算法逻辑(RL)。为此,我们将创建一个带有洞的网格世界(类似于缩略图中的那个),并让我们的代理在创建的世界中自由遍历。

希望在代理的旅程结束时,他能学会在世界上哪个地方是好的地方,哪些位置应该避免。为了帮助我们的代理学习,我们将使用著名的**TD(0)**算法。

在深入算法之前,让我们定义一下我们想要解决的目标。

在这篇文章中,我们将创建一个 5 行 7 列的网格世界,这意味着我们的代理将能够处于 35 个状态中的一个。移动规则如下:

  • 代理不能离开网格世界的边界。

  • 在每个时间步,代理只能向上、向下、向左或向右移动。

  • 代理从我们网格世界的左上角开始。

  • 如果代理达到目标或掉入洞里,游戏结束,代理会被返回到起始状态。

  • 每次移动都会获得-1 的奖励。

  • 掉入洞里会获得-10 的奖励。

  • 达到目标会获得 10 的奖励。

我们代理的**终极目标是尽可能准确地评估它可能处于的每一个状态。换句话说,我们代理希望根据给定的移动策略评估每个状态的价值。

以下代码片段初始化了前一节中描述的环境:

import numpy as np 

def init_policy(S: np.array, weight_dict: dict = {'right': 1}) -> dict:
    # Saving all the unique states to a vector 
    states = np.unique(S)

    # Getting the number of rows and columns of the S matrix
    n_row = S.shape[0]
    n_col = S.shape[1]

    # Dictionary to hold each action for a given state
    P = {}
    for s in states: 
        s_dict = {}

        # Checking which index is the current state in the S matrix 
        s_index = np.where(S == s)

        # If the state is in the top left corner, we can only move right and down
        if s_index == (0, 0):
            s_dict['right'] = 0.5 * weight_dict['right']
            s_dict['down'] = 1 - s_dict['right']

        # If the state is in the top right corner, we can only move left and down
        elif s_index == (0, n_col - 1):
            s_dict['left'] = 0.5
            s_dict['down'] = 0.5

        # If the state is in the bottom left corner, we can only move right and up
        elif s_index == (n_row - 1, 0):
            s_dict['right'] = 0.5 * weight_dict['right']
            s_dict['up'] = 1 - s_dict['right']

        # If the state is in the bottom right corner, we can only move left and up
        elif s_index == (n_row - 1, n_col - 1):
            s_dict['left'] = 0.5
            s_dict['up'] = 0.5

        # If the state is in the first row, we can only move left, right, and down
        elif s_index[0] == 0:
            s_dict['right'] = 0.333 * weight_dict['right']
            s_dict['left'] = (1 - s_dict['right']) / 2
            s_dict['down'] =  (1 - s_dict['right']) / 2

        # If the state is in the last row, we can only move left, right, and up
        elif s_index[0] == n_row - 1:
            s_dict['right'] = 0.333 * weight_dict['right']
            s_dict['left'] =  (1 - s_dict['right']) / 2
            s_dict['up'] = (1 - s_dict['right']) / 2

        # If the state is in the first column, we can only move up, down, and right
        elif s_index[1] == 0:
            s_dict['right'] = 0.333 * weight_dict['right']
            s_dict['up'] = (1 - s_dict['right']) / 2
            s_dict['down'] = (1 - s_dict['right']) / 2

        # If the state is in the last column, we can only move up, down, and left
        elif s_index[1] == n_col - 1:
            s_dict['up'] = 0.333
            s_dict['down'] = 0.333
            s_dict['left'] = 1 - s_dict['up'] - s_dict['down']

        # If the state is in the middle, we can move in all directions
        else:
            s_dict['right'] = 0.25 * weight_dict['right']
            s_dict['up'] = (1 - s_dict['right']) / 3
            s_dict['down'] = (1 - s_dict['right']) / 3
            s_dict['left'] = (1 - s_dict['right']) / 3

        # Saving the current states trasition probabilities
        P[s] = s_dict

    return P

def generate_holes(nrow: int, ncol: int, start_coords: list, hole_coords: list, nholes: int = 1) -> list:
    """
    Function that generates nholes in a gridworld 

    The holes cannot be: 
        - in the start state
        - in the goal state
    """
    # Generating the hole coordinates 
    # The hole cannot be in the start or goal state
    hole_coords = []
    for _ in range(nholes):

        hole_row = np.random.randint(0, nrow - 1)
        hole_col = np.random.randint(0, ncol - 1)

        while (hole_row, hole_col) in start_coords or (hole_row, hole_col) in hole_coords:
            hole_row = np.random.randint(0, nrow - 1)
            hole_col = np.random.randint(0, ncol - 1)

        # Appending to the hole coordinates list
        hole_coords.append((hole_row, hole_col))

    return hole_coords

def init_env(
        n_rows: int, 
        n_cols: int,
        step_reward: float = -1, 
        goal_reward: float = 10,
        hole_reward: float = -10,
        n_holes: int = 1,
        random_seed: int = 42, 
        policy_weights: dict = {'right': 1}
        ) -> np.array: 
    """
    Functionat that returns the initial environment: 
        S - the state matrix indexed by [row, col]
        V - the initial value matrix indexed by [row, col]
        R - the reward matrix indexed by [row, col]
        A - the action matrix indexed by [row, col]
        P - the probability dictionary where for each state, the keys are the actions and the values are the probabilities of the next state
    """
    # Setting the random seed
    np.random.seed(random_seed)

    # Initiating the S matrix 
    S = np.arange(0, n_rows * n_cols).reshape(n_rows, n_cols)

    # Creating the initial V matrix
    V = np.zeros((n_rows, n_cols))

    # The start state will be always the top left corner 
    # The goal state will be always the bottom right corner
    # We will generate a random holes that our agent can fall in
    # Any other state that is not the hole or the goal state will receive a step reward 
    goal_coord = (n_rows - 1, n_cols - 1)
    R = np.zeros((n_rows, n_cols))
    R.fill(step_reward)
    R[0, 0] = step_reward
    R[goal_coord] = goal_reward

    # Generating the hole coordinates 
    # The hole cannot be in the start or goal state
    hole_coords = generate_holes(n_rows, n_cols, [(0, 0)], [goal_coord], n_holes)

    # Setting the hole reward
    for hole_coord in hole_coords:
        R[hole_coord] = hole_reward

    # Initiating the policy 
    P = init_policy(S, weight_dict=policy_weights)

    return S, V, R, P, hole_coords, [goal_coord]

我们需要开始学习的对象是:

  • 状态矩阵 S

  • 值矩阵 V

  • 奖励矩阵 R

  • 策略字典 P

默认情况下,上述代码片段初始化了一个随机策略的世界。

随机策略意味着我们的代理通过均匀概率分布选择从一个状态转移到另一个状态。

让我们创建我们的世界,更详细地探索这些矩阵:

S, V, R, P, hole_coords, goal_coard = init_env(5, 7, n_holes=4, random_seed=3)

以下代码片段用于绘制矩阵:

def array_index_to_matplot_coords(i: int, j: int, n_cols: int) -> Tuple[int, int]:
    """
    Converts an array index to a matplot coordinate
    """
    x = j
    y = n_cols - i - 1
    return x, y

def plot_matrix(
    M: np.array, 
    goal_coords: list = [],
    hole_coords: list = [],
    img_width: int = 5, 
    img_height: int = 5, 
    title: str = None,
    ) -> None: 
    """
    Plots a matrix as an image.
    """
    height, width = M.shape

    fig = plt.figure(figsize=(img_width, img_width))
    ax = fig.add_subplot(111, aspect='equal')

    for x in range(height):
        for y in range(width):
            # By default, the (0, 0) coordinate in matplotlib is the bottom left corner,
            # so we need to invert the y coordinate to plot the matrix correctly
            matplot_x, matplot_y = array_index_to_matplot_coords(x, y, height)

            # If there is a tuple of (x, y) in the goal_coords list, we color the cell gray 
            if (x, y) in goal_coords:
                ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='gray'))
            # If there is a tuple of (x, y) in the hole_coords list, we color the cell salmon
            elif (x, y) in hole_coords:
                ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='salmon'))

            ax.annotate(str(M[x][y]), xy=(matplot_x, matplot_y), ha='center', va='center')

    offset = .5    
    ax.set_xlim(-offset, width - offset)
    ax.set_ylim(-offset, height - offset)

    ax.hlines(y=np.arange(height+1)- offset, xmin=-offset, xmax=width-offset)
    ax.vlines(x=np.arange(width+1) - offset, ymin=-offset, ymax=height-offset)

    plt.title(title)
    plt.show()

def plot_policy_matrix(P: dict, S:np.array, terminal_coords: list = [], img_width: int = 5, img_height: int = 5, title: str = None) -> None: 
    """ 
    Plots the policy matrix out of the dictionary provided; The dictionary values are used to draw the arrows 
    """
    height, width = S.shape

    fig = plt.figure(figsize=(img_width, img_width))
    ax = fig.add_subplot(111, aspect='equal')
    for x in range(height):
        for y in range(width):
            matplot_x, matplot_y = array_index_to_matplot_coords(x, y, height)

            # If there is a tuple of (x, y) in the goal_coords list, we color the cell gray 
            if (x, y) in terminal_coords:
                ax.add_patch(matplotlib.patches.Rectangle((matplot_x - 0.5, matplot_y - 0.5), 1, 1, facecolor='gray'))

            else:
                try:
                    # Adding the arrows to the plot
                    if 'up' in P[S[x, y]]:
                        plt.arrow(matplot_x, matplot_y, 0, 0.3, head_width = 0.05, head_length = 0.05)
                    if 'down' in P[S[x, y]]:
                        plt.arrow(matplot_x, matplot_y, 0, -0.3, head_width = 0.05, head_length = 0.05)
                    if 'left' in P[S[x, y]]:
                        plt.arrow(matplot_x, matplot_y, -0.3, 0, head_width = 0.05, head_length = 0.05)
                    if 'right' in P[S[x, y]]:
                        plt.arrow(matplot_x, matplot_y, 0.3, 0, head_width = 0.05, head_length = 0.05)
                except Exception as e:
                    print(f"Error: {e}")
                    print(f"Current x and y: {x}, {y}")

    offset = .5    
    ax.set_xlim(-offset, width - offset)
    ax.set_ylim(-offset, height - offset)

    ax.hlines(y=np.arange(height+1)- offset, xmin=-offset, xmax=width-offset)
    ax.vlines(x=np.arange(width+1) - offset, ymin=-offset, ymax=height-offset)

    plt.title(title)

首先让我们可视化状态矩阵:

plot_matrix(S, goal_coords=goal_coard, hole_coords=hole_coords, title='State Matrix')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

状态矩阵;作者拍摄的照片

红色状态表示洞的坐标——这些是我们的代理想要避免的状态。

灰色状态表示目标——这是我们的代理想要到达的地方。

我们的代理总是从状态 0 开始它的旅程。

奖励矩阵如下:

plot_matrix(R, goal_coords=goal_coard, hole_coords=hole_coords, title='Reward Matrix')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

奖励矩阵;作者拍摄的照片

转移到某个状态的奖励矩阵在上面可视化。

例如:

  • 从状态 1 到 8 会获得-1 的奖励

  • 从状态 9 到 10 会获得-10 的奖励

  • 从状态 33 到 34 会获得 10 的奖励

依此类推。

我们的代理将遵循的策略是随机策略——进入每个状态的概率均等:

plot_policy_matrix(P, S, terminal_coords=hole_coords + goal_coard, title='Policy Matrix')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

策略矩阵;作者拍摄的照片

策略矩阵中的灰色状态表示终端状态:如果代理选择进入该状态,剧集将结束,代理将被重置到状态 0。

TD(0)算法的目标是评估给定策略下每个状态的价值。

换句话说,我们想要填充值矩阵的值:

plot_matrix(V, goal_coords=goal_coard, hole_coords=hole_coords, title='Value Matrix')

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

初始值矩阵;作者拍摄的照片

TD(0)算法是单步时序差分算法的简称。为了开始建立直觉并广泛地说,在此算法中,我们的代理按照给定的策略执行一步,观察奖励,并在这种步骤后更新状态价值的估计。

从数学上讲,更新步骤如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TD(0) 更新方程

这里:

  • s prime — 我们的代理从当前状态 s 转移到的状态。

  • 奖励 r 等于转移到 s prime 的奖励。

  • Gamma 是折扣率(大于 0,小于或等于 1)。

  • Alpha 是大小(大于 0,小于或等于 1)。

完整算法¹如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

完整 TD(0);作者照片

TD(0)算法是一种预测算法。在强化学习中,预测算法指的是一种尝试估计状态值的算法,同时改变给定的策略(转移概率)。

这也是一种自助算法,因为我们使用当前的价值函数估计来估计下一个状态的价值函数。

因此,我们只关心状态值——智能体从当前状态移动的总期望累计奖励:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

状态价值

现在让我们开始实现算法。

我们的智能体首先需要根据我们创建的策略进行移动:

def select_move(s, S, P) -> int:
    """
    Given the current state, returns the coordinates of the next state based on the current policy 
    """
    # Getting the current state index 
    s_index = np.where(S == s)

    # Getting the current state policy
    s_policy = P[s]

    # Selecting the next action based on the current policy
    next_action = np.random.choice(list(s_policy.keys()), p=list(s_policy.values()))

    # Getting the next state coordinates based on the next action
    try:
        if next_action == 'up':
            next_state = S[s_index[0] - 1, s_index[1]][0]
        elif next_action == 'down':
            next_state = S[s_index[0] + 1, s_index[1]][0]
        elif next_action == 'left':
            next_state = S[s_index[0], s_index[1] - 1][0]
        elif next_action == 'right':
            next_state = S[s_index[0], s_index[1] + 1][0]
    except Exception as e: 
        print(f"Current state: {s}")
        print(f'Next action: {next_action}')
        print(f'Error: {e}')

    return next_state

当智能体处于状态 s 时,它只能前往策略矩阵字典中存在的状态。例如,状态 1 中的所有动作是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

状态 1 的所有可能动作

所有概率的总和等于 1,我们的智能体随机选择右、左或下(请参阅状态矩阵图以查看状态位置)。

上述动作是开始更新价值函数所需的全部。当智能体进行移动时,它转移到另一个状态并收集该状态的奖励。然后我们应用方程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TD(0)更新方程

def get_state_coords(s, S) -> tuple:
    """
    Returns the state coordinates given the state index
    """
    s_index = np.where(S == s)
    return s_index[0][0], s_index[1][0]

def update_value(s, s_prime, S, P, V, R, alpha: float = 0.1, gamma: float = 0.9) -> float: 
    """
    Updates the current value function based on the current policy
    """
    # Getting the CURRENT state's nrow and ncol index
    s_index_now = get_state_coords(s, S)

    # Getting the SELECTED state's nrow and ncol index
    s_index_prime = get_state_coords(s_prime, S)

    # Getting the reward by moving to the selected state 
    move_reward = R[s_index_prime[0], s_index_prime[1]]

    # Getting the current estimated value of the selected state 
    current_value = V[s_index_now[0], s_index_now[1]]

    # The next value 
    prime_value = V[s_index_prime[0], s_index_prime[1]]

    # Returning the TD(0) current state value 
    return current_value + alpha * (move_reward + gamma * prime_value - current_value)

最后一步是将所有内容封装到一个while 循环中,只有当我们的智能体转移到终止状态时才停止探索:

def episode_exploration(S, P, V, R, terminal_state_coords: list, alpha: float = 0.1, gamma: float = 0.9) -> None: 
    """
    Agent exploration and value updating using TD(0) equation until a terminal state is reached
    """
    # The starting state is 0 
    s = 0 

    # Keeping track of the number of moves
    n_moves = 0

    # Getting the coordinates of the s 
    s_coords = get_state_coords(s, S)

    while s_coords not in terminal_state_coords:
        # Selecting the next state based on the current policy
        s_prime = select_move(s, S, P)

        # Updating the current state value 
        V[s_coords] = update_value(s, s_prime, S, P, V, R, alpha, gamma)

        # Updating the current state 
        s = s_prime

        # Incrementing the number of moves
        n_moves += 1

        # Getting teh new s coords
        s_coords = get_state_coords(s, S)

    return n_moves

我们现在拥有了实施完整 TD(0)算法所需的一切。

让我们定义 10000 次实验,让我们的智能体进行学习吧!

# Defining the number of episodes to explore 
n_episodes = 10000

# We will plot the V matrix after each episode filling the same device plot to make an animation
number_of_walks = []
for _ in tqdm(range(n_episodes)):
    n = episode_exploration(S, P, V, R, terminal_state_coords=hole_coords + goal_coard, alpha=0.1, gamma=0.9)
    number_of_walks.append(n)

我们的智能体在终止之前所采取的动作数量:

# Ploting the distribution of the number of moves 
plt.figure(figsize=(10, 5))
sns.kdeplot(number_of_walks, fill=True)
plt.title(f'Number of moves distribution | Mean: {np.mean(number_of_walks):.2f} | Std: {np.std(number_of_walks):.2f}')
plt.xlabel('Number of moves')
plt.ylabel('Frequency')
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

移动次数;作者绘图

平均而言,我们的智能体在碰到终止状态之前进行了 10 次移动。

最终评估的状态价值矩阵:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 TD(0)和随机策略评估的 V;作者绘图

正如我们所见,按照给定的策略,智能体开始旅程的状态非常糟糕。平均而言,从该状态开始,智能体仅获得-9.96 的奖励。然而,随着我们接近目标状态,价值会增加。

注意,目标状态和洞穴状态的值为 0,因为这些状态没有探索——每次智能体转移到这些状态,游戏就结束了。

如果我们选择了另一种策略会发生什么?例如,更频繁地选择“向右”方向:

# Assiging a different policy
S, V, R, P, hole_coords, goal_coard = init_env(5, 7, n_holes=4, random_seed=3, policy_weights={'right': 1.5}) 

# Defining the number of episodes to explore 
n_episodes = 10000

# We will plot the V matrix after each episode filling the same device plot to make an animation
number_of_walks = []
for _ in tqdm(range(n_episodes)):
    n = episode_exploration(S, P, V, R, terminal_state_coords=hole_coords + goal_coard, alpha=0.1, gamma=0.9)
    number_of_walks.append(n)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同策略下的移动次数

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不同策略的价值矩阵

随机策略的状态价值矩阵总和为**-249.29**,而更高概率向右的策略总和为**-213.51**。

从这个意义上说,我们可以说更频繁地向右移动是一种更好的策略!

在这篇文章中,我介绍了 RL 中的第一个基于样本的算法——一步时序差分算法或 TD(0)。

这是一种预测算法,即仅用于评估给定策略的状态。改变策略会得到不同的状态价值结果。

祝大家学习愉快,编程快乐!

[1]

  • 作者: 理查德·S·萨顿,安德鲁·G·巴托

  • 年份: 2018

  • 页码: 120

  • 书名: 强化学习:一种介绍

  • URL:http://archive.ics.uci.edu/ml

时间图基准

原文:towardsdatascience.com/temporal-graph-benchmark-bb5cc26fcf11?source=collection_archive---------2-----------------------#2023-12-09

挑战性和现实的时间图学习数据集

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shenyang(Andy) Huang

·

关注 发表在 Towards Data Science · 10 分钟阅读 · 2023 年 12 月 9 日

近年来,静态图上的机器学习取得了显著进展,这得益于公共数据集和标准化评估协议的普及,例如广泛采用的开放图基准 (OGB)。然而,许多现实世界系统,如社交网络、交通网络和金融交易网络,随着时间的推移不断演变,节点和边不断添加或删除。这些系统通常被建模为时间图。到目前为止,时间图学习的进展受到缺乏大型高质量数据集以及缺乏适当评估的制约,导致了过于乐观的性能表现。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

现实世界的网络随着时间的推移而演变。图片来源:Armand Khoury on Unsplash

为了解决这一问题,我们推出了时序图基准测试(TGB),这是一个针对时序图的挑战性和多样化基准数据集的集合,用于现实的、可重复的、稳健的机器学习评估。受到 OGB 成功的启发,TGB 自动化了数据集下载和处理以及评估协议,并允许用户通过排行榜比较模型性能。我们希望 TGB 能够成为时序图社区的标准化基准,促进新方法的发展,并提高对大型时序网络的理解。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

针对时序图学习的挑战性和现实性基准

这篇文章基于我们的论文 时序图基准测试在时序图上的机器学习 (NeurIPS 2023 数据集和基准测试专题),由 Emanuele Rossi共同撰写。请在 我的网站查找更多时序图相关工作。想了解更多关于时序图的内容?加入 时序图阅读小组 NeurIPS 2023 时序图学习研讨会 ,了解最前沿的 TG 研究。

目录:

  1. 动机

  2. 问题设定

  3. 数据集详情

  4. 动态链接属性预测

  5. 动态节点属性预测

  6. 开始使用 TGB

  7. 结论与未来工作

动机

近年来,静态图的机器学习领域得到了显著提升,这主要归功于公开数据集的出现和已建立的基准测试,例如开放图基准(OGB)、长程图基准TDC 基准。然而,许多现实世界的系统,如社交网络、交通网络和金融交易网络,都是时间性的:它们随着时间的发展而演变。直到现在,时间图的发展由于缺乏大型、高质量的数据集和全面的评估框架而受到显著阻碍。这种稀缺性,加上评估限制,导致了在流行数据集(如 Wikipedia 和 Reddit)上的几乎完美的 AP 或 AUROC 分数,导致了对模型性能的过于乐观的评估,并且在区分竞争模型方面面临挑战。

数据集的缺乏。 常见的 TG 数据集仅包含几百万条边,远远小于实际时间网络中的规模。此外,这些数据集大多限制在社交和互动网络领域。由于网络属性在不同领域间通常变化显著,因此在多个领域上进行基准测试也很重要。最后,缺乏节点级任务的数据集,导致大多数方法仅关注链接预测。为了解决这个挑战,TGB 包含了来自个不同领域的nine个数据集,这些数据集在节点、边和时间戳的数量上都是数量级更大的。此外,TGB 还提出了四个数据集用于新的节点亲和预测任务。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TGB 数据集显著大于常见的 TG 数据集

简化的评估。 动态链接预测通常被框架化为二分类任务:正(真实)边标记为 1,负(不存在)边标记为 0。在评估时,通过保持源节点固定并随机选择目标节点来采样每个正边的一个负边。这种评估仅考虑少量容易预测的负边,导致模型性能被夸大,许多模型在 Wikipedia 和 Reddit 上获得了>95%的 AP(Poursafaei et al. 2022Rossi et al. 2020Wang et al. 2021Souza et al. 2022)。在 TGB 中,我们将链接预测任务视为排序问题,并使评估更加稳健。我们展示了改进的评估结果能提供更现实的性能表现,并突出了不同模型之间的明显差距。

问题设定

在 TGB 中,我们专注于连续时间的时间图,如 Kazemi et al. 2020 定义的那样。在这种设置中,我们将时间图表示为带时间戳的边流,由*(源节点, 目标节点, 时间戳)*三元组组成。请注意,时间边可以是加权的、有向的,同时节点和边可以选择性地具有特征。

此外,我们还考虑了流式设置,在这种设置中,模型可以在推理时纳入新信息。特别地,在时间t预测测试边时,模型可以访问[1]所有发生在t之前的边,包括测试边。然而,不允许使用测试信息进行反向传播和权重更新。

数据集详情

TGB 包含 个数据集,其中七个是为此工作专门整理的,两个来自以前的文献。这些数据集在时间上分为训练集、验证集和测试集,比例为 70/15/15。数据集根据边的数量分类:小型(<5 百万)、中型(5–25 百万)和大型(> 25 百万)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TGB 数据集的统计信息

TGB 数据集还具有不同的领域和时间粒度(从 UNIX 时间戳到年度)。最后,数据集的统计信息也非常多样化。例如,惊讶指数,由训练集中从未观察到的测试边的比例定义,在不同的数据集中差异显著。许多 TGB 数据集中还包含许多测试集中出现的新节点,这需要归纳推理。

TGB 数据集也与现实世界任务相关。例如,tgbl-flight 数据集是一个从 2019 年到 2022 年的众包国际航班网络,其中机场建模为节点,而边则是给定日期的机场之间的航班。任务是预测未来某个日期两特定机场之间是否会发生航班。这对于预测潜在的航班中断(如取消和延误)非常有用。例如,在 COVID-19 大流行期间,为了遏制 COVID-19 的传播,许多航班路线被取消。预测全球航班网络对研究和预测疾病(如 COVID-19)向新地区传播也很重要,如 Ding et al. 2021 中所见。详细的数据集和任务描述在论文第四部分中提供。

动态链接属性预测

动态链接属性预测的目标是预测在未来时间戳下,节点对之间链接的属性(通常是存在性)。

负边采样。 在实际应用中,真实的边在事先并不为人知。因此,查询大量节点对,仅将得分最高的节点对视为边。受到这一点的启发,我们将链接预测任务框架化为排名问题,并对每个正边采样多个负边。具体而言,对于给定的正边*(s,d,t),我们固定源节点s和时间戳t*,并采样q个不同的目标节点d。对于每个数据集,q的选择基于评估完整性和测试集推断时间之间的权衡。在q个负样本中,一半是均匀随机采样的,另一半是历史负边(在训练集中观察到但在时间t时不存在的边)。

性能指标。 我们使用过滤后的平均倒数排名(MRR)作为本任务的指标,因为它专为排名问题设计。MRR 计算真实目标节点在负样本或伪目标中的倒数排名,通常用于推荐系统和知识图谱文献中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

tgbl-wiki 和 tgbl-review 数据集上的 MRR 表现

小数据集的结果。 在小数据集tgbl-wikitgbl-review上,我们观察到最佳表现的模型有很大差异。此外,在tgbl-wiki上的顶级模型,如 CAWN 和 NAT,在tgbl-review上的性能显著下降。一个可能的解释是,与tgbl-wiki数据集相比,tgbl-review数据集具有更高的惊讶指数。高惊讶指数表明,测试集边的高比例从未在训练集中观察到,因此tgbl-review需要更多的归纳推理。在tgbl-review中,GraphMixer 和 TGAT 是表现最佳的模型。由于其较小的规模,我们能够为tgbl-wiki采样所有可能的负样本,为tgbl-review每个正边采样一百个负样本。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

tgbl-coin、tgbl-comment 和 tgbl-flight 数据集上的 MRR 表现。

大多数方法在这些数据集上运行时耗尽了 GPU 内存,因此我们对 TGN、DyRep 和 Edgebank 进行了比较,因为它们的 GPU 内存需求较低。注意,某些数据集如tgbl-commenttgbl-flight跨越多年,因此可能导致其长期跨度上的分布变化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

负样本数量对 tgbl-wiki 的影响

洞察。tgbl-wiki中所示,用于评估的负样本数量可以显著影响模型性能:我们看到,当负样本数量从 20 增加到所有可能的目标时,大多数方法的性能显著下降。这验证了确实需要更多的负样本来进行稳健的评估。有趣的是,像 CAWN 和 Edgebank 这样的算法性能下降相对较小,我们将其作为未来的工作来调查为何某些方法受影响较小。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TG 模型的总训练和验证时间

接下来,我们观察到 TG 方法的训练和验证时间差异高达两个数量级,其中启发式基线 Edgebank 始终是最快的(因为它简单地实现为哈希表)。这表明,提高模型效率和可扩展性是未来的重要方向,以便可以在 TGB 中提供的大型数据集上测试新的和现有的模型。

动态节点属性预测

动态节点属性预测的目标是在任何给定的时间戳t预测节点的属性。由于缺乏具有动态节点标签的大型公共 TG 数据集,我们引入了节点亲和性预测任务来研究时间图上的节点级任务。如果您希望贡献具有节点标签的新数据集,请与我们联系。

节点亲和性预测。 该任务考虑节点子集(例如用户)对其他节点(例如项目)的亲和性及其随时间自然变化的方式。这个任务在推荐系统中很相关,在那里,通过建模用户对不同项目的偏好随时间的变化来为用户提供个性化推荐非常重要。在这里,我们使用前 10 项的归一化折扣累积增益(NDCG@10)来比较预测项目与真实值的相对顺序。标签是通过统计用户在未来一段时间内与不同项目的互动频率生成的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

节点亲和性预测任务的实证结果。

结果。 对于这个任务,我们将 TG 模型与两种简单的启发式方法进行比较:持久性预测,即预测当前时间点的最近观察到的节点标签,以及移动平均,即过去几步中的节点标签的平均值。这里的关键观察是,在这个任务中,像持久性预测和移动平均这样的简单启发式方法是 TG 方法的有力竞争者,并且在大多数情况下,它们的表现超过了 TG 方法。这突显了未来需要开发更多针对节点级任务的 TG 方法。

开始使用 TGB

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TGB 的机器学习管道

如何使用 TGB?上面展示了 TGB 的 ML 流程。你可以自动下载数据集,并将其处理为 numpyPyTorchPyG 兼容的数据格式。用户只需设计自己的 TG 模型,这些模型可以通过 TGB 评估器 进行 标准化评估*。* 最后,公开的 TGB 排行榜帮助研究人员跟踪时间图领域的最新进展。你可以轻松安装该软件包:

pip install py-tgb

最后,你可以将你的模型性能提交到 TGB 排行榜。我们要求你提供代码链接和描述你方法的论文以确保可重复性。要提交,请填写 google 表单

结论与未来工作

为了实现对时间图进行现实、可重复和鲁棒的评估,我们推出了时间图基准(Temporal Graph Benchmark),这是一个包含挑战性和多样化数据集的集合。通过 TGB 数据集和评估,我们发现模型性能在不同数据集上差异显著,这显示了在多样的时间图领域进行评估的必要性。此外,在节点亲和度预测任务中,简单的启发式方法优于 TG 方法,从而激发了未来开发更多节点级 TG 模型的动机。

集成到 PyG 中。 Matthias Fey(Kumo.AI),PyG 的核心负责人,在 斯坦福图学习研讨会 2023 上宣布,TGB 将集成到 PyG 的未来版本中。敬请关注!

TGX 库。 我们目前正在开发一个用于时间图的实用工具和可视化 Python 库,名为 TGX。TGX 支持来自 TGB 的 20 个内置时间图数据集以及 Poursafaei et al. 2022

社区反馈与数据集贡献。 TGB 是一个社区驱动的项目,我们感谢所有通过电子邮件或 Github 问题提供建议的社区成员。如果你有任何建议或希望向 TGB 贡献新的数据集,请通过 email在 Github 上创建问题 与我们联系。我们正在寻找大规模数据集,特别是用于动态节点或图分类任务的数据集。

2023 年的时间图学习

原文:towardsdatascience.com/temporal-graph-learning-in-2023-d28d1640dbf2?source=collection_archive---------1-----------------------#2023-01-16

目前为止的故事

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shenyang(Andy) Huang

·

关注 发表在 Towards Data Science ·15 分钟阅读·2023 年 1 月 16 日

现实世界的网络,如社交网络、交通网络和引用网络,往往会随着时间演变,而**时间图学习(TGL)**领域旨在从这些不断演变的网络中提取、学习和预测。最近,TGL 在机器学习社区中受到越来越多的关注,相关论文数量激增,去年在 NeurIPS 2022 上举办了该领域的首个研讨会

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

时间图中的演变。图片由作者提供。

这篇文章由 Emanuele Rossi, Michael Galkin Kellin Pelrine 共同撰写。 感谢 Farimah Poursafaei 提供的有益反馈。

在这篇博客文章中,我们展示了 TGL 在 2022 年之前的主要进展,并讨论了有前景的未来方向。请注意,我们将“动态图”和“时序图”交替使用。如果你想学习或开始一个 TGL 项目,这篇文章将是一个很好的参考和起点。

请在评论区与我们分享您感兴趣的其他进展。

目录:

  1. 时序图学习简介

  2. 时序图网络的表达能力

  3. 重新思考时序图中的评估

  4. 时序知识图谱

  5. 库和数据集

  6. 利用时序图进行疾病建模

  7. 时序图中的异常检测

  8. 检测时序图中的虚假信息

  9. 加入时序图学习社区

时序图学习简介

在本节中,我们简要介绍了文献中一些著名的 TGL 方法。学习连续时间动态图(CTDGs)的方法主要分为两类:时序图网络和游走聚合方法。有关 CTDGs 的详细信息,请参阅 Kazemi 等 的这篇综述。

时序图网络 (TGNs) 将信息传递神经网络 (MPNNs) 推广到时序图。它们通过引入一个节点记忆来实现,该记忆表示节点在给定时间的状态,作为节点过去交互的压缩表示。每当两个节点参与交互时,它们会相互发送消息,这些消息然后用于更新它们的记忆。在计算节点嵌入时,会对节点的时序邻居进行额外的图聚合,使用该时刻的原始节点特征和记忆。以下是 TGN 计算的示意图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对一批训练边缘的 TGN 计算。

图片来源:Rossi 等

TGN 是一个通用框架,它将以前的模型,如联合动态用户-项目嵌入 (JODIE) 和时序图注意力 (TGAT),作为特例进行推广。有关 TGN 的更全面介绍,请参阅下面其中一位作者的博客文章。

## 时序图网络

一种用于动态图的新型神经网络架构。

[towardsdatascience.com

诸如 Causal Anonymous Walks (CAW) 这样的 Walk 聚合方法则依赖于(时间)随机游走。特别是,为了预测时间 t 上的一个链接 (u, v) 的存在,CAW 首先提取多个从 uv 开始的随机游走,使得游走中的边的时间戳只能单调递减。这些游走首先通过用节点在游走中每个可能位置出现的次数向量替换每个节点标识符来进行匿名化。然后,使用 RNN 对每个游走进行编码,并通过自注意力或简单平均来聚合编码。

时间图网络的表达力。

关于在静态图上运行的图神经网络(GNNs)表达能力的研究已有大量工作。Xu et al. 2019 首次通过将图神经网络(GNNs)与 Weisfeiler-Lehman (WL) 图同构测试关联起来,并展示了许多 GNNs 的能力不超过 1-WL 测试,从而描述了其区分能力。随后,出现了更具表达能力的模型,如 子图 GNNs,图变换器高阶 GNNs,这些模型被设计得比 1-WL 测试更具表达力(下面是 Michael Bronstein 关于如何超越 WL 测试的精彩博客文章的链接)。

Graph Neural Networks beyond Weisfeiler-Lehman and vanilla Message Passing

受物理启发的图上连续学习模型可以克服传统 GNNs 的局限性。

towardsdatascience.com

直到今年,关于 TGL 方法的表达力的研究仍然很少。第一个弥合这一差距的努力是由 Ribeiro et al. 提出的,其关键思想是将现有的 TGL 方法分为 时间-和-图时间-然后-图 框架。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将 TG 转换为时间-然后-图表示。

图片来源: Ribeiro et al.

1️)。在 时间-和-图 中,GNNs 用于在每个时间快照图上生成节点嵌入,从而形成节点嵌入的序列。

2️)。在 时间-然后-图 中,TG 中的每条边被转换为一个时间序列,该序列指示边存在的时间,从而将时间边折叠为静态图中的边特征。

已证明 时间-然后-图 表示可以从任何给定的 时间-和-图 表示中构建,从而证明 时间-然后-图 至少与 时间-和-图 一样具备表达力。通过在 时间-然后-图 中的静态表示,我们可以直接将静态图的 WL 测试表达框架应用于 TGL 方法。这样,只要使用 1-WL GNN 作为主干模型,时间-然后-图 就比 时间-和-图 更具表达力。

Souza et al. 也旨在为 TGL 方法建立 1-WL 表达框架。值得注意的是,他们将 CTDG 视为一系列时间戳多图,其中在给定时间 t 的多图 G(t) 是通过顺序应用所有早于 t 的事件来获得的。这里的多图意味着两个节点之间可以有多条边,而边的属性是时间戳信息。

现在,时间 WL 测试可以通过对从 CTDG 构建的多图应用 WL 测试来定义。因此,更具表达力的 TGN 方法必须在其时间邻域上是单射的(即将两个不同的多集节点哈希为不同的颜色),称为单射 MP-TGNs。Souza et al. 还分析了基于游走的 TGNs,如 CAW,并显示 MP-TGNs 和 CAW 之间并没有比彼此更具表达力(如上所示)。他们提出的 PINT 方法结合了这两类方法的优点,因此是最具表达力的。下面的示例显示了 MP-TGNs 无法区分的两个时间图。颜色表示节点标签,边的时间戳从 t₁ 开始。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MP-TGNs 无法区分的时间图示例,例如直径、环长和循环数量。

图片来源:Souza et al.

重新思考时间图中的评估

在很大程度上,TGL 中的评估程序相对未被充分探索,并且受到静态图学习的重大影响。例如,对动态图上的链路预测任务(或动态链路预测)的评估通常涉及:1)。固定的训练、测试拆分,2)。随机负边采样 和 3)。来自类似领域的小数据集。这样的评估协议往往导致结果表中报告的指标已经达到 95% 以上,很难区分新模型是否带来了实际的好处,还是只是重新使用现有方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

典型的时间链路预测结果表,报告了平均精度(AP)。即使基线模型也能达到 98%,我们真的在取得进展吗?

图片来源: Souza 等

You 等 讨论了当前 TGL 方法在离散时间动态图(DTDGs)中的模型设计、评估设置和训练策略的局限性。他们认为数据和模型的演变特性没有被考虑。在标准评估中,所有时间点按时间顺序划分为训练集、评估集和测试集。对于给定的数据集,这种划分是固定的。

他们指出,这种固定的划分意味着只有来自所选测试期的边会被评估,因此可能跨越训练、验证和测试期的长期行为将无法正确评估。此外,许多 TGL 方法在测试时是过时的,意味着模型表示在评估过程中没有得到更新。考虑一个示例交易图,如果前一天的信息可用,用户很可能希望利用这些信息来更新模型,以实现最佳性能。因此,提出了一种实时更新评估的方法,其中模型根据新观察到的数据进行微调,利用历史信息并预测未来的连接。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

灰色/红色条分别表示 Wikipedia / MOOC 数据集中的重复/新颖边。时间图中的许多边随时间重复出现。

图片来源: Poursafaei 等

近期工作由两位作者研究了如何选择负边进行 CTDG 方法的评估,并引入了来自不同领域的更多数据集。在动态链接预测中,负边通常是从任意节点对中随机抽取的。然而,时间图中的许多边会随着时间的推移而重复(如上图所示)。考虑到现实世界图的稀疏性,大多数节点对不太可能形成边。因此,随机 负边可以被视为容易的负边。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

TGL 方法的平均性能。使用更困难的负边显著影响模型性能。简单的基线 EdgeBank 的表现也出奇地好。

图片来源: Poursafaei 等

现在,什么可以被视为困难的负边?首先,我们介绍历史负边,即在训练集中出现但在当前测试步骤中缺失的边。我们还将归纳负边定义为在测试集中之前出现但在当前步骤中不存在的测试边。最后,我们提出了一个基线 EdgeBank,仅依靠记住过去的边(本质上是已见边的哈希表)。在上面的图中,我们看到,通过改变负边进行评估时,现有 TGL 方法在历史归纳设置下的平均性能显著降低,与标准设置相比。EdgeBank 在标准设置下也是一个出乎意料的强大基线。有关详细信息,请参见下方作者之一的博客。

[## 迈向更好的动态图链接预测

伴随博客文章,介绍了《迈向更好的动态链接预测》的评估,将在 NeurIPS 2022 数据集和…

medium.com](https://medium.com/@shenyanghuang1996/towards-better-link-prediction-in-dynamic-graphs-cdb8bb1e24e9?source=post_page-----d28d1640dbf2--------------------------------)

时间知识图谱

在知识图谱(KG)的领域中,时间设置与同质世界略有不同,即时间戳图快照并不常见。相反,一些(或所有)三元组具有一个(开始时间,结束时间)对属性,表示某个事实为真的时间范围。因此,三元组变成了五元组,或者在 Wikidata 中,时间属性成为 限定词 的一部分,更一般的 声明(主三元组 + 多个键值限定词),声明形成所谓的 超关系 KGs*.*

例如,(法国总统,职务持有者,尼古拉·萨科齐,2007,2012) 是一个五元组,描述了尼古拉·萨科齐担任法国总统的时间段。或者,每个三元组也可以只有一个时间戳(形成四元组)。最常见的预测任务是给定时间属性评分头/尾预测,例如,(法国总统,职务持有者,**???**,2007,2012) —— 这可以被视为超关系链接预测的特例,其中限定词仅为日期时间文字。一个经典的时间 KG 补全模型是 TNTComplex(ICLR 2020)。

Krause et al. 已经迈出了弥合时间知识图谱与同质图之间差距的第一步。在这项工作中,作者提出了一个框架,以形式化知识图谱中的各种时间方面。即,他们将 时间 知识图谱定义为局部扩展,即边上具有时间戳的图,而 动态 知识图谱定义为全局扩展,即随着时间的推移通过添加或删除节点和边而改变拓扑的图。更进一步,这些基本类型的组合是存在的,例如,时间和动态知识图谱的组合被称为 增量。我们希望这项工作能为时间知识图谱的繁杂文献带来更多秩序和清晰度,社区也能遵循这个良好的分类法。下一步:为这些图类型最终确定一个适当的评估协议。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

时间和动态知识图谱(及其组合)。

图片来源:Krause et al.

Wang et al. 解决了在时间 + 动态图上进行少样本链接预测的任务,其中边具有时间戳并且新节点可能在后续时间步出现(增量 ,如 Krause et al. 上述分类)。少样本场景使得任务更加具有挑战性——我们只能访问有限数量的训练和推理点(通常小于 5)来推理查询链接。在这里,作者提出了 MetaTKGR,这是一种基于元学习的方法,通过聚合一定 delta t 时间邻域内现有节点的特征来构建新节点的表示。时间戳之间的标量差异通过傅里叶变换进行向量化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MetaTKGR 的组件。

图片来源:Wang et al.

库和数据集

过去几年中,缺乏大规模数据集和具有挑战性的任务一直阻碍着时间图学习领域的研究。幸运的是,来自不同领域的新数据集正在涌现。例如,Poursafaei et al. 引入了六个新的公开可用的 TG 数据集,涵盖了交通、政治、经济和接近领域。然而,该领域仍然缺乏一致的努力来将基准和评估标准化到高质量,就像 OGB 对静态图所做的那样。我们希望在 2023 年,我们能看到更多关注实际应用的标准化 TG 基准。

关于库,著名的一个是 Pytorch Geometric Temporal,这是 Pytorch Geometric 的时间图扩展。然而,Pytorch-Geometric Temporal 似乎只包含离散时间方法和数据集。一个包含连续时间方法的库将为社区带来很大价值。最近,Zhou et al. 提出了 TGL,这是一个用于大规模离线时间图神经网络训练的统一框架。特别是在一台 4-GPU 机器上,TGL 可以在 1–10 小时内训练超过十亿条时间边的一轮。

我们列出了各种 TGL 库和数据集的链接如下。

使用时间图进行疾病建模

在近期的 COVID-19 大流行中,流行病建模对理解疾病传播以及设计相应的干预策略至关重要。人际接触网络实际上是时间图。通过将接触图与经典的基于隔离的模型如 SEIRSIR 相结合,我们可以更准确地预测 COVID-19 感染曲线,并超越同质混合假设(所有个体之间的接触概率相等)。

Chang et al. 从手机数据中推导出了时间移动网络,并将 9800 万人的小时移动从人口普查区块组(CBGs)映射到美国的特定兴趣点(POIs)。通过将小时接触网络与 CBG 层面的 SEIR 模型结合,他们能够准确拟合实际感染轨迹。特别是,模型显示一些‘超级传播者’POIs 如餐馆和健身中心占据了大多数感染。此外,不同种族和社会经济群体之间的流动差异导致这些群体之间的感染率不同。这项工作展示了利用大规模时间图进行疾病预测和制定干预策略的现实潜力。

除了人际接触网络,动态交通网络在 COVID-19 的传播中也扮演着重要角色。在一项研究中,我们将每日航班网络纳入 SEIR 模型,以估计输入的 COVID-19 病例。通过纳入航班网络,可以实现对疫情爆发的早期检测并预测旅行限制的影响。更多细节请见作者的博客文章

尽管基于时间图的疾病模型在实践中取得了成功,但回答诸如“接触网络结构如何影响疾病传播?”和“如何修改接触模式以减缓或阻止 COVID-19 的传播?”等问题也很重要。Holme 等比较了在八个网络数据集中使用时间、静态和完全连接网络的爆发特征差异,并研究了不同网络结构对疾病传播的影响。他们展示了将时间网络转换为静态网络可能导致对疾病爆发规模和消失时间的严重低估或高估。

TGL 在流行病建模方面的下一步是什么?

首先,预测整个接触或流动网络的快照以应对短期挑战是一个关键问题。通过预测的结构,我们可以应用基于网络的 SEIR 模型来估计感染曲线。

其次,定义和理解互动模式对接触网络的影响对于政策制定和可解释性至关重要。分析图结构与感染曲线之间的相互作用可以帮助我们确定最有效的干预策略。

时间图中的异常检测

异常检测是分析时间图中的一个基本任务,它识别出与其他实体显著偏离的实体。例如,欺诈检测可以被建模为在交易网络中检测异常边缘,而交通事故识别可以被视为在交通网络中检测异常事件。

对于利用时间图网络的表示能力进行异常检测的兴趣日益增长。蔡等人设计了一个端到端结构化时间图神经网络模型,用于检测异常边,称为StrGNN。首先基于感兴趣的边提取一个包围子图,一个以该边为中心的 k-hop 子图,以减少计算复杂性。然后使用图卷积神经网络(GCN)从子图中生成结构嵌入。接着使用门控递归单元(GRUs)来捕捉时间信息。异常检测的挑战之一是缺乏标记样本。因此,蔡等人提出通过替换正常边中的一个节点来生成“上下文相关”的负边,并用这些负边来训练模型。

与无监督的、非 GNN 基础的异常检测方法,如SEDANSPOTAnomRank相比,GNN 基础的方法可以轻松地结合任何给定的属性,并具有实现更强性能的潜力。然而,GNN 基础的方法面临两个重大挑战。

1). 首先,如何扩展到具有数百万条边和节点的动态图?这是一个开放性问题,既涉及到 GNN 模块在提取图特征时的挑战,也涉及到处理长期信息的时间模块,如 GRUs 和 transformers。

2️). 其次,如何为检测到的异常提供准确的解释?在实际应用中,检测到的异常通常会被验证,然后可能对这些检测到的实体采取惩罚措施。GNN 在动态图上的可解释性仍然是一个未解决的挑战。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

LAD 检测到 2013 年是加拿大 MP 投票网络中的一个变化点,原因是政治党派之间的边的数量异常。

图片来源:黄等人

变化点检测任务旨在检测动态图中时间点的变化,其中图结构或分布显著偏离之前观察到的状态。这种变化可能归因于外部事件(如交通中断和 COVID-19 相关的航班限制),或仅仅是动态图的自然演变。作者之一的近期工作利用了每个图快照的拉普拉斯矩阵的特征值来嵌入图结构,同时应用滑动窗口来比较图结构在长短期内的变化。在上述内容中,提出的拉普拉斯异常检测 (LAD) 方法检测到了由于政治党派之间边缘增加而导致的加拿大国会议员(MP)投票网络中的变化。这与贾斯廷·特鲁多在 2013 年被选为自由党领导人的事件相吻合。

在时间图上检测虚假信息

虚假信息的传播模式和速度与真实信息不同 (Vosoughi 等)。已有大量研究在静态图中研究这些网络模式,而动态图方法尚未得到充分探索 (Song 等)。然而,在过去的一年里,使用 TGL 方法进行虚假信息检测和理解的数量有所增加。例如,Zhang 等 开发了一种基于时间点过程的方法,而动态 GCN (DynGCN) 和 DGNF 是基于动态 GNN 的方法。

下图展示了 DynGCN 的架构。他们以均匀时间间隔构建图快照,通过 GCN 层处理每个快照,然后结合这些表示并使用注意力机制学习快照的演变模式。这是一种相对简单的方法,比起上述一些方法如 TGNCAW,它利用时间信息的方式更为简单,但在作者检查的数据集上,比之前的最先进技术在虚假信息检测方面表现更好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DynGCN 使用具有共享权重的 GCN 层处理单个图快照,然后通过注意力机制结合这些表示以获取时间上的演变。

图片来源: Choi 等

动态交互模式在虚假信息检测中被证明非常有用(Plepi 等)。随着 TGL 方法的显著进展,我们可以期待结合动态图的新型最先进的虚假信息检测方法。

加入时间图学习社区

2022 年,机器学习社区对时间图学习(TGL)的关注有所增加。首届TGL 研讨会于 NeurIPS 2022 上举办。会议的演讲和讨论会录像将很快在NeurIPS 虚拟网站上提供。接受的论文可以在研讨会网站上找到。请关注 TGL 研讨会的新版本公告,并加入研讨会 Slack(网站上有最新链接)以便与社区互动。今年,我们还计划组织一个 TGL 阅读小组,如果你希望分享你的工作或参与组织阅读小组,请发送邮件至 shenyang.huang@mail.mcgill.ca。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:NeurIPS 2022 时间图学习研讨会的 logo。图片由作者提供。

Python 中的临时变量:可读性与性能

原文:towardsdatascience.com/temporary-variables-in-python-readability-versus-performance-f6708b5f293c

PYTHON 编程

临时变量可以使代码更清晰。那么这样的代码性能如何呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Marcin Kozak

·发布于Towards Data Science ·8 分钟阅读·2023 年 6 月 1 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Python 的快捷方式快吗?图片由Stefan Steinbauer提供,Unsplash

临时变量是生命周期很短的变量:

[## 临时变量 - 维基百科

来自维基百科,免费的百科全书。在计算机编程中,临时变量是生命周期很短的变量……

en.wikipedia.org

临时变量在编程中非常常用,你不需要知道这个术语也可以使用临时变量。一个最常见的用例是使代码更清晰,例如,在管道中:

input → tempvar_1 := func_1(input) →
        tempvar_2 := func_2(tempvar_1) →
        func_3(tempvar_2) → output

在这里,我使用了 Python 的海象运算符来直观地表示赋值,就像在 Python 代码中使用的一样。在这个管道中,我们有两个临时变量:tempvar_1tempvar_2。它们在数据流过代码的时间上生命周期很短,尽管在实际时间上可能很长。tempvar_1 仅用于一个目的:将管道第一步的结果传递到下一步。但请注意,从技术上讲,它是没有必要的:

input → func_3(func_2(func_1(input))) → output

虽然这两个版本的功能相同,但后者的可读性可能大大降低。因此,前者在编程中被广泛使用,唯一的原因是使代码更清晰。

注意,如果tempvar_1tempvar_2在代码后续中使用,它们就不再是临时变量,因为它们的生命周期不会很短。为了简单起见,我们可以假设临时变量是你只使用一次的变量,用于将一个可调用的输出作为输入传递给另一个。

你是否曾经思考过在管道中使用临时变量是否比直接——和最短——的计算方式更好?比如,以下两个代码片段哪个更好?

# snippet 1
third_function(second_function(first_function(x)))

# snippet 2
x1 = first_function(x)
x2 = second_function(x1)
x3 = third_function

或者,这次使用简单的算术运算:

# snippet 1
x1 = 2.056 * x
x2 = x1 / (1 + x1)
y = 2.3 / (- x2 - 7.33)

# snippet 2
y = 2.3 / (- 2.056 * x / (1 + 2.056 * x) - 7.33)

你会选择哪种方式?这重要吗?

Python 因各种原因而非常受欢迎,其中之一是其代码的可读性。同时,Python 也因其性能差而闻名——尽管它并不像许多人声称的那样糟糕,正如我在下面的文章中所写的:

[## Python 的速度:并没有那么糟糕!

我一直听到 Python 太慢了。这是真的吗?

medium.com](https://medium.com/pythoniq/the-speed-of-python-it-aint-that-bad-9f703dd2924e?source=post_page-----f6708b5f293c--------------------------------)

很多时候,你可以——也需要——在可读性和性能之间进行选择。有时你可能需要哪怕是最微小的性能提升,即使这意味着可读性降低。其他时候,小幅度的性能提升意味着没有副作用,并且代码既可读又易懂;为何不选择它呢?

当性能的提升带来一些成本时,你应该小心。你应该问自己——或你所在的开发团队——以下问题:这种微小的性能提升是否值得降低代码的可读性?

在这篇文章中,我想向你展示一个通过避免临时变量实现的改进的例子。去除它们可以稍微提高性能,但通常会以降低可读性为代价。是的,通常是这样,所以不总是:如果你运气好,去除临时变量可以帮助你同时提高性能可读性。这是完美的情况,不是吗?

Python 代码中的临时变量

想象一下你想实现一个计算一系列事物的函数。为了简单起见,我们将进行一些基本的算术计算,使例子变得简单。然而,在现实生活中,这样的管道可能包含多个函数执行各种操作,甚至相当复杂。

def calc_with_tempvar(x):
    y = x**2
    z = y/2
    f = z + 78
    g = f/333.333
    return g

因此,我们从 x 开始,然后计算 yzf,最终得到 gg 是最终输出,因此返回。这类似于 函数组合,不同之处在于这里我们不是组合函数,而是组合计算。然而,在许多场景中,你将会有实际的函数;例如,代替 y = x**2,你可以有 y = some_function(x)。在 Python 中类似的一个完美示例就是生成器管道:

## 在 Python 中构建生成器管道

本文提出了一种优雅的构建生成器管道的方法。

towardsdatascience.com

及其一般版本,理解管道:

## 在 Python 中构建理解管道

理解管道是 Python 特有的构建管道的概念。

towardsdatascience.com

在简单的情况下,例如我们的 calc_with_tempvar() 函数中,这种方法似乎有些多余。相反,我们可以简单地做如下操作:

def calc_without_tempvar(x):
    return ((x**2)/2 + 78)/333.333

这些测试表明,两者产生了完全相同的结果:

>>> for x in (1, 2.3, 0.0000465, 100_000_000.004):
...     assert calc_with_tempvar(x) == calc_without_tempvar(x)

没有输出意味着这确实是正确的。

首先,让我们拆解这两个函数,看看它们如何转换为 Python 字节码:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 dis.dis() 函数对两个函数进行拆解。图片由作者提供

即使不分析这两个函数的字节码,我们也可以看到,使用临时变量的函数比不使用临时变量的函数在 Python 中需要做更复杂的工作。这不奇怪,对吧?函数的定义方式本身表明 calc_with_tempvar() 要达到结果需要做更多工作,而不是 calc_without_tempvar()

临时变量:性能

然而,这如何转化为性能呢?为了了解这一点,让我们使用 [perftester](https://github.com/nyggus/perftester) Python 包,它专门用于基准测试和测试 Python 函数的时间和内存性能:

## 轻松基准测试 Python 函数:perftester

你可以使用 perftester 轻松地基准测试 Python 函数。

towardsdatascience.com

对于基准测试,我在 Windows 10 机器上的 WSL 1 中使用了 Python 3.11,配备 32GB 的 RAM 和四个物理(八个逻辑)核心。然而,在我们的案例中,原始时间并不那么重要;我们将重点关注相对比较。

首先,让我更改基准测试的默认设置。我将使用 2000 万次函数调用重复 7 次;从中选择最快的一次作为基准结果。

>>> import perftester
>>> perftester.config.set_defaults(
...     "time",
...     Number=20_000_000,
...     Repeat=7,
... )

现在实际的基准测试,对于一个float数值:

>>> x = 1.67
>>> t1 = perftester.time_benchmark(calc_with_tempvar, x)
>>> t2 = perftester.time_benchmark(calc_without_tempvar, x)

然后让我们看看结果¹:

>>> perftester.pp({
...     "1\. composition": t1["min"],
...     "2\. one-shot": t2["min"],
...     "3\. composition--to--one-shot ratio": t1["min"] / t2["min"]
... })
{'1\. composition': 2.063e-07,
 '2\. one-shot': 1.954e-07,
 '3\. composition--to--one-shot ratio': 1.056}

如预期的那样,一次性版本(不使用临时变量)更快——大约快 5%。一方面,这并不多。另一方面,这仅仅是通过如此微小的改变——如此小的变化就达到了 5%!

上述计算是快速的。然而,对于较长的计算,差异可能接近于不可见。

你注意到我们可以稍微改进一下calc_with_tempvar()函数吗?我们需要最后一个对象g吗?有时候,像这样的对象通过一个好的名字可以提高函数的可读性,但在这种情况下并不需要——所以我们不需要g。让我们看看去掉它是否会提高性能:

def calc_with_tempvar_shorter(x):
    y = x**2
    z = y/2
    f = z + 78
    return f/333.333
>>> t3 = perftester.time_benchmark(calc_without_tempvar_shorter, x)
>>> t3["min"]
1.998e-07

一个微小的改进,因为组合版本比这个版本慢1.032倍,而这个版本比一次性版本慢1.023倍。但再次强调,这种改进是通过如此微小的变化实现的!既然如此,这个微小的改变值得使用吗?

结论

对我来说——绝对值得,但并非总是。

关键是,当性能不重要时,优先考虑可读性。如果程序运行多一分钟、10 秒钟或甚至半秒钟都不会改变任何事情——干脆不要考虑通过这些技巧来提高性能。为什么要这样做?为什么要为了微小的改进而降低可读性呢?只需关注可读性。

当然,有时去掉临时变量会提高函数的可读性。在这种情况下,为什么我们还要讨论这个?再次强调,优先考虑可读性,当这也意味着提高性能时——那就完美了。

有时性能确实很重要。即使是秒的分割也可能产生差异。如果是这种情况,你应该分析你的代码并找出瓶颈。其他时候,你可能希望优化代码的每一个部分。一个例子是为他人使用的框架进行工作,对于其中一些人,性能将很重要。在这种情况下,作为框架作者的有责任提供尽可能快的工具。否则,你有可能会冒着一些用户不会使用你的框架的风险。

总结

  • 如果性能很重要,避免使用像calc_with_tempvar()中的临时变量。如果性能重要性较低(如果有的话),则优先考虑可读性——这意味着是否使用临时变量的决定应该完全基于代码的可读性。

  • 临时变量并不总是增加可读性。例如,假设你有一个数学函数 y(x) = ((x**2)/2 + 78)/333.333。你认为 calc_with_tempvar(),那个包含所有临时变量的函数,会提高可读性吗?我不这么认为。

因此,有时临时变量会提高代码的可读性,有时则不会。如果性能至关重要,请记住,临时变量可能会增加一些轻微的开销。更多时候,这些开销是微不足道的——但在一些项目中,即使是那些秒数的分割也可能很重要。

总之,始终双重检查是否值得在你的代码中去除临时变量——或者是否值得使用它们。

注释

¹ 代码使用了 perftester.pp() 函数,该函数以标准库函数 pprint.pprint() 的方式美观地打印 Python 对象,并将其中的所有数字四舍五入到四位有效数字。它使用了rounder包:

## GitHub - nyggus/rounder: 用于在复杂 Python 对象中对浮点数和复杂数字进行四舍五入的 Python 包

rounder 是一个轻量级的包,用于在复杂的 Python 对象(如字典、列表、元组等)中对数字进行四舍五入。

github.com

感谢阅读。如果你喜欢这篇文章,你可能也会喜欢我写的其他文章;你可以在这里查看。如果你想加入 Medium,请使用下面的推荐链接:

## 使用我的推荐链接加入 Medium - Marcin Kozak

阅读 Marcin Kozak 的每一个故事(以及 Medium 上成千上万其他作家的故事)。你的会员费直接支持…

medium.com

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值