TowardsDataScience 2023 博客中文翻译(二百七十七)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

语言障碍的革命:掌握多语言音频转录和语义搜索

原文:towardsdatascience.com/revolutionizing-language-barriers-mastering-multilingual-audio-transcription-and-semantic-search-5540f038778d

利用先进的转录和语义搜索技术,解锁跨语言信息访问的潜力

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Luís Roque

·发表于 Towards Data Science ·12 min 阅读·2023 年 12 月 13 日

这篇文章由 Rafael Guedes 共同撰写。

介绍

在我们这个互联的世界中,信息没有边界,使其对每个人都可访问,不论他们的母语是什么或他们是否有能力学习新语言,这一能力变得非常重要。无论你是内容创作者还是全球组织的负责人,能够快速而轻松地帮助你的追随者/客户在多种语言中搜索特定信息都有很多好处。例如,它可以帮助客户找到用不同语言已经回答过的相同问题。

考虑一个不同的使用场景,你经常需要参加公司会议。通常,你可能无法参与,而讨论的许多话题可能与你无关。如果你能够搜索感兴趣的主题并收到总结,包括相关讨论的开始和结束时间,这将多么方便?这样,你可以用十到十五分钟的时间获取所需的信息,而不是花费一个小时在会议上,这将显著提高你的生产力。此外,你可能有用葡萄牙语和英语录制的会议。然而,你仍然希望用英语进行搜索。

在本文中,我们将展示如何实现多语言音频转录和多语言语义搜索,以便你可以将其应用于你的使用场景。对于多语言音频转录,我们将解释 Whisper 和 WhisperX 的工作原理、它们的局限性以及如何在 Python 中使用它们。

然后,我们介绍多语言语义搜索模型如何训练,以及为何您可以从向量数据库中获取相同的信息,无论您用什么语言查询。我们还提供了使用 Postgres 和 PGVector 进行语义搜索的详细实现。

最后,我们展示了上述结果在两个用例中的表现。我们使用了两个视频,一个是葡萄牙语的,另一个是英语的,并用葡萄牙语和英语提出相同的问题,以检查是否能得到相同的答案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1: 多语言音频转录和多语言语义搜索有无尽的应用场景待探索(i图像由作者使用 DALLE 制作)

一如既往,代码可以在我们的 GitHub 上找到。

WhisperX:一个强大的音频转录架构

WhisperX [1] 是 Whisper [2] 的进化形式,Whisper 是由 OpenAI 开发的模型。但它们之间有什么区别呢?

Whisper 和 WhisperX 是能够进行多语言语音识别、语音翻译、口语语言识别和语音活动检测的语音识别模型。它们依赖于转换器序列到序列架构,将各种语音处理任务表示为一系列由解码器预测的标记。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2: Whisper 架构(i图像由作者提供)

尽管 Whisper 在不同领域和语言中表现出色,但在长音频转录方面还需要改进。这个问题的主要原因是训练期间使用的滑动窗口方法。这通常导致漂移和幻觉。它在将转录与音频时间戳对齐时也存在限制。

WhisperX 来解决这些问题:

  1. 漂移和幻觉是通过语音活动检测 (VAD) 和自定义的方法来解决的,用于剪切和合并音频片段。VAD 检测人声的存在或缺失,并根据该分类将输入音频分成段。之后,它将带有人声的片段剪切并合并为 30 秒的窗口。它尝试在语音概率较低的区域定义边界。这些片段被剪切成 30 秒的窗口,以匹配 Whisper 训练时使用的片段持续时间。

  2. 转录对齐是通过强制对齐来解决的,这是架构的最后一层。它使用音素识别模型来识别区分一个单词和下一个单词的最小语音单元,例如,‘t’‘nut’ 中的元素。然后,通过获取同一单词中第一个和最后一个音素的开始和结束时间来获得每个单词的开始和结束时间,以获得更可靠的对齐。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:WhisperX 架构 (i 图片由作者提供)

Whisper 和 WhisperX 的实际应用

我们可以使用 Whisper 或 WhisperX 通过几行代码转录音频。我们需要从 git+https://github.com/openai/whisper.gitgit+https://github.com/m-bain/whisperx.git 安装 Whisper 和 WhisperX。

安装完成后,我们首先导入 Whisper 或 WhisperX。然后,我们加载模型,最后,我们转录 .wav 格式的音频文件。结果将是一个包含三个键的字典:

  1. ‘text’ 是一个包含完整转录文本的字符串。

  2. ‘segments’ 是一个文本片段的列表,包含开始和结束时间以及其他一些元数据。

  3. ‘language’ 是一个表示音频语言的字符串。

### ----- WHISPER ----- ###
import whisper
model = whisper.load_model("large", "cpu")
result = model.transcribe("<YOUR AUDIO FILE>.wav")### ----- WHISPERX ----- ###
import whisperxmodel = whisperx.load_model("large-v2", "gpu", compute_type="float16")
audio = whisperx.load_audio("<YOUR AUDIO FILE>.wav")
result = model.transcribe(audio)

如前所述,Whisper 在将转录与音频时间戳对齐时存在一些局限性。因此,我们使用 WhisperX 来解决这个问题。

我们加载对齐模型,并根据 Whisper 或 WhisperX 的结果,修正其对齐。

from whisperx import load_align_model, align
model_a, metadata = load_align_model(language_code=result['language'], device="cpu")
result_aligned = align(result['segments'], model_a, metadata, "<YOUR AUDIO FILE>.wav", "cpu")

语义搜索:一种多语言方法

语义搜索是一种搜索引擎技术,它匹配查询的含义,而不是传统搜索方法匹配查询的关键词。

语义搜索通过使用 Transformers 来实现有效性,Transformers 对于将自由文本形式的文档转换为数值表示至关重要。这些表示,称为嵌入,实际上是存储在像 PGVector 这样的向量数据库中的向量。这个过程使语义搜索能够基于含义或意图匹配查询,从而显著提高搜索结果的准确性和相关性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:语义搜索的内部工作原理(图片由作者提供)

当用户提交查询时,它会被转换成一个嵌入。这个嵌入随后被向量数据库的内置检索系统利用,通常基于 k 最近邻(kNN)算法。该系统使用这个算法来识别和排序与用户查询最相关的 k 个最相似的文档。这个过程确保检索到的结果与用户的搜索意图紧密对齐。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:语义搜索嵌入被应用于文本、音频或图像。这些嵌入可以存储在由 kNN 支持的向量数据库中,以根据用户查询检索最相关的文档。 (i 图片由作者提供)

最新的 NLP 进展,特别是在语义搜索方面,使得在不同语言中为相同句子创建相同的嵌入成为可能 [3]。这对全球运营的组织带来了巨大的优势,因为他们可以快速且低成本地将语义搜索扩展到更多语言。这是可能的,因为所需的样本相对较少,硬件要求较低,正如我们将遵循的方法的作者所提到的。

扩展通常基于英语的单语模型涉及使用教师模型学生模型。这些模型在使语言模型能够有效处理多种语言方面扮演着不同但互补的角色。

教师模型: 该模型作为参考点或标准。它通常是一个在源语言(通常是英语)中经过充分训练的高性能模型。教师模型对语言有深刻的理解,能够生成准确代表各种文本含义的高质量嵌入向量。

学生模型: 学生模型旨在从教师模型中学习。与仅在源语言中操作的教师模型不同,学生模型同时处理源语言和翻译语言。学生模型的主要目标是在新语言环境中复制教师模型的性能。

这些模型的使用及其有效性的原因在于它们的训练方法:

  1. 嵌入对齐: 学生模型的训练目标是最小化其嵌入与教师模型生成的嵌入之间的均方误差。这个过程确保了学生模型在源语言和翻译语言中生成的嵌入与教师模型的嵌入紧密匹配。

  2. 语言适应: 这种训练方法使学生模型能够适应新语言,同时保持原始模型的质量和特征。通过与教师模型的理解对齐,学生模型能够有效处理和理解翻译语言。

  3. 高效学习: 学生模型不必从头开始学习。通过利用教师模型已经成熟的理解,学生模型可以在新语言中以潜在更少的数据和训练时间实现高性能。

  4. 跨语言一致性: 这种方法确保模型在不同语言中的性能一致。它在保持嵌入质量方面特别有利,而嵌入对语义搜索、自然语言理解和翻译等任务至关重要。

尽管可以使用几种架构,但作者们为教师模型使用了 Sentence-BERT [4],为学生模型使用了 XLM-RoBERTa [5]。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:多语言嵌入创建的架构,其中给定两个不同语言的相同句子,学生模型可以生成与教师模型生成的向量相近的两个语言的向量(来源

使用 PGVector 实现多语言语义搜索

在本节中,我们介绍了在 Postgres 上实现 PGVector 的方法。我们还部署了一个 pgAdmin 应用程序来查询 Postgres,并检查我们的嵌入如何存储。

我们借助 LangChain 来编码来自 Whisper 或 WhisperX 的转录,将其插入 Postgres 中的一个表,并检索与用户查询最相似的文档。

由于在我们的使用案例中,我们需要能够检索信息,而不论音频的语言或用户查询的语言,因此我们使用sentence-transformers中的multi-qa-mpnet-base-dot-v1来编码转录。我们选择这个模型是因为它在多语言语义搜索中表现最佳(你可以在这里查看可用的多语言语义搜索模型)。

设置 PGVector

我们使用 Docker 部署由 PGVector 支持的 Postgres。我们首先定义 docker-compose.yml 文件,包含两个容器,postgrespgadmin

Postgres:

  • 镜像:ankane/pgvector允许我们部署带有 PGVector 扩展的 Postgres。

  • 端口:5432

  • 环境:与 Postgres 交互的用户名和密码,以及一个存储我们嵌入的数据库。

pgAdmin:

  • 镜像:dpage/pgadmin4

  • 端口:5050

  • 环境:登录用的电子邮件和密码。

version: '3.8'
services:
  postgres:
    container_name: container-pg
    image: ankane/pgvector
    hostname: localhost
    ports:
      - "5432:5432"
    environment:
      POSTGRES_USER: admin
      POSTGRES_PASSWORD: root
      POSTGRES_DB: postgres
    volumes:
      - postgres-data:/var/lib/postgresql/data
    restart: unless-stopped
  pgadmin:
    container_name: container-pgadmin
    image: dpage/pgadmin4
    depends_on:
      - postgres
    ports:
      - "5050:80"
    environment:
      PGADMIN_DEFAULT_EMAIL: admin@admin.com
      PGADMIN_DEFAULT_PASSWORD: root
    restart: unless-stopped
volumes:
  postgres-data:

一旦定义了 docker-compose 文件,我们可以通过在 docker-compose 文件所在的目录中运行docker-compose up -d命令来启动我们的应用程序。

应用程序运行后,是时候在 pgAdmin 中创建一个服务器,以便我们可以查询我们的嵌入和文档。为此,我们必须按照以下步骤操作:

  1. 在网页浏览器中打开 pgAdmin 的 Web 界面,访问localhost:5050/

  2. 使用我们在 docker-compose 文件中的PGADMIN_DEFAULT_EMAILPGADMIN_DEFAULT_PASSWORD环境变量中设置的电子邮件和密码进行登录。

  3. 右键点击服务器节点,选择注册 → 服务器

  4. 创建 — 服务器对话框中,在名称字段中输入服务器名称。

  5. 连接选项卡中,插入以下信息:

  6. 主机名/地址postgres

  7. 端口5432

  8. 维护数据库:你可以使用postgres数据库来完成这个任务。

  9. 用户名POSTGRES_USER环境变量,我们在 docker-compose 文件中设置了它。

  10. 密码POSTGRES_PASSWORD环境变量,我们在 docker-compose 文件中设置了它。

  11. 点击保存按钮以创建服务器。

创建好服务器后,我们来填充刚刚创建的名为postgres的数据库中的嵌入和文档。

注意:这个 pgAdmin 是可选的;如果你不想查询嵌入,可以跳过这一步。

使用 LangChain 填充 Postgres

一旦数据库设置好并准备好存储嵌入,就该定义编码器了。之后,我们使用 LangChain 来填充和检索数据库中与用户查询最相似的文档。

如上所述,编码器是多语言的,可以定义为:

from langchain.embeddings import HuggingFaceEmbeddings

encoder = HuggingFaceEmbeddings(
    model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1",
    model_kwargs={"device": "cpu"},
)

LangChain 与 PGVector 集成。因此,要将 LangChain 连接到 Postgres,我们需要将字符串连接定义如下:

from langchain.vectorstores.pgvector import PGVector

CONNECTION_STRING = PGVector.connection_string_from_db_params(
     driver="psycopg2", # driver to connect with postgres
     host="localhost", # host defined in docker-compose.yml
     port="5432", # port defined in docker-compose.yml     
     database="postgres", # database defined in docker-compose.yml
     user="admin", # user defined in docker-compose.yml
     password="root", # password defined in docker-compose.yml
)

请注意,COLLECTION_NAME 必须唯一,因为 PGVector 将使用它作为键来识别从 Postgres 检索的文档。对于我们的用例,我们可以将 COLLECTION_NAME 视为会议 ID,这样可以从用户感兴趣的会议中检索信息。

COLLECTION_NAME = "Meeting ID"

在定义编码器、连接和集合名称之后,我们将来自 Whisper 或 WhisperX 的转录内容转换为文档(LangChain 所期望的格式)。我们还创建并填充了一个包含嵌入的表。

from langchain.docstore.document import Document

# Transform transcription into documents and add the start and end time of each sequence 
docs = [Document(page_content=f'start {item["start"]} - end {item["end"]}: {item["text"]}') for item in result['segments']]
db = PGVector.from_documents(
 embedding=encoder,
  documents=docs,
  collection_name=COLLECTION_NAME,
  connection_string=CONNECTION_STRING,
  pre_delete_collection=True,  # deletes previous records, useful for testing
)

在创建并填充表格后,我们可以查询数据库,并使用 LangChain 获取最相似的文档:

similar_docs = db.similarity_search("<USER QUERY>")

或者我们也可以去 pdAdmin 并查询 Postgres 以查看嵌入和文档的样子:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:在 pgAdmin 中查询嵌入向量和文档(作者提供的图像)

多语言语义搜索有效吗?

我们将 Luis 谈论两个不同主题的两个视频转换为音频。在第一个视频中,Luis 用葡萄牙语谈论他的前一家公司;在第二个视频中,他谈论了概率深度学习。然后,我们使用葡萄牙语和英语查询这两个视频,并比较检索到的文档。

对于葡萄牙语用例,我们使用了以下查询

  • 葡萄牙语:marcas e investimentos

  • 英语:brands and investments

因此,前四个最相关的文档在两个案例中都是相同的:

开始 81.401 — 结束 85.26:一个总部位于柏林的品牌,因此我们没有等待这个投资到来,

开始 111.902 — 结束 117.26:为了成熟产品、投资技术和进行业务开发。

开始 88.58 — 结束 93.039:这两轮投资我们已经做了,能力也有所不同开始

28.6 — 结束 32.64:因此,我们提供了品牌所需的所有物流基础设施组件。

对于英语用例,我们使用了以下查询

  • 葡萄牙语:modelos de aprendizagem profunda

  • 英语:deep learning models

我们必须检索八个文档以找到相关的文件。这是因为在葡萄牙语中,我们通常不翻译 Deep Learning;我们使用英语表达。因此,模型可能没有足够的数据进行训练。

开始 45.28 — 结束 51.9:当我们使用深度学习模型时,我们通常依赖于最大似然估计

另一方面,以下查询的前 4 个结果相同

  • 葡萄牙语:distribuição normal

  • 英文: normal distribution

这表明,对于经常翻译的术语,例如‘normal distribution’到*‘distribuição normal’*,我们的方法能够产生相关输出。

结论

多语言音频转录和语义搜索是构建更加互联世界的重要资产。我们的例子只是冰山一角;还有许多技术可以结合使用,以应对不同的应用场景。

考虑一种使用检索增强生成(RAG)系统进行客户支持的场景。通常,在客户支持系统中,客户用任何语言提问。我们可以用多语言模型对这些问题进行编码,并使用检索器从客户服务专家那里提取相关的过往回答作为上下文。大型语言模型(LLM)使用这些上下文生成翻译成客户语言的答案。该系统有效地减少了客户服务专家的工作负担,并提供了快速、实时的客户支持。

尽管我们的方法提供了广泛的可能性,但它并不是万能的解决方案。例如,在我们的实验中,检索器未能将“Deep Learning”与其葡萄牙语对应词“Aprendizagem Profunda”语义关联起来。克服这些限制需要使用特定数据进行微调或实施基于规则的机制,以提高文档检索的准确性,特别是在不同语言之间。

保持联系: LinkedIn, X/Twitter, Medium

参考文献

[1] Max Bain, Jaesung Huh, Tengda Han, Andrew Zisserman. WhisperX: 精确时间语音转录长篇音频。arXiv:2303.00747, 2023

[2] Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. 通过大规模弱监督的稳健语音识别。********arXiv:2212.04356, 2022

[3] Nils Reimers, Iryna Gurevych. 通过知识蒸馏将单语句子嵌入转换为多语种。arXiv:2004.09813, 2020。

[4] Nils Reimers, Iryna Gurevych. Sentence-BERT: 使用 Siamese BERT 网络的句子嵌入。arXiv:1908.10084, 2019。

[5] Alexis Conneau, Kartikay Khandelwal, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer, Veselin Stoyanov. 大规模无监督跨语言表示学习。arXiv:1911.02116, 2019。

Rise Up! 使用数据和 Home Assistant 为我的站立式办公桌建立警报系统

原文:towardsdatascience.com/rise-up-building-an-alert-system-for-my-standing-desk-using-data-and-home-assistant-a7574236f579

将微处理器、Home Assistant、Grafana、InfluxDB 和 Telegram 集成,为桌子提供智能化和更健康的工作环境

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Juan De Dios Santos

·发表于 Towards Data Science ·11 分钟阅读·2023 年 5 月 25 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 DALL-E 生成。提示:“一个带有笔记本电脑的站立式办公桌。笔记本电脑的图像需要是折线图。”

我们都很清楚长期坐着的健康风险。这可能导致肌肉退化、背部问题、糖尿病风险增加等(source)。是的,情况确实很糟糕。然而,尽管有这些有害影响,我们中的许多人——包括我自己——还是会长时间坐着。我们这样做是因为我们喜欢这样,或者因为我们的工作需要这样,就像我一样。

为了应对这些健康风险,我买了一个站立式办公桌。我非常喜欢这个桌子。它看起来很酷,并允许我配置高度预设,我在一天中切换这些预设。然而,我必须承认,有时候桌子会一直保持在最低设置,这反映了我缺乏运动。为了应对这个问题(同时也作为一个有趣的借口来启动一个新项目),我在桌子下安装了一个微处理器。这个微处理器监控桌子的高度,它是一个流的入口,最终会通过 Telegram 发送通知提醒我,如果桌子的高度保持在我定义的“坐着”预设太久的话,就提醒我站起来。

我给附加的微处理器配备了一个距离传感器,以便随着时间的推移跟踪其高度并将其记录在 Home Assistant 中,Home Assistant 是一个开源的家庭自动化平台,作为智能家居设备的中心枢纽。我使用的微处理器是来自 SparkFun 的 ESP32 Thing Plus。这种设备包括一个 WiFi 模块,支持开发网络服务器。对于这个项目,我设置了一个网络服务器,具有一个返回传感器测量距离的端点。在这种情况下,这个测量值是桌子到地板的距离——这些数据用于分析我的桌子使用情况,并在长时间坐着后发出警报。本文解释了我是如何做到的。

设置

这个项目的核心是我的站立桌。它是一款非常普通的站立桌,配有一个用于调整高度的控制器和四个用于设置特定高度预设的按钮。

SparkFun Thing Plus ESP32 微控制器是另一个关键组件。在所有功能中,与此项目相关的是 WiFi 收发器和 SparkFun 的 Qwiic 连接系统,它允许在不需要焊接的情况下连接传感器。我用来测量桌子高度的传感器是 SparkFun Distance Sensor — 1.3 Meter, VL53L4CD (Qwiic)。它通过发射红外激光并计时目标的反射来测量距离。

我将微处理器和传感器放置在桌子下方。由于我的地板和桌腿都不反光,我在传感器下方的桌腿上粘贴了一小块箔纸,以便反射传感器的激光。因此,我实际上是在测量桌子与其底座之间的距离,大约是距离地面 5 厘米。

我将传感器的数据存储在我的 Home Assistant 安装中。Home Assistant 是一个功能复杂的系统,具有许多功能和特性。就这个项目而言,我们需要知道的是,我已经将其安装在一个连接到我的家庭网络的 Raspberry Pi 上,它可以发出警报并将其发送到不同的平台,并且它可以托管 InfluxDB(一个处理高量时间戳数据的时间序列数据库)和 Grafana(一个数据可视化和监控工具)。

你可以在这里了解更多关于 Home Assistant 的信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我的(杂乱的)设置。左侧是桌子(看看箔纸吗?),右侧是微处理器和传感器。我意识到它有多丑。

微处理器程序

桌子可能是这个项目的核心,但微处理器程序无疑是大脑。这个程序,我编写了用来跟踪桌子高度的程序,在 ESP32 板上设置了一个网络服务器,通过一个端点提供由 SparkFun VL53L1X 距离传感器采集的距离测量数据。

这里是代码:

#include <WiFi.h>
#include <WebServer.h>
#include <Wire.h>
#include "SparkFun_VL53L1X.h"

// Replace this with the WiFi's SSID
const char* ssid     = "SSID";
// Replace this with the WiFi password
const char* password = "PASSWORD";

// Listening on port 80
WebServer server(80);

#define SHUTDOWN_PIN 2
#define INTERRUPT_PIN 3

SFEVL53L1X distanceSensor(Wire, SHUTDOWN_PIN, INTERRUPT_PIN);

void setup() {
  Serial.begin(115200);
  // Connect to WiFi
  WiFi.begin(ssid, password);

  while (WiFi.status() != WL_CONNECTED) {
    delay(1000);
    Serial.println("Connecting to WiFi...");
  }

  Serial.println("Connected to WiFi");

  // Set the distance sensor
  Wire.begin();
  if (distanceSensor.init() == false) {
    Serial.println("Distance sensor is online.");
  }

  server.on("/distance", [](){
    // Measure the distance
    distanceSensor.startRanging();
    int distance = distanceSensor.getDistance(); // Distance is in mm.
    distanceSensor.stopRanging();
    server.send(200, "text/plain", String(distance));
  });

  // GET /ping is just a health check
  server.on("/ping", [](){
    server.send(200, "text/plain", "ok");
  });

  server.begin();
  Serial.println("HTTP server started");
}

void loop() {
  server.handleClient();
}

我开始编写代码,导入必要的库,包括 WiFi 模块和距离传感器库。在库之后,我们定义了两个常量:WiFi 的 SSID 和密码。接下来,我们创建 Web 服务器,指定其端口,并初始化传感器。然后是setup()函数,其中包含主要逻辑。这个函数的前半部分是启动 WiFi 连接和距离传感器。一旦两者都准备好,它将继续为 Web 服务器设置两个路径。

第一个路径,GET /distance,检索传感器的距离测量值并以纯文本形式返回(例如,695)。

第二条路径,GET /ping,是一个健康检查,以确保程序正在运行。尽管第一个端点也可以完成这个目的,但我专门创建了一个不同的端点用于健康检查。

现在,我们需要将程序上传到 Arduino。我在这里不会深入探讨如何做,但如果你需要帮助,可以参考这个指南support.arduino.cc/hc/en-us/articles/4733418441116-Upload-a-sketch-in-Arduino-IDE

一旦程序运行,你可以使用类似curl -X GET http://192.168.1.XXX/distance的 cURL 命令进行测试,只要执行命令的设备与微处理器在同一网络上。请注意,你需要将XXX替换为微处理器的实际 IP。找到 IP 的一种方法是查看路由器控制面板上的连接设备列表。

从 Home Assistant 消耗端点

下一步是使用sensor组件将/distance端点与 Home Assistant 集成。为了简便起见,我假设你已经安装了 Home Assistant 并熟悉其基础知识。

传感器组件监控实体的状态和条件,这可以是物理传感器或像我们创建的端点。要设置此功能,你需要通过文件编辑器或控制台访问 Home Assistant 的配置文件configuration.yaml。进入文件后,查找 YAML 文件中现有的sensor键,如果没有,则创建一个。在此键下,添加以下内容:

sensor:
  - platform: rest
      name: Desk Distance
      unique_id: desk_distance
      unit_of_measurement: "mm"
      resource: [`192.168.1.XXX/distance`](http://192.168.1.XXX/distance)

此配置在传感器组件中设置了一个RESTful平台——一个消耗 REST 端点的平台。它有四个值:

  • name:端点的描述性名称。

  • unique_id:传感器的唯一标识符。

  • unit_of_measurement:传感器的测量单位。在我们的案例中,它是“mm”,因为距离传感器以毫米为单位进行测量。

  • resource:要使用的端点的 URL。

还有一个名为method的可选字段,用于指定请求的 HTTP 方法。我没有使用它,因为它默认为GET,这是我们端点的方法。

现在保存文件并重新加载以应用更改。为确保其正常工作,考虑在仪表盘中创建一个新模块以显示传感器读取的值。或者,您可以在开发者工具的“states”选项卡中找到该实体。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从开发者工具的“states”选项卡中可以看到传感器。

下一部分解释了如何使用 InfluxDB 和 Grafana 来可视化数据。

使用 Grafana 和 InfluxDB 可视化数据

在开始这个项目之前,我设想了用一个可视化工具显示桌面的实际升降。幸运的是,对于 Home Assistant 社区和我来说,有一个成员创建了一个插件,可以无缝地将InfluxDBGrafana集成到平台中。InfluxDB 是一个高性能的开源时间序列数据库,能够高效地存储和管理大量时间戳数据,非常适合跟踪桌面的高度变化。Grafana 是一个开源的数据可视化和监控平台,允许用户创建交互式、自定义的仪表盘。将 InfluxDB 和 Grafana 结合起来,可以实时收集、存储和可视化桌面的移动数据,提供了跟踪桌面高度的无缝体验。有关如何安装它们的说明,请参考InfluxDBGrafana的文档。

安装和设置(包括在 Grafana 中添加 InfluxDB 作为数据源)完成后,导航到 Grafana 的仪表盘以创建新的仪表盘和面板;这个面板就是我们将可视化桌面高度的地方。页面底部是查询区域,我们将在这里定义从 InfluxDB 获取数据的查询。首先,从“数据源”菜单中选择“InfluxDB”。然后,按照如下方式填写空白:

  • FROM:选择 InfluxDB 数据源的名称(我的为default)。在旁边的字段中,选择“mm”——这是存储单位为毫米的传感器数据的表的名称。

  • WHERE:使用此子句仅过滤实体desk_distance

  • SELECT:从菜单中选择value。我还应用了mean()聚合函数和数学公式math(/1000),将 Y 轴刻度转换为米。

  • GROUP_BY:使用 time($_interval)$_interval 分组数据,其中 $_interval 是由 Grafana 计算的时间间隔(参见解释这里),而 fill(linear) 使用线性插值填充指定时间范围内的任何缺失数据点。我使用这个方法来避免在我的可视化中出现间隙。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上述查询。

配置这些参数后,你将得到一个类似于这样的图表:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我的图表是这样的。

这个图表表示两天的数据。y 轴显示桌子的高度(单位:毫米),而 x 轴显示时间。你可以清楚地看到我最常使用的两个预设:坐着的预设,桌子高度约为 700 毫米,和站立的预设,将桌子升高到近一米的高度。

发送警报到 Telegram

这个项目的最终目标是开发一个提醒我在长时间坐着后站立的通知系统。通过使用从距离传感器收集的数据,Home Assistant 的警报集成可以在事件发生时发送这些提醒。这些通知通过另一个集成通知发送,该集成支持多种平台,包括 Twilio、电子邮件和 Telegram——我将在这个项目中使用它。

从 Home Assistant 设置一个 Telegram 的警报涉及两个关键步骤:创建一个 Telegram 机器人来接收警报,以及在 Home Assistant 中定义警报。

创建机器人

我使用了一个现有的 Telegram 机器人BotFather来创建我的警报机器人。要开始,打开一个新的聊天窗口并搜索 BotFather,开始对话。开始对话后,按照屏幕上的说明操作,这些说明大多是关于你的新机器人的问题,例如其名称。BotFather 然后会提供一个 API 令牌来控制你最近创建的机器人,以及一个启动 Telegram 对话的链接。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

BotFather 的二维码

接下来,你需要从与你的机器人对话中获取 chat ID。我通过 curl 请求 $ curl -X GET https://api.telegram.org/bot<YOUR_API_TOKEN>/getUpdates 获得了这个 ID(在此之前,你需要给机器人发送一条消息),其中 YOUR_API_TOKEN 是你从 BotFather 获得的 API 密钥。API 响应是一个 JSON 对象,其中包含一个 chat 对象和一个名为 id 的字段,包含 chat ID。

你也可以使用另一个名为GetIDs的机器人来获取 ID,它提供关于聊天和消息的信息。不过,我没有尝试这种方法。

定义警报

最后一步涉及在 Home Assistant 中定义警报。打开configuration.yaml文件,并按如下方式配置机器人:

telegram_bot:
  - platform: polling
    api_key: YOUR_API_KEY
    allowed_chat_ids:
      - CHAT_ID_1 # your chat_id
      - CHAT_ID_2 # Optional. You can also add another chat!

接下来,创建通知器,通过你刚刚配置的机器人发送通知:

notify:
  - platform: telegram
    name: NOTIFICATION_NAME
    chat_id: chat_id
  - platform: telegram # Optional. Add another telegram platform if you wish to notify another chat.
    name: NOTIFICATION_NAME_2
    chat_id: chat_id_2

最后,在automation键下定义警报本身(如果没有,则创建它):

automation:
  - alias: "Desk Distance Alert 45 minutes"
      trigger:
        platform: numeric_state
        entity_id: sensor.desk_distance
        below: 900
        for:
          minutes: 45
      action:
        service: notify.HA_NOTIFICATIONS
        data:
          message: "You've been sitting for 45 minutes. Get up!"

我的警报名称是桌面距离警报 45 分钟,正如名称所示,当我的桌面降低超过 45 分钟时,它会发送通知。

我在trigger中定义了“45 分钟”条件。在这里,我们需要一个numeric_state自动化,当实体的状态满足条件时触发一个动作。我的条件是在桌面高度below 900mm for 45 minutes 时触发。你可以根据需要自定义这个条件。另外,注意entity_idsensor这个词后跟你用来命名传感器的 id,例如sensor.desk_distance。我的触发器伴随一个action,当触发器启动时执行。这个action使用我们之前创建的NOTIFICATION_NAME通知服务来发送通知。

我还添加了一个message,名字带有鼓励性质,希望能激励我起身。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通知

为了测试这个过程,创建一个在一分钟内触发的临时警报,或访问https://your_ha_location/config/automation/dashboard的自动化仪表板。如果警报出现在列表中,点击右侧的三点符号,手动选择“运行”以触发它。如果一切顺利,你应该会收到来自 Telegram 的通知。

结论

作为一个在桌子后面坐了大半天的人,我发现解决这种对健康的负面影响至关重要。认识到这一点,我被驱动去寻找一个解决方案——一种提醒自己休息、移动和站立的方法。

通过整合 ESP32 微处理器、SparkFun VL53L1X 距离传感器、Home Assistant、Grafana、InfluxDB 和 Telegram,我打造了一个实时跟踪桌面高度、可视化数据并在我坐得太久时提醒我的系统。

但这个项目不仅仅是关于我的升降桌;它还涉及利用技术工具(以及我放在箱子里很久的传感器)来做出小的改变,从而改善我们的日常习惯。无论你是沉浸在工作中还是迷失在自己的项目里,都不应忽视健康(当然是给自己提个醒)。幸运的是,借助合适的工具和一点数据,我们可以建立既有助于生产力又有助于健康的系统。

一张图表中的博弈论与风险管理

原文:towardsdatascience.com/risk-mapping-c6bb3eb4ae29?source=collection_archive---------13-----------------------#2023-03-27

数据可视化

风险管理原则介绍

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gabriel de Longeaux

·

关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 3 月 27 日

从数据科学家的角度出发,使用期望效用理论介绍基本的风险管理原则。风险管理策略在一张图表中进行总结(图表的详细解读在下文中),其中体现了等风险曲线。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

风险地图与风险对象(G. de Longeaux)

一个简单的风险情况

首先假设我们有一个总财富 W₀。如果我们在不久的将来面临丧失一个金额 L(显然小于 W₀)的风险,概率为 p,我们不能认为我们的财富仍然是 W₀(当 p = 1 时显然如此)。相反,我们应该计算在风险环境下,博弈论中所称的“确定等价物”。

如何减少财富的风险

确定性等价是指金额 Wₑ,使得 Wₑ 的效用与我们资产面临损失的效用相同。如果我们是风险中性的,使用的效用函数就是身份函数,因此确定性等价只是我们当前财富减去损失期望(即使没有效用函数的概念,这也很直观):证明是,如果 Wₑ 的效用(即 Wₑ)等于我们财富面临损失的效用(在博弈论中称为“彩票”),则 Wₑ = (1-p) W₀ + p (W₀-L) = W₀ - pL(在博弈论中,彩票的效用是彩票结果的期望效用),因此 Wₑ= W₀ - pL。如果我们是风险厌恶的,我们可以例如使用对数作为效用函数,因此 ln(Wₑ) 等于 (1-p) ln(W₀) + p ln(W₀-L),这意味着 Wₑ = W₀^(1-p) (W₀-L)^p。

风险成本是什么

注意,我们的表示非常简单,因为我们只有一个概率 p 在不久的将来损失正好为 L(没有应用贴现率),但这足以说明概念。现在我们需要问自己风险的成本是什么,这非常简单:即我们在暴露于任何损失之前的财富 W₀ 与因风险暴露而现在拥有的财富 Wₑ 之间的差异。对于风险中性的人,风险成本 r = W₀ - Wₑ = W₀ - (W₀ - pL) = pL,这很直观,因为它只是期望损失。对于风险厌恶的人(使用对数效用函数),风险成本会稍微复杂一些,r 等于 W₀-W₀^(1-p) (W₀-L)^p = W₀ [1 - (1 - L/W₀)ᵖ]。

绘制风险图

很明显,不同的损失金额 L 和损失概率 p 的组合可以关联到相同的风险成本 r。我们应在图表上绘制出所有这些表示相同风险成本 r 的点,其中损失金额 L 在 y 轴上表示,概率 p 在 x 轴上表示。为了找到所有点关联相同风险 r 的曲线方程(这就是这些曲线通常被称为“等风险曲线”的原因),我们只需将 r 设为常数,并推导出 L 关于 p 的表达式。对于风险中性实体,我们从 r=W₀-Wₑ = pL 开始,这意味着 L = r/p。对于风险厌恶的实体,我们从 r = W₀-Wₑ = W₀ - W₀^(1 - p) (W₀ - W)^p 开始,最终得到 L=W₀ [1 - (1-r/W₀)^(1/p)](对于 p = 0 这没有定义,但这个情况不有趣,除了看到与 Allais 悖论相关的不连续性:即使 p 极小,我们应该根据所选择的效用函数考虑风险,但实际上我们永远不会考虑如此小的风险——我们将忽略这一点,因为这与我们主要关注的问题无关)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

等风险曲线 (G. de Longeaux)

为了绘制图表,我取了 W₀ = 5。我们可以看到不同风险成本 r 值下的风险中立和厌恶风险个体的等风险曲线。

探索地图

表示风险对象

这就是一切变得有趣的地方。企业拥有暴露于损失的工厂、仓库、商店、办公室……它们可以在我们创建的地图上标出这些地点:例如,一个暴露于 40%概率损失 2 的仓库将在地图上显示为点(x = 0.4, y = 2)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

风险对象的等风险曲线(G. de Longeaux)

根据图表上绘制的等风险曲线,我们可以看到风险成本略低于 r = 1(假设 r = 0.9),如果我们假设公司是厌恶风险的。然而,并不是所有公司都厌恶风险:保险公司接近于风险中立(忽略其费用),因为它们可以通过大数法则来分摊风险。对于保险公司来说,风险成本 r 会更低(假设 r = 0.8),因此保险公司将能够提供一个交易:仓库的所有者将支付 0.85 的保险费,将风险转移给保险公司,从而节省 0.9 作为风险成本。最终,保险公司将获得 0.85 以承担 0.8 的成本,从而赚取 0.05,而另一家公司将支付 0.85 以摆脱 0.9 的成本,也赚取 0.05。

确定战略区域

进一步探索图表,我们可以识别出四个主要区域:

  • 位于风险地图左下角的风险对象(工厂、仓库、商店、办公室等)具有非常好的风险概况:保险费用便宜(保险费始终低于 1)。

  • 位于左上角的风险对象具有“严重损失”特征,这意味着可能会发生大损失。然而,由于损失概率较小(低于 40%),风险概况仍然良好,保险非常有用,因为相对于风险成本来说,它非常便宜:一个成本为 3.5 的风险可以以 1.5 的费用投保(-57%)。

  • 位于右下角的风险对象具有“概率损失”或“频率损失”特征,因为损失发生的概率很高(超过 40%)(注意在我们简单的表示中,彩票是一个简单的伯努利试验,损失概率和损失频率是相同的,但一般情况下并非如此)。在这种情况下,风险转移似乎并不十分有趣,因为企业和保险的风险成本基本相同。通常,对于很可能发生的小损失,几乎没有保险的兴趣。如果有的话,公司可以通过自保(通过自保公司)或使用免赔额来减少总体支付的保险费:保险公司不会支付发生在这些地点的低于 1.5 的损失。

  • 地图右上角的风险对象具有较差的风险状况:损失既大又高度可能。即使考虑到风险成本,保险费用也很高。例如,一个风险成本为 4.5 的对象可以以 4 的价格投保(-11%)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

风险状况和保险策略(G. de Longeaux)

应对风险

预防和保护

在高风险情况下,应采取措施减少可能损失的严重性(保护措施)或降低损失发生的概率(预防措施)——或者两者兼顾。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

风险管理和保险策略原则(G. de Longeaux)

帮助公司进行预防和保护措施,以改善其风险状况、保护业务和拯救生命,这是风险工程师的工作——数据科学家可以支持他们的工作。

风险转移

采取适当的措施将有助于最大限度地利用风险转移策略,通过承担减少的风险成本或以有利价格投保风险对象(尽管短期内保险费用不直接取决于风险状况,因为市场条件的变化,但长期内是会有影响的)。

结论

通过一个非常基本的统计模型结合博弈论,我们能够轻松理解何时以及为何投保是有利的,并定义减少风险的主要策略及其对我们从保险中获得的收益的影响。所有结论都可以在一个图表中展示,其中可以同时识别支付的保险费和投保的财务收益。

R 代码

如果你想使用 R 重新制作或改进图表,代码如下:

#### Graphical options ####
background <- TRUE
drawArrows <- TRUE # Arrows are drawn only if background is displayed
drawPoints <- TRUE
nb_cases <- 10
sharpness <- 10000

#### Total wealth ####
x0 <- 5

#### Plot ####
r <- x0/10
p <- seq(from = 0, to = 1, by = 1/sharpness)
n <- length(p)
xNeutral <- r/p # Loss estimates for a risk-neutral entity
xAverse <- x0 * (1 - (1-r/x0)^(1/p)) # Loss estimates for a risk-averse entity

plot(x = p, y = xAverse, type = "l", col = "red",
    xlim = c(0, 1.01), ylim = c(0, 1.1 * x0),
    xlab = "", ylab = "",
    xaxs = "i", yaxs = "i",
    main = "Iso-risk curves for risk-neutral and risk-averse entities")
title(sub = "Risk cost r is defined as the difference between the current wealth \n and the certainty equivalent in a risky environment", cex.sub = 0.8)

for (r in seq(from = x0/nb_cases, to = x0, by = x0/nb_cases)){
  xNeutral <- r/p
  xAverse <- x0 * (1 - (1-r/x0)^(1/p))
  lines(x = p, y = xAverse, col = "red")
  lines(x = p, y = xNeutral, col = "blue")
  text(x = p[n] - 0.03, y = xAverse[n] - x0/125, label = paste("r =", round(r, digits = 2)), cex = 0.8)
}

if (background){
  rect(xleft = 0, ybottom = 0, xright = 0.4, ytop = 0.3*x0,
      col = rgb(red = 30.59/100, green = 89.41/100, blue = 30.59/100, alpha = 0.3), border = "transparent")
  text(x = 0.175, y = 1 * x0/5, label = "Low risk: inexpensive insurance", col = "darkgreen", cex = 0.8)

  rect(xleft = 0.4, ybottom = 0, xright = 1.01, ytop = 0.3*x0,
      col = rgb(red = 25.1/100, green = 72.55/100, blue = 100/100, alpha = 0.3), border = "transparent")
  text(x = 0.707, y = 0.22 * x0/5, label = "Frequency losses:", col = "blue", cex = 0.8)
  text(x = 0.8, y = 0.08 * x0/5, label = "self-insured or inexpensive insurance", col = "blue", cex = 0.8)

  rect(xleft = 0, ybottom = 0.3*x0, xright = 0.4, ytop = x0,
      col = rgb(red = 25.1/100, green = 72.55/100, blue = 100/100, alpha = 0.3), border = "transparent")
  text(x = 0.09, y = 1.79 * x0/5, label = "Severity losses:", col = "blue", cex = 0.8)
  text(x = 0.174, y = 1.65 * x0/5, label = "relatively inexpensive insurance", col = "blue", cex = 0.8)

  rect(xleft = 0.4, ybottom = 0.3*x0, xright = 1.01, ytop = x0,
      col = rgb(red = 100/100, green = 25.1/100, blue = 25.1/100, alpha = 0.3), border = "transparent")
  text(x = 0.76, y = 1.65 * x0/5, label = "High risk: expensive insurance", col = "darkred", cex = 0.8)

  title(xlab = "Loss probability", line = 2, cex.lab = 1)
  title(ylab = "Loss estimate", line = 2, cex.lab = 1)
}else{
  title(xlab = "Loss probability", line = -1, cex.lab = 1)
  title(ylab = "Loss estimate", line = -1, cex.lab = 1)
}

if ((background) && (drawArrows)){
  arrows(x0 = 0.55, y0 = 2.5 * x0/5, x1 = 0.3, y1 = 2.5 * x0/5, length = 0.1, col = "gray22")
  text(x = 0.46, y = 2.6 * x0/5, label = "Prevention", col = "gray22", cex = 0.8)
  arrows(x0 = 0.55, y0 = 2.5 * x0/5, x1 = 0.55, y1 = 1 * x0/5, length = 0.1, col = "gray22")
  text(x = 0.61, y = 2 * x0/5, label = "Protection", col = "gray22", cex = 0.8)
}

if(drawPoints){
  points(x = 0.4, y = 0.4 * x0, pch = 16, col = "gray22")
  arrows(x0 = 0.4, y0 = 0, x1 = 0.4, y1 = 0.4 * x0, length = 0, col = "gray22", lty = 3)
  arrows(x0 = 0, y0 = 0.4 * x0, x1 = 0.4, y1 = 0.4 * x0, length = 0, col = "gray22", lty = 3)
  text(x = 0.4, y = 0.425 * x0, label = "Warehouse", cex = 0.8, col = "gray22")
}

text(x = 0.083, y = 1.02 * x0, label = "Current wealth", cex = 0.8)
legend("bottomleft", legend = c("Risk-neutral", "Risk-averse"), col = c("blue", "red"), pch = c("_", "_")) 

RLHF: 来自人类反馈的强化学习

原文:towardsdatascience.com/rlhf-reinforcement-learning-from-human-feedback-faa5ff4761d1

ChatGPT 成功的关键:指令数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ms Aerin

·发表于Towards Data Science ·24 分钟阅读·2023 年 10 月 11 日

ChatGPT 凭借其令人印象深刻的能力吸引了全世界的关注。但它是如何变得如此聪明的呢?

我最近和一位我非常尊敬的前同事——一位软件工程师进行了交谈,我注意到他认为 ChatGPT 是 AGI 的体现,并将其将复杂主题简化到六岁孩子理解水平的能力作为证据。虽然我对它的不合理智能并不完全不同意,但我觉得有必要表达一下我的想法。在这篇文章中,我想强调 ChatGPT 的魔力在于其训练数据。

精心策划的指令数据是 ChatGPT 类人能力的关键。诸如向 6 岁孩子解释概念、将简历转化为 LinkedIn 资料、与您头脑风暴等功能并不是偶然出现的——它们是被刻意编码到模型中的训练数据。

发过几次这样的推文后,也许是时候写一篇长文了…

和其他人一样,这是我第一次接触封闭研究。自大学以来,所有前沿研究都是开放和同行评审的,直到最近。我相信开放性最终比封闭性更能推动科学进步。

如果我们旨在通过开源来匹配 ChatGPT 的表现,我相信我们需要更加认真对待训练数据。ChatGPT 的有效性很大程度上可能并不是来自于特定的 ML 架构、微调技术或框架。而更可能的是来自于指令数据的广度、规模和质量。

直截了当地说,在平庸的指令数据上微调大型语言模型是一种浪费计算资源。让我们看看训练数据和学习范式中发生了什么变化——我们现在如何以不同的方式格式化训练数据,因此与过去的大规模预训练相比,学习也发生了不同的变化。

什么是 RLHF?

RLHF 代表来自人类反馈的强化学习。它有两个主要组成部分:

  1. 强化学习(RL)

  2. 人类反馈(HF)

到底训练的是什么?

历史上,当我们谈论 LLM 训练时,我们只意味着更新语言模型的参数。然而,当我们使用 RLHF 时,我们训练三个独立模型的参数。 这种方式提供了更多的自由,因为它不受限于最大似然框架(详细信息见[我们为何在 LLM 中尝试 RL?]部分),并且我们直接从数据本身学习目标函数。

这里有三个正在训练的模型:

  1. 语言模型(SFT 模型)

    是一个像 GPT-3 这样的预训练的大型语言模型。该模型已经经过训练,稍后将基于指令数据进行微调。

  2. 奖励模型

    训练以预测人类偏好并提供奖励信号以强化代理。它是基于人类反馈数据进行训练的。

  3. 策略模型 (代理)

    通过最大化预测奖励来训练生成令牌。为此,它使用了以奖励模型作为反馈来源的强化学习。策略模型是从 SFT 模型初始化的。

LLM 的预先存在的权重在 RL 阶段进行调整和微调,在这个阶段,模型优化其行为(生成令牌)以最大化奖励(良好的人类反馈)。

关于 RLHF 的开创性论文是InstructGPT,它是去年由 OpenAI 发布的。认识到 InstructGPT 模型的强大,OpenAI 将所有公共 API 从使用原始模型切换到使用指令模型。随后,他们减少了详细描述进一步进展的学术出版物,将研究转移到内部。我将在这个博客中主要使用 InstructGPT 的例子和方法。

RLHF 的关键创新:改变训练数据格式

在 RLHF / ChatGPT / InstructGPT 之前(我将这三个术语互换使用),像 GPT-3 这样的语言模型是使用交叉熵损失来预测下一个词的概率。

但预测下一个令牌的概率性是否是我们的最终目标?

绝对不是!ChatGPT 最令人印象深刻的方面是它能在自然语言中执行许多不同的任务,如释义、总结、分类等。这种广泛的能力使 ChatGPT 非常出色,并且与那些更专注于单一目的的机器学习模型相比,具有了‘惊叹’的因素。

那么,为了让语言模型执行各种任务而不仅仅是预测下一个词,我们需要做什么?

一般来说,如果你想改变模型的行为,你需要改变它的训练数据,无论是其内容、格式,还是两者都有。你也可以改变损失函数。ChatGPT 改变了这三个方面。

在深入 RLHF 的细节之前,我想展示 InstructGPT 团队如何不遗余力地创建了大量详尽的训练数据,使 ChatGPT 成为现实。

RLHF 中使用了两种类型的人类反馈。 一种是 指令数据,另一种是 人类偏好数据

1. 指令数据(即示范数据)

指令数据是输入和输出的配对,展示了给定输入时模型应该如何表现。

如果你想从头开始训练你的第一个 InstructGPT 模型,你不仅需要编写答案,还需要编写用户输入(用例)。因为直到去年,GPT-3 API 用户很少输入像向 6 岁孩子解释复杂概念这样的大胆提示。用户从未想过可以向模型提出这样的问题。这也是为什么指令数据也被称为“示范”数据。我们首先必须向语言模型展示用例。

让我们看看 InstructGPT 团队策划的各种用例(提示)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

InstructGPT 提供的用例示例

这里有一些有趣的用例来强调:

  • 封闭式问答用例有明确的正确和错误答案,如:
When you drop a heavy stone from a tree, what happens? 

A. The stone falls to the ground.
B: The stone stays in the tree.
C: The stone floats.
D: Nothing happens.

Answer:
  • 开放式问答用例会有主观性的回答:
Who was the best human who ever lived?

Answer:
  • 重写用例将需要标注者的创造力。
Convert my resume into a profile overview. 

{resume}

Profile overview: 

创建指令数据时的勤奋

让我们看看生成高质量指令数据需要什么。这是来自 InstructGPT API 提示分发的标注说明摘录。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

InstructGPT 的 API 提示分发标注说明摘录

这一长段文字是为“标注者”准备的。这是一份长文件,似乎有很多含义需要弄清楚。我们需要这套长说明,因为为了让标注者创建我们想要的指令数据,他们必须首先理解我们希望他们做什么,并且 遵循这些规则。

似乎有三条规则你应该遵循:

有帮助、真实且无害。

让我们来看看成为有帮助的标准。

**“回答他们本来想问的问题,即使他们问得不准确。”

这是一个巨大的要求。它要求标注者真正尝试帮助用户,而不是用“我不理解你”这样的回答来回避用户提出的错误问题。这类似于母亲尝试理解她的宝宝想要什么,即使宝宝没有准确地说出来。

“对国际性敏感(例如,“football”不应指美式足球,而“the president”不一定指美国总统)”

标注员应具备扎实的语言能力和对不同文化运作方式的良好理解。

所以,谁是这些能够认真遵循这些复杂指南的标注员呢? 他们肯定不是那些只能每天投入 1-2 小时的众包平台的兼职工人。根据我创建大规模训练数据的经验,随意的众包工人无法充分提供自然、细腻的对话,进而促使 ChatGPT 的卓越表现。

我更倾向于使用**“数据编写者”**这个术语,而不是“标注员”,因为它更能体现其中的创造力和细致入微。为了确保这些数据编写者提供你所需的高质量工作,你需要培训他们,与他们过度沟通,保持一致,审查他们的提交,给予反馈,并保留最优秀的编写者,让其余的离开。你需要能够信任你的编写者,因为你的 LLMs 的表现(“wow”因素、ChatGPT 对你问题的回答质量等)将基于他们的工作。虽然你是他们的老板,但你也严重依赖他们。这是一种迷人的共生关系,本身就是一种艺术。

InstructGPT 团队值得大力称赞,他们将这门艺术提升到了一个新的水平。他们的工作告诉我们,如果我们希望开源的 LLMs 能达到 ChatGPT 的表现,数据方面需要无懈可击。

2. 偏好数据

指令数据用于监督性微调(SFT)阶段(详细信息见下一部分)。另一半关键的训练数据是“偏好数据”。偏好数据用于在 RL 阶段训练奖励模型。这涉及到人类根据他们的偏好对不同的 LLM 生成的输出进行排名。偏好数据为正确与错误的行为提供训练信号。

当我阅读标注指南时,像“有帮助的”或“真实的”这样的标准对我来说有点不清楚。此外,如果我是一名标注员,我可能不太会仔细阅读这些指南,因为它们太长了。为了应对这一点,InstructGPT 团队付出了巨大努力,通过提供清晰的示例来培训标注员。这是影响期望模型行为的关键步骤。

这里是提供给标注员的示例,帮助他们理解“有帮助的”、“真实的”和“无害的”是什么意思。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

以无害性优先为例。好的,安全第一。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

阅读上面的“推理”部分。我认为对训练数据“有用”方面的重视,是 ChatGPT 中最重要的变化。这种注释数据的新方法使得 InstructGPT 与之前的研究区别开来。然而,也值得注意的是,同样的“有用”因素可能会导致 幻觉(稍后会详细讲解)****。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

上面三个示例来自 InstructGPT 的公开文档,展示了指令数据编写者所需的训练水平及其对模型行为的重大影响。

指令数据的非凡有效性

让我们比较两个模型的输出——一个是用指令数据训练的,另一个则没有。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模型比较:无指令训练 vs. 从openai.com/research/instruction-following进行的指令训练

在左侧,未经过指令数据训练的 DaVinci 原版(一个未经过指令数据训练的模型)未能理解“用几句话向 6 岁孩子解释登月”这个提示。它似乎无法理解用户的要求,而是提供了多个无关的回答,如解释进化。

另一方面,右侧的 instruct-DaVinci 模型能够正确回答用户的提示,虽然它比 gpt4 的回答简洁。 😃

我为什么要关心指令数据?

1. 理解指令数据的格式可以帮助你编写更好的提示。

你输入的提示与专有模型的指令数据越接近,输出效果就会越好。设计与模型训练数据相似的提示可以通过减少试错的时间来节省你的时间。

2. 它在一定程度上解释了幻觉倾向。

已经提出了各种原因来解释模型中的幻觉现象(对话模型幻觉的起源:是数据集还是模型?使大型语言模型生成带引用的文本通过总结评估大型语言模型的事实一致性等)。一些人认为,语言模型显示模式完成行为是因为它们被训练来最大化相邻文本的可能性。但这是否是 RLHF 中幻觉的唯一原因?

我认为我们不能忽视这样一个事实,即在偏好数据标注过程中,标注人员被指示优先考虑对用户的有用性而非真实性。但当我们进行最终评估时,我们会让标注人员把真实性放在首位。

再次参考示例 2,“优先考虑有用性而非真实性”。

这个例子展示了在人工偏好数据中对“有帮助”答案加权过重如何导致幻觉。为了减轻这种情况,我们可以生成更多优先考虑真实性和无害性的训练数据,而不是在某些情境下(如医学等高风险领域)只关注帮助性。平衡不同情况下的不同优先级可以帮助减少幻觉。

另一个可能导致幻觉的因素是模型不知道自己被允许表达不确定性。减少幻觉的一个重要步骤是激励模型用文字表达不确定性。这在 NLP 中一直是一个长期存在的问题,正如 SQUAD(斯坦福问答数据集)V2 通过在不确定时不回答的问题所体现的那样。因此,虽然 RLHF 是一个重要的进步,但一些 NLP 的重要问题,如如何处理不确定性,仍然没有完全解决。

好的,我们完成了数据部分。现在让我们看看 RLHF 的方法。

RLHF 的三步骤

OpenAI 总是分享这个简化的图示来解释 ChatGPT 是如何工作的。我希望现在你可以更好地理解在第 1 步中,次要子步骤“A 标注员展示了期望的输出行为”的意义。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

那个图示

第一步. 监督微调(SFT)初始化

RLHF 的第一步是监督微调(SFT),以初始化语言模型权重(图示中的第一列)。SFT 在指令数据上训练模型;克隆展示的对话行为。这一步为后续的强化学习做了准备。

你可以从预训练模型如 GPT-3 开始 SFT,就像 OpenAI 为 InstructGPT 做的那样。或者你也可以从头开始训练,然后继续前进。SFT 的输出为下一个强化学习阶段提供输入。

适当初始化的权重对于强大的下游任务表现至关重要,不仅仅在 RLHF 中如此,一般情况下也是如此。因此,SFT 模型的选择不是随意的。最佳的 SFT 模型将根据使用验证集的奖励模型得分来选择。

[InstructGPT 中的一些显著摘录]

最终的奖励模型是从一个 6B GPT-3 模型初始化的,该模型在各种**公共 NLP 数据集(ARC, BoolQ, CoQA, DROP, MultiNLI, OpenBookQA, QuAC, RACE, 和 Winogrande)**上进行了微调。这主要是出于历史原因;我们发现从 GPT-3 或 SFT 模型初始化 RM 时也会得到类似的结果。

我们发现我们的 SFT 模型在 1 个周期后会在验证损失上过拟合;然而,我们发现训练更多周期对 RM 得分和人工偏好评级都有帮助,尽管存在过拟合。

获取良好的指令数据可能很昂贵,特别是如果你没有成千上万的用户提交的种子提示。那么,如果你没有像商业企业那样的资源,你可以做什么呢?一个选择是使用公开的数据。上述提到的学术数据集、SQUAD V1、V2、StackOverflow、Quora 等都可能有帮助。你可以将这些数据转换以适应你的训练需求。

第 2 步:训练奖励模型

奖励模型的工作是返回一个表示人类偏好的标量,当给定一对(提示,答案)时。高分意味着被偏好,低分意味着不被偏好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

奖励模型的损失函数

当你看到方程时,它可能看起来不直接,但这实际上是一个简单的公式。让我们用真实的数据来看看。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

WordMakeover.com 用于有效的电子邮件写作

x = 输入、问题或提示

y_w = 赢的输出

y_l = 输的输出

K = 输出数量(这里为 7,因为有 7 个 LLM 结果)

θ = 正在训练的奖励模型参数

r_θ = 来自模型的奖励分数(标量)

现在我们知道方程中的每个变量了,让我们理解为什么这个损失函数是这样的。假设最右侧的项,即 r_θ(赢的对比对)和 r_θ(输的对比对)之间的差值,持有一个特定值。sigmoid 将使这个差值落在 0 和 1 之间。

视觉化 sigmoid 函数后的对数图形在 0 和 1 之间。当输入接近零时,它骤降至负无穷,而当输入接近一时,它上升至零。从中可以看出,如果模型给输掉的对比对分配了比赢得的对比对更大的奖励值,那么模型将受到重大的惩罚。

对所有 7C2 对进行这种操作,然后取平均值。这就是你想要最小化的损失。

对于那些喜欢代码的人:

class RewardTrainer(Trainer):
    # Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
    def compute_loss(self, model, inputs, return_outputs=False):
        rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
        rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
        loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
        if return_outputs:
            return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
        return loss

# https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py

奖励模型从 SFT 模型初始化。然后我们移除最终的嵌入层,添加一个给出标量的线性层。

从大小上看,奖励模型通常比语言模型小。例如,InstructGPT 使用了一个 175B 参数的语言模型,但使用了一个 6B 参数的奖励模型。团队报告说,175B 奖励模型的训练不稳定,使其不太适合作为 RL 期间的价值函数。

排名的目的是什么?

排名使得比较两个输出变得简单。有了 n 个输出,排名可以通过一次标注轻松生成 nC2 对。

二元选项的一个缺点是缺乏细粒度。它们无法捕捉输出 A 相对于 B 的优越程度。 而且没有量化这种差异,错误无法根据严重程度精确地惩罚。 另一种选择是让标注员给出整数或浮点数,但这非常主观,并且很难在不同标注员之间进行校准。

有人能想到更好的方式来表述偏好问题吗?😃

步骤 3:使用 RL 优化针对奖励模型的策略

这一步骤可以用一句话概括:LLM 参数和策略是联合优化以最大化从奖励模型中获得的期望奖励。

我们为什么在 LLM 中尝试 RL?

过去,语言模型很少使用 RL 进行优化。相反,它们依赖于信息论损失函数,如交叉熵,使用最大似然进行优化。

尽管最大似然和 RL 都用于学习,但它们更新参数的方式基于不同的原理。最大似然是基于最小化与正确答案的误差使用固定损失函数,而 RL 则基于可学习的奖励函数,同时通过与环境的互动来最大化累积奖励。

人们(例如,John SchulmanYoavGo等)给出了大量关于使用 RL 训练 LLM 的理由,但如果我追求直观的答案,我相信我们尝试 RL是因为我们想要训练目标函数的灵活性

传统的语言模型训练仅优化一个方面:模型参数,同时保持损失函数固定。这种方法限制了灵活性,因为损失函数如交叉熵本身带来了强大的归纳偏差——最大似然。它奖励最可能的下一个标记预测,假设最高似然输出是最佳的。

如果我们使用 RL,我们不仅在训练模型参数,还在训练奖励函数和训练策略。奖励函数充当一个可学习的损失函数,量身定制于最终目标。这提供了更大的优化自由度,因为我们不再受限于最大似然框架。我们可以从数据中学习目标函数。在 RLHF 中,你的目标函数是奖励模型,你使用 RL 来优化该目标函数。

总结来说,我们尝试使用 RL 来参数化和学习目标函数。这仍然是一个进行中的实验。

我们如何将这定义为 RL 问题?

ChatGPT 的最终目标是生成人类更喜欢的文本。

然后我们可以将 RL 问题的组件定义如下:

代理语言模型充当 RL 代理。它学习生成被认为是基于奖励系统的最佳文本。

动作空间:在这种情况下,动作空间是 LLM 可以生成的所有可能语言输出的集合。鉴于语言的多样性,这个空间非常广泛。

策略:策略是模型在每个生成步骤上的可能输出的概率分布。它根据当前状态决定代理应该采取哪些行动。

环境:环境是代理互动的对象,并且是代理获取其行动反馈的地方。在 RLHF 案例中,环境通过基于人类偏好模型给予奖励的方式向代理提供反馈。

奖励:奖励是来自人类偏好模型的标量信号。RL 中的代理目标是最大化这个期望奖励,从而提高文本生成质量。

通过将语言生成框定为一个 RL 问题,模型可以与奖励模型互动,从而随着时间的推移改善其策略。

对于那些通过阅读代码更容易理解的人,这里有一份由我们的开源贡献者 Phil Wang 慷慨提供的 RLHF 训练器的直接实现

预期有人会抽象化这个版本,我在这里复制了训练脚本。这涵盖了大多数 PPO 训练组件和流程。

  1. generate 函数根据给定的提示生成文本序列。它使用演员-评论家模型生成序列,并使用奖励模型为每个序列打分。选择得分最高的序列作为最佳序列。

  2. learn 函数批量处理经验,计算 PPO 损失,并更新演员和评论家网络。实现核心 PPO 算法。

  3. train 循环收集演员经验,评估奖励并存储在内存中。定期调用 learn() 来更新策略。

class RLHFTrainer(nn.Module):
    def __init__(
        self,
        prompts: Optional[List[str]] = None,
        prompts_path: Optional[str] = None,
        prompt_token_ids: Optional[torch.Tensor] = None,
        tokenizer: Callable = None,
        palm: PaLM,
        reward_model: RewardModel,
        critic_palm: Optional[PaLM] = None,
        actor_critic: Optional[ActorCritic] = None,
        actor_lr = 1e-4,
        critic_lr = 1e-4,
        actor_wd = 0.,
        critic_wd = 0.,
        actor_adam_eps = 1e-7,
        critic_adam_eps = 1e-7,
        actor_lora = True,
        critic_lora = True,
        actor_lora_r = 8,
        critic_lora_r = 8,
        critic_pooled_values = True,
        actor_dropout = 0.,
        critic_dropout = 0.,
        betas = (0.9, 0.999),
        max_norm = None,
        eps_clip = 0.2,
        value_clip = 0.4,
        beta_s = .01,
        pad_value = 0.,
        minibatch_size = 16,
        epochs = 1,
        kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is
        accelerate_kwargs: dict = {},
        use_lion = False
    ):
        super().__init__()
        self.accelerate = Accelerator(**accelerate_kwargs)

        # take care of prompts -> token ids
        assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1
        if exists(prompts_path):
            path = Path(prompts_path)
            prompts = path.read_text().split('\n')
        if exists(prompts):
            assert len(prompts) > 0, 'no prompts'
            assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
            prompt_token_ids = tokenizer(prompts)
        self.pad_value = pad_value # token pad value
        self.num_prompts = prompt_token_ids.shape[0]
        self.register_buffer('prompt_token_ids', prompt_token_ids)

        # models
        self.palm = palm
        if not exists(actor_critic):
            actor_critic = ActorCritic(
                palm = palm,
                critic_palm = critic_palm,
                actor_lora = actor_lora,
                critic_lora = critic_lora,
                actor_lora_r = actor_lora_r,
                critic_lora_r = critic_lora_r,
                pooled_values = critic_pooled_values,
                actor_dropout = actor_dropout,
                critic_dropout = critic_dropout).to(palm.device)
        self.actor_critic = actor_critic
        self.reward_model = reward_model.eval()

        # train hyperparameters
        self.epochs = epochs
        self.minibatch_size = minibatch_size
        self.max_norm = max_norm
        self.kl_div_loss_weight = kl_div_loss_weight

        # optimizers
        self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion)
        self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion)

        # ppo hyperparams
        self.eps_clip = eps_clip
        self.value_clip = value_clip
        self.beta_s = beta_s

        # prepare with accelerator
        (
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        ) = self.accelerate.prepare(
            self.actor_critic,
            self.reward_model,
            self.actor_optim,
            self.critic_optim
        )

    @property
    def device(self):
        return self.accelerate.device

    @torch.no_grad()
    def generate(
        self,
        max_seq_len,
        *args,
        prompt,
        num_samples = 4,  # sample 4 per prompt and select the one with highest reward
        **kwargs
    ):
        assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
        prompt = repeat(prompt, 'n -> b n', b = num_samples)
        actor_critic = self.accelerate.unwrap_model(self.actor_critic)
        reward_model = self.accelerate.unwrap_model(self.reward_model)
        actor_critic.eval()
        (
            actions,
            sequences,
            mask,
            prompt_mask,
            action_logits,
            _
        ) = actor_critic.generate(
            prompt,
            *args,
            max_seq_len = max_seq_len,
            return_values = False,
            **kwargs
        )
        rewards = reward_model(
            sequences,
            prompt_mask = prompt_mask,
            mask = mask,
            sample = True
        )
        best_sequence_index = rewards.topk(1, dim = -1).indices
        best_sequence = sequences[best_sequence_index]
        best_sequence = rearrange(best_sequence, '1 ... -> ...')
        return best_sequence

    def learn(
        self,
        memories: Deque[Memory]
    ):
        # stack all data stored in the memories
        all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))

        # prepare dataloader for policy phase training
        dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
        self.actor_critic.train()

        # PPO training
        for _ in range(self.epochs):
            for (sequences,
                prompt_masks,
                masks,
                old_action_probs,
                old_log_probs,
                rewards,
                old_values) in dl:
                action_masks = ~prompt_masks & masks
                action_logits, values = self.actor_critic(
                    sequences,
                    mask = action_masks
                )
                action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
                action_len = old_log_probs.shape[-1]
                action_probs = action_logits.softmax(dim = -1)
                action_log_probs = log_prob(action_probs, sequences)
                action_log_probs = action_log_probs[:, -action_len:]

                # calculate entropies, taking into account which part of the sequence is actually an action
                entropies = masked_entropy(action_probs, mask = action_masks)

                # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
                kl_penalty = 0.
                if self.kl_div_loss_weight > 0:
                    kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight

                # subtract the kl penalty from the rewards
                rewards = rewards - kl_penalty

                # handle non-pooled values
                normalize_kwargs = dict()
                if old_values.ndim == 2:
                    old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))

                    old_values = old_values[:, -action_len:]
                    values = values[:, -action_len:]
                    rewards = rearrange(rewards, 'b -> b 1')
                    normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
                if values.ndim < rewards.ndim:
                    values = rearrange(values, '... -> ... 1')

                # calculate clipped surrogate objective, classic PPO loss
                ratios = (action_log_probs - old_log_probs).exp()
                advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
                if advantages.ndim == 1:
                    advantages = rearrange(advantages, 'b -> b 1')
                surr1 = ratios * advantages
                surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
                policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies

                # combine losses
                loss = policy_loss.mean()

                # update actor
                self.accelerate.backward(loss)
                self.print(f'policy_loss: {loss.item():.3f}')
                if exists(self.max_norm):
                    self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm)
                self.actor_optim.step()
                self.actor_optim.zero_grad()

                # calculate value loss and update value network separate from policy network
                value_loss = clipped_value_loss(values, rewards.detach(), old_values, self.value_clip)
                value_loss = value_loss.mean()
                self.print(f'critic_loss: {value_loss.item():.3f}')
                self.accelerate.backward(value_loss)
                if exists(self.max_norm):
                    self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm)
                self.critic_optim.step()
                self.critic_optim.zero_grad()

    def train(
        self,
        num_episodes = 50000,
        max_timesteps = 500,
        update_timesteps = 5000,
        max_batch_size = 16,
        max_seq_len = 2048,
        eos_token = None,
        temperature = 1.
    ):
        device = self.device
        time = 0
        memories = deque([])
        for eps in tqdm(range(num_episodes), desc = 'episodes'):
            for timestep in range(max_timesteps):
                time += 1

                # select a bunch of random states (prompts)
                # and get the action (sampled sequence from palm as well as the action probs)
                # also calculate the reward using reward model and store
                rand_prompt_index = randrange(0, self.num_prompts)
                state = self.prompt_token_ids[rand_prompt_index]

                # remove padding from state
                state_mask = state != self.pad_value
                state = state[state_mask]

                # get predicted sequence
                (
                    actions,
                    sequence,
                    mask,
                    prompt_mask,
                    action_logits,
                    value
                ) = self.actor_critic.generate(
                    rearrange(state, 'n -> 1 n'),
                    max_seq_len = max_seq_len,
                    eos_token = eos_token,
                    temperature = temperature,
                    return_values = True
                )
                action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
                action_prob = action_logits.softmax(dim = -1)
                action_len = actions.shape[-1]
                action_log_prob = log_prob(action_prob, sequence)
                action_log_prob = action_log_prob[:, -action_len:]
                actions = rearrange(actions, '1 ... -> ...')

                # get reward as given by supervised trained reward model
                sequence = torch.cat((state, actions), dim = 0)
                prompt_length = len(state)
                prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
                sequence = rearrange(sequence, 'n -> 1 n')
                prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
                mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device))
                reward = self.reward_model(
                    sequence,
                    prompt_mask = prompt_mask,
                    mask = mask,
                    sample = True
                )
                detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')

                # store memory for learning
                memories.append(Memory(*map(detach_to_cpu_, (
                    sequence,
                    prompt_mask,
                    mask,
                    action_prob,
                    action_log_prob,
                    reward,
                    value
                ))))

                # learn from the stored memories
                if time % update_timesteps == 0:
                    self.learn(memories)
                    memories.clear()

        print('rlhf training complete')

Proximal Policy Optimization (PPO)

我们如何在不冒着过度优化导致性能崩溃的风险的情况下,利用当前数据在策略上迈出最大的改进步伐?

Proximal Policy Optimization (PPO) 是一种强化学习算法,它在样本效率和实施简便性之间取得了平衡。为了防止策略变化过大,其目标函数使用了裁剪的替代目标。因此它的名字中有“proximal”一词。这一策略确保了稳定且一致的学习,同时避免了其他旨在实现相同结果的算法常常复杂的实现过程。

我不会详细讨论策略优化及其实施。

PPO 的工作原理值得另写一篇博客,所以我会在这里链接一些好的、深入的教程。

[## 在 RLHF 中指定目标

在 ICML 上,很明显很多人从 RLHF 中获得了价值。什么限制了科学理解…

www.interconnects.ai](https://www.interconnects.ai/p/specifying-objectives-in-rlhf?source=post_page-----faa5ff4761d1--------------------------------) [## Proximal Policy Optimization - Spinning Up 文档

(之前:TRPO 背景)PPO 的动机与 TRPO 相同:我们如何在策略上迈出最大的改进步伐…

spinningup.openai.com

数据规模比较

用于 InstructGPT 的数据量比用于预训练基础模型的数据量小得多。

预训练数据如 GPT-3 使用了 3000 亿个标记。相比之下,InstructGPT 使用了约 O(10M) 个标记。

  • 监督微调(SFT)使用了约 15,000 个提示用于训练,1,500 个用于验证。

  • 奖励模型使用了最多的训练和验证提示,分别约为150,00080,000

  • 强化学习阶段仅使用了约 32,000 个提示用于训练,约 16,000 个用于验证,以优化代理。

因此,总体来说,RLHF 数据约为 1000 万个标记——远远小于用于一般预训练的数百亿个标记。

我将通过突出 InstructGPT 的美妙和有前途的结果来结束这篇博客文章。

结果:使用正确类型的数据进行训练比将模型扩大 100 倍更为有效。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 InstructGPT

看一下图表。红色和黄色线条代表 Instruct-PPO 变体,这些是 RLHF 方法。

ELO 评分在左侧,数字越高表示偏好越强。

PPO 模型仅有 13 亿个参数,而 SFT 和 GPT 模型(由绿色和蓝色线条表示)有 1750 亿个参数。尽管参数远小于 GPT-3,人类显著偏好 InstructGPT 的输出。

这表明,使用正确类型的数据进行训练比仅仅将模型扩大一百倍更为有效。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

InstructGPT 在几个其他具体指标上表现更佳。

我们的圣杯:涌现泛化

尽管我通过提醒我的同事所有提示都在训练数据中而否定了他关于“涌现泛化”的说法,InstructGPT 团队确实观察到了泛化的出现。他们报告了遵循指令时扩展到新领域的泛化程度

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

尽管 99% 的训练数据是英语,InstructGPT 模型偶尔也显示出跟随法语指令的能力。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

此外,尽管训练集中没有针对编程的具体指令,但在代码问答场景中显示出了一些泛化能力。GPT-3 并未真正回答提示问题,但 InstructGPT 表现得相当不错。

这些泛化的迹象表明了 AI 中渴望的涌现现象。尽管 InstructGPT 的技能主要基于其训练数据,但我相信超越它的迹象指向了学习推理的开端

我对随着 RLHF 研究的扩展而取得进一步突破持乐观态度。如果基础的强化学习可以解锁一些泛化能力,那么进步到更大更好的模型可能会帮助我们获得更广泛的新兴智能。

改进开源 RLHF 训练数据:行动事项

最后,我想谈谈我们可以采取哪些行动来改善开源 RLHF 数据。

我们现在陷入了一个恶性循环。因为没有像 ChatGPT 这样的优秀开源 LLM,所以使用它们的人并不多。这导致用于训练和改进的数据较少,结果就是我们得到的是平庸的模型。与此同时,商业 LLM 获得了更多的用户并不断改进。

这里有几种方法可以打破这个循环:

  1. 一个集中的中心,汇总开源用户(已选择参与)的提示、结果和反馈:目前,我知道的唯一可以尝试 LLama 2 的平台是POE。然而,开源维护者无法访问用户输入(提示)和模型的输出,这对改善开源模型至关重要。我们需要让那些从事开源模型工作的人能够获得这些数据。 这一点本身将使开源 LLM 变得更好。我们还需要提升这个平台的用户体验,以吸引更多用户,这将带来更多数据和更好的模型。

  2. 一个统一的数据准备代码库: 一个集中平台,让所有开源 LLM 爱好者可以分享他们的数据工作,如清理、转换、准备和自动标注,将是非常有益的。例如,包括将网页内容转换为可训练格式的代码,以及将一些未标记的数据(如教科书中的文本)自动重新格式化为提示-响应对的代码。目前,开源 RLHF 中的所有数据工作都是分散且未被追踪的。这是有道理的,因为这些核心且艰难的数据工作是区分不同 LLM 的关键。然而,为了利用社区的力量,我们需要建立一个单一的、集中化的中心。

  3. 激励数据共享。 这是最困难的部分,说起来容易做起来难。我目前没有一个好的答案。为了让开源取得进展,人们需要对他们的训练数据保持透明。我们需要找到一种激励数据工作和共享的方法。我们还需要弄清楚开源数据负责人和训练 LLM(大语言模型)之间的密切合作。

如果我们能够解决数据和反馈循环的问题,我确实相信社区有潜力创造出比目前商业上可用的 LLM 更好的模型。这是一个雄心勃勃的目标,但通过集体社区的努力,我相信这是可以实现的。我希望在读完这篇博客文章后,你会更有动力去贡献开源数据。

非常感谢我的审阅者们在他们紧张的日程中挤出时间,分享他们的想法给博客。没有他们,这个博客不会好到现在的一半。

特别感谢(按姓氏字母顺序排列):Nathan Lambert(前 Huggingface)、Rosanne Liu(Deepmind, ML Collective)、Erin LeDell(AutoML)、Joshua Moore(Snap)、Abhinav Srivastava(Breez)、Susan Zhang(OPT-175B)

道路网络边缘匹配与三角形

原文:towardsdatascience.com/road-network-edge-matching-with-triangles-5dc989a77edf?source=collection_archive---------15-----------------------#2023-01-03

三角形在地理空间查询中具有强大的属性

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 João Paulo Figueira

·

关注 发布于 Towards Data Science ·13 min read·2023 年 1 月 3 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Pawel Czerwinski 提供,来自 Unsplash

三角形是具有许多实际几何属性的形状。在这篇文章中,我将阐述在解决特定地理空间问题时如何利用这些属性进行机会优化:恢复缺失的地图匹配信息。

我开始探索扩展车辆能量数据集¹(EVED)[1],以寻找城市道路网络背景下有趣的地理空间数据分析机会。该数据集源自之前的出版物,车辆能量数据集 [2],并包含了许多增强功能,即车辆的地图匹配 GPS 位置。地图匹配过程将原始 GPS 位置快照到最可能的基础道路网络边缘上。

下图图 1(取自之前的文章)展示了地图匹配过程如何将采样的 GPS 位置快照到最可能的道路网络边缘上。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1 — 地图匹配过程将嘈杂的 GPS 测量值快照到最可能的道路网络边缘。这里展示了这一过程的图示,其中以“n”表示的圆圈代表道路网络节点,以“e”命名的箭头表示有向边缘。绿色的采样 GPS 位置与沿弧线的另一个位置匹配并记录在数据库中。然而,匹配边的信息并未提供。(图片来源:作者)

不幸的是,EVED 数据集没有保留基础的匹配边信息;仅保留了位置快照。在缺少边信息的情况下,我们可以从数据中做出更多推断,例如,创建一个目的地预测模型。我们可以从匹配的 GPS 位置中恢复这些信息吗?

文章作者使用了Valhalla工具集,通过使用Open Street Map数据对基础道路网络进行地图匹配操作。掌握这些信息后,我们可以利用地理空间查询恢复缺失的映射边信息。我们从使用一个非常著名的现成工具开始:OSMnx。我们的第一项任务是下载道路网络(即图)。

下载道路网络

要下载和准备道路网络数据,我们需要使用 OSMnx 的功能,如以下代码片段²所示。

def download_road_network(place_name, network_type='drive'):
    graph = ox.graph_from_place(place_name, network_type=network_type, 
                                simplify=False)
    graph = ox.add_edge_speeds(graph)
    graph = ox.add_edge_travel_times(graph)
    graph = ox.bearing.add_edge_bearings(graph)
    return graph

我们从下载一个未简化的图开始,以保留大部分节点细节。接下来,我们向网络中添加缺失的属性,如边缘速度、旅行时间和方位角(从真北开始按顺时针方向测量的角度)。该函数返回道路网络作为一个NetworkX [3] 有向图对象,允许多个边缘存在于节点之间。

road_network = download_road_network("Ann Arbor, Michigan, USA")

寻找边缘

正如我提到的,EVED 只包含地图匹配位置,而不是边本身,我们的任务是重建这些信息。地图匹配过程涉及找到最大化观察路线与已知道路网络之间匹配概率的网络边。更具体地说,该操作将每个 GPS 样本映射到最有可能代表实际行驶路线的道路网络边。地图匹配过程投影采样的 GPS 位置,提供额外的上下文信息。匹配的位置属于边界定义的大圆线段,我们将看到如何利用这一点。

OSMnx 方法

现在让我们转到 OSMnx,发现一种搜索地图匹配位置所属道路网络边缘的方法。幸运的是,该软件包实现了查找最近节点和边的函数,我们将从这里开始。

第一步是将道路网络坐标投影到UTM [4]。这种转换将球面 GPS 坐标投影到一个局部平面空间,在这里我们可以使用常规几何,测量单位为米。

network_utm = ox.projection.project_graph(road_network)

上面的函数调用将道路网络坐标投影到与区域中心对应的 UTM 区域。我们现在可以使用数据库中的坐标对调用 OSMnx 的边检测函数。

easting, northing, zone_num, zone_ltr = utm.from_latlon(42.287702, -83.707775)
edge_id = ox.distance.nearest_edges(network_utm, easting, northing)

该函数支持纬度和经度集合,而不是单个位置,返回相应的边列表。至于上述调用,我们可以使用以下代码检查其结果:

network_utm[edge_id[0]][edge_id[1]][0]

结果是一个包含最近边属性的 Python 字典,如下所示。

{'osmid': 8723817,
 'oneway': False,
 'lanes': '2',
 'highway': 'tertiary',
 'reversed': True,
 'length': 116.428,
 'speed_kph': 48.3,
 'travel_time': 8.7,
 'bearing': 267.3,
 'name': 'Glazier Way',
 'maxspeed': '30 mph'}

不幸的是,这个函数很慢。如果我们想将整个 EVED 数据库转换为为每个点分配最近的边,我们应该尝试另一种方法。

三角形方法

我在本节中提出的解决方案是我首先想到的。正如上文所述,地图匹配位置位于连接端节点的边界大圆线段上。这使我们能够使用三角形性质来找到特定点的最佳网络边。

在进一步解释之前,我邀请你阅读一篇较早的文章,探讨三角形性质以执行高速地理空间查询。

[## 使用三角形不等式查询地理数据

一种快速且简单的方法来查询大量位置。

medium.com](https://medium.com/tblx-insider/using-the-triangle-inequality-to-query-geographic-data-7148a1b103a0?source=post_page-----5dc989a77edf--------------------------------)

在这里,我使用了该文章代码的更新版本来执行道路网络上的基本搜索查询:K 最近邻半径查询。更新的代码版本使用了 Numba 基于优化以提高执行性能。

除了使用三角不等式来加速地理空间查询外,我们还将使用它来选择给定地图匹配 GPS 样本的最佳边缘。这个想法非常简单,我在下面的 图 2 中进行了说明。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2 — 当匹配的 GPS 点与给定道路网络边缘地理测地线不对齐时,三个点定义了一个三角形(上图),并且距离验证 b + c > a。当点对齐(下图)时,我们得到一个退化三角形,b + c = a。(图片来源:作者)

为了将给定的道路网络边缘与 GPS 点匹配,我们需要计算该点到节点(bc图 2 中)的距离。边缘长度(a)是下载的道路网络数据的一个属性。我们计算以下比率作为拟合的度量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3 — 上述比率为 1 当匹配的 GPS 位置位于道路网络边缘地理测地线上;否则,它将更大。最适合的边缘将具有最低可能值。(图片来源:作者)

最适合的道路网络边缘将具有此度量的最低值。但这不是我们必须使用的唯一标准,因为段的方向也很重要。

我们使用端节点的标识符来查询网络边缘,其顺序是重要的。通过反转网络查询中的节点标识符,我们可以获得反向方向的不同属性(如果存在),即计算出的方位角或方向。下面的 图 4 显示了这些属性可能是什么样的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4 — 通过反转端节点标识符,我们可以获得道路段的不同属性,即方位角。(图片来源:作者;数据:© OpenStreetMap 贡献者)

为了正确匹配道路网络边缘,我们还必须知道 GPS 方位角,或者如情况所示,推断出的方位角。您可以阅读下面的文章,了解如何从匹配的 GPS 位置计算 EVED 方位角。

## 使用 Quadkeys 进行旅行时间估计

本文解释了如何使用已知速度向量并通过 quadkeys 索引来估计旅行时间。

towardsdatascience.com

我们现在准备寻找最佳适配的边缘,但如何在一个任意大的道路网络中搜索它呢?一种暴力方法是搜索所有可用的道路段,但这不是有效利用计算能力的好方法,因为我们可以做得更好。我们可以选择一小部分附近的候选节点,然后只在这些节点中搜索。

选择这个候选集的标准很简单——我们将使用来自输入 GPS 位置的半径查询。半径由两部分组成:从查询点到网络的最小距离和最大道路段长度。通过将这两个距离相加,我们获得一个半径,我们可以确定最近的边缘节点将位于该半径内。图 5 下面展示了这一概念。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5 —— 上面的示意图展示了如何确定搜索半径:将查询位置(红色)到道路网络(蓝色)的最短距离与最大段大小(绿色)相加。所有在绿色圆圈内的节点都是候选节点。请注意,查询圆圈的中心在查询位置。(图片来源:作者)

一旦确定了候选节点集,我们只考虑搜索半径内的现有链接。

让我们看看代码是什么样的。我们从处理道路网络的类声明开始:

class RoadNetwork(object):

    def __init__(self, graph, projected=False):
        self.graph = graph
        self.projected = projected
        self.max_edge_length = max([graph[e[0]][e[1]][0]["length"] \
                                    for e in graph.edges])
        self.ids, self.locations = self.get_locations()
        self.geo_spoke = GeoSpoke(self.locations)
# more...

要初始化这个类,我们调用下载和准备 OSM 道路网络的函数,并将其结果作为构造函数参数。构造函数随后收集所有位置并将它们传递给前述文章中描述的索引器对象。请注意,我们不需要为此方法投影任何坐标。

收集地理空间坐标的函数非常简单:

 def get_locations(self):
        latitudes = []
        longitudes = []
        ids = []
        for n in self.graph.nodes:
            ids.append(n)
            node = self.graph.nodes[n]
            longitudes.append(node['x'])
            latitudes.append(node['y'])

        locations = np.array(list(zip(latitudes, longitudes)))
        return np.array(ids), locations
# more...

现在,我们可以进入算法的核心——查询过程本身。对于每个查询点,我们希望选择最有可能限定边缘大地测量段的道路网络节点。下面的函数接收位置坐标,并找到具有最小适配度指标值的道路网络边缘(图 3)。

def get_matching_edge(self, latitude, longitude, bearing=None):
    loc = np.array([latitude, longitude])
    _, r = self.geo_spoke.query_knn(loc, 1)
    radius = self.max_edge_length + r[0]
    node_idx, dists = self.geo_spoke.query_radius(loc, radius)
    nodes = self.ids[node_idx]
    distances = dict(zip(nodes, dists))
    adjacent_set = set()
    graph = self.graph

    best_edge = None
    for node in nodes:
        if node not in adjacent_set:
            adjacent_nodes = np.intersect1d(np.array(graph.adj[node]),
                                            nodes, assume_unique=True)

            adjacent_set.update(adjacent_nodes)
            for adjacent in adjacent_nodes:
                edge_length = graph[node][adjacent][0]['length']
                ratio = (distances[node] + distances[adjacent]) / \
                        edge_length
                if best_edge is None or ratio < best_edge[2]:
                    best_edge = (node, adjacent, ratio)

        if bearing is not None:
            best_edge = fix_edge_bearing(best_edge, bearing, graph)
    return best_edge

代码首先找到最近的道路网络节点及其距离。然后通过将这个距离加上最大的道路网络边缘长度来计算搜索半径。随后的半径查询返回候选节点集合及其到查询位置的距离。我们现在使用节点标识符作为字典中距离的键,以便更快地检索。

主循环遍历候选节点,找到查询半径内需要继续遍历的邻近节点。最后,代码计算适配比率并保留最佳的道路网络边缘。

但在返回的道路网络边缘中还有一个最终测试:其方向。如果我们有样本 GPS 方位角,我们可以解决这个问题。正如我之前解释的,我们有可以使用的推断方位角值。你可以在代码的最后部分看到这一点,只有在你提供了航向角并且反向边存在时,代码才会有效。修正边缘航向角的函数如下所示。

def fix_edge_bearing(best_edge, bearing, graph):
    if (best_edge[1], best_edge[0], 0) in graph.edges:
        bearing0 = radians(graph[best_edge[0]][best_edge[1]][0]['bearing'])
        bearing1 = radians(graph[best_edge[1]][best_edge[0]][0]['bearing'])
        gps_bearing = radians(bearing)
        if cos(bearing1 - gps_bearing) > cos(bearing0 - gps_bearing):
            best_edge = (best_edge[1], best_edge[0], best_edge[2])
    return best_edge

你可以使用附带的Jupyter notebook来测试这段代码,代码存放在GitHub 仓库中。在我的 MacBook Pro 上,这段代码的性能比 OSMnx 方法提高了三倍以上。

距离方式

有人可能会争辩说,在严格假设下,前面的安排表现更好,即查询位置已经在道路段的测地线上。如果情况并非如此呢?我们能否基于相同的搜索原理开发一种更通用的方法?可以!但我们必须假设距离很小³,因此我们不必进行坐标投影,幸运的是,这种情况是符合的。

与使用上述三角形比率度量不同,我们可以在不需要任何地理空间投影(如上文提到的 UTM)的情况下计算 GPS 位置与任何附近道路段之间的距离。我们再次依赖三角形的属性,使用两种不同的方法计算三角形的面积和其他三角形不等式[5]。

在计算给定点到线段的距离时,我们需要考虑两种情况:可以将点正交投影到线段上,以及不能投影的情况。让我们在下面的图 6中可视化第一种情况。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6 — 查询点(红色)正交投影到道路段(蓝色)。两者之间的距离(黑色)是未知三角形的高度,而我们知道所有长度。(图片来源:作者)

对于这种情况,我们的未知量是三角形的高度,即从点到道路段的最短距离。那么我们如何计算它呢?其中一个最著名的三角形面积公式使用了这个量,见下面的图 7

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7 — 三角形的面积等于其底边和高度的乘积除以二。(图片来源:作者)

如果已知面积,我们可以通过简单的代数迅速推导出高度。我们可以仅使用边长计算三角形的面积吗?

另一个可能不太为人所知的三角形面积公式得名于亚历山大里的赫伦 [6],他是第一个证明这个公式的人。有趣的是,这个公式仅依赖于我们已经知道的东西——三角形的边长。这个公式有几种形式,其中最著名的可能是下面的图 8中的形式。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8 — 海伦公式仅使用边长来计算三角形的面积。量“s”是半周长。(图片来源:作者)

使用这个公式,我们可以计算三角形的面积,并将其用于前面的公式中,以获得从样本点到段的距离。不幸的是,这种公式已知在数值稳定性方面存在问题,特别是当应用于具有非常锐角的“平坦”三角形时。我们将使用图 9中所示的一个已知稳定的替代方案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9 — 数值稳定的海伦公式要求a ≥ b ≥ c。(图片来源:作者)

当我们无法将查询点正交投影到道路段上时会发生什么?我们可以通过下面的图 10来可视化这种情况。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10 — 我们无法将查询点正交投影到道路段上。在这种情况下,两者之间的距离是“a”。(图片来源:作者)

在这种情况下,我们很容易,因为距离已经计算出来。但我们如何仅使用边长来了解几何形状呢?我们可以通过一个观察来区分图 6图 10中的三角形。在图 6中,段“a”和“c”与“b”形成的角度都是锐角,而在图 10中,其中一个角是钝角(大于 90 度)。

幸运的是,几何学帮助我们通过另一组三角形不等式来确定内部三角形角度是锐角、钝角还是直角。在图 9的情况下,我们有c² > a² + b²。在对称情况下,即对角是钝角时,我们会有a² > b² + c²。这两个测试可以区分这两种情况,并且执行速度非常快。

下面的代码演示了使用距离而不是简单适应度比率的查询。

def get_nearest_edge(self, latitude, longitude, bearing=None):
    best_edge = None
    adjacent_set = set()
    graph = self.graph

    loc = np.array([latitude, longitude])
    _, r = self.geo_spoke.query_knn(loc, 1)
    radius = self.max_edge_length + r[0]
    node_idx, dists = self.geo_spoke.query_radius(loc, radius)
    nodes = self.ids[node_idx]
    distances = dict(zip(nodes, dists))

    for node in nodes:
        if node not in adjacent_set:
            adjacent_nodes = np.intersect1d(np.array(graph.adj[node]),
                                            nodes, assume_unique=True)

            adjacent_set.update(adjacent_nodes)
            for adjacent in adjacent_nodes:
                a = distances[node]
                b = graph[node][adjacent][0]['length']
                c = distances[adjacent]

                a2, b2, c2 = a * a, b * b, c * c

                if c2 > a2 + b2 or a2 > b2 + c2:
                    distance = min(a, c)
                else:
                    area = heron_area(a, b, c)
                    distance = area * 2.0 / b

                if best_edge is None or distance < best_edge[2]:
                    best_edge = (node, adjacent, distance)

    if bearing is not None:
        best_edge = fix_edge_bearing(best_edge, bearing, graph)
    return best_edge

最后,以下函数根据三个任意的三角形边长计算海伦公式。注意代码如何通过适当地排序边长来开始。

@njit()
def heron_area(a, b, c):
    c, b, a = np.sort(np.array([a, b, c]))
    return sqrt((a + (b + c)) *
                (c - (a - b)) *
                (c + (a - b)) *
                (a + (b - c))) / 4.0

让我们看看所有这些努力是否是值得的。

性能

我使用 2019 年 16 英寸 MacBook Pro,配备 2.6 GHz 6 核 Intel Core i7 CPU,32 GB RAM,和 Ventura 13.0 获取了下面的性能结果。所有三种方法都查询了相同的 868 点轨迹,该轨迹来自 EVED。

在下面的图 11中,你可以看到这篇文章中介绍的三种算法的基准结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11 — 上述性能测量反映了每种算法处理 868 个地点(含重复项)的路径所需的平均时间。(图片来源:作者)

如您所见,我使用了缓存来处理重复项并避免不必要的处理。这可能会对 OSMnx 算法提供不公平的优势,为了澄清,我决定使用相同路径中的 203 个唯一位置运行相同的基准测试。结果显示在图 12下方。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 12 — 当我们将输入轨迹减少到仅有 203 个唯一位置时,性能似乎有所改善。然而,性能曲线没有显著变化。(图片来源:作者)

请注意,对于1.3.0之前的 OSMnx 版本,性能差异显著更差。

我们利用了三角形属性并找到了一种快速的边匹配算法。然而,我应该进行更多测试,以调查边缘情况和更大的道路网络,以确保这是一个可靠的算法。

结论

在本文中,我开发了一种快速算法,用于搜索 EVED 位置的缺失地图匹配边。通过假设这些位置位于道路网络边缘的测地线上,我开发了一种快速的拟合度量,使用了三角形不等式属性。接着,我丰富了算法,使用了点到线段的几何概念。我使用了更多的三角形属性和不等式,仅考虑了边长。最后,我对解决方案进行了基准测试,并确认了新算法在性能上的提升,超过了 OSMnx 算法。

最后,我要强调的是,性能提升源于我对问题定义所能做出的强假设。该算法的性能会随着搜索半径的增加而下降,这高度依赖于道路网络结构和节点密度。

请从GitHub 存储库获取代码。

注释

  1. 原作者将数据集授权为 Apache 2.0 许可证(参见VEDEVED GitHub 存储库)。请注意,这也适用于衍生作品。

  2. 我将本文及附带的 GitHub 存储库中的所有代码授权为 MIT 许可证。

  3. 我们处理的数据集涉及相对较小的距离。下载数据的最大道路段长度小于 600 米(0.37 英里或 1968 英尺)。您可能可以安全地使用更大的距离而不会产生显著误差,但我建议检查所产生的误差是否在可接受范围内。

参考文献

[1] 张松,法提赫,阿卜杜勒卡迪尔,施瓦茨,马晓。 (2022). 扩展车辆能量数据集 (eVED): 一个增强的大规模数据集,用于深度学习车辆旅行能量消耗。arXivdoi.org/10.48550/arXiv.2203.08630

[2] Oh, G. S., Leblanc, D. J., & Peng, H. (2019). 车辆能源数据集 (VED),用于车辆能源消耗研究的大规模数据集。arXiv. doi.org/10.48550/arXiv.1905.02081

[3] Aric A. Hagberg, Daniel A. Schult 和 Pieter J. Swart, “使用 NetworkX 探索网络结构、动态和功能”, 见于 第七届科学会议(SciPy2008)论文集, Gäel Varoquaux, Travis Vaught 和 Jarrod Millman (编辑), (美国加州帕萨迪纳), 第 11–15 页, 2008 年 8 月

[4] 通用横坐标系统。 (2022 年 6 月 16 日). 见于维基百科. en.wikipedia.org/wiki/Universal_Transverse_Mercator_coordinate_system

[5] 三角不等式列表。 (2022 年 12 月 17 日). 见于维基百科. en.wikipedia.org/wiki/List_of_triangle_inequalities

[6] 赫伦公式。 (2022 年 12 月 17 日). 见于维基百科. en.wikipedia.org/wiki/Heron%27s_formula

João Paulo Figueira 在tb.lx by Daimler Trucks and Buses担任数据科学家,工作地点在葡萄牙里斯本。

大型语言模型:RoBERTa——一种强健优化的 BERT 方法

原文:towardsdatascience.com/roberta-1ef07226c8d8?source=collection_archive---------1-----------------------#2023-09-24

了解用于 BERT 优化的关键技术

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vyacheslav Efimov

·

关注 发表在 Towards Data Science ·5 分钟阅读·2023 年 9 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

介绍

BERT模型的出现带来了 NLP 领域的重大进展。BERT 源自Transformer架构,在各种下游任务上实现了最先进的结果:语言建模、下一句预测、问答、命名实体识别标记等。

## 大型语言模型:BERT——来自 Transformer 的双向编码表示

了解BERT如何构建最先进的嵌入表示

towardsdatascience.com

尽管 BERT 的性能优秀,研究人员仍然继续尝试调整其配置,希望实现更好的指标。幸运的是,他们成功了,并提出了一种新的模型,称为 RoBERTa——鲁棒优化 BERT 方法。

在本文中,我们将引用官方的 RoBERTa 论文,其中包含有关该模型的深入信息。简单来说,RoBERTa 包含了对原始 BERT 模型的几个独立改进——所有其他原则包括架构保持不变。所有的进展将会在本文中涵盖和解释。

1. 动态掩蔽

从 BERT 的架构中,我们记得在预训练期间,BERT 通过尝试预测一定百分比的掩蔽标记来执行语言建模。原始实现的问题在于,对于给定的文本序列,在不同的批次中选择掩蔽的标记有时是相同的。

更准确地说,训练数据集重复了 10 次,因此每个序列仅以 10 种不同的方式进行掩蔽。考虑到 BERT 运行了 40 个训练周期,每个具有相同掩蔽的序列会传递给 BERT 四次。研究人员发现,使用动态掩蔽稍微更好,即每次将序列传递给 BERT 时掩蔽都是唯一生成的。总的来说,这在训练过程中减少了重复的数据,给模型提供了处理更多不同数据和掩蔽模式的机会。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

静态掩蔽与动态掩蔽

2. 下一个句子预测

论文的作者进行了研究,以寻找建模下一个句子预测任务的最佳方法。因此,他们发现了几个有价值的见解:

  • 去除下一个句子预测损失会导致性能略有提升。

  • 将单个自然语言句子传入 BERT 输入会降低性能,相较于传入由多个句子组成的序列。解释这一现象的一个可能假设是模型仅依靠单个句子难以学习长距离依赖关系。

  • 从单个文档中采样连续的 句子来构造输入序列比从多个文档中采样更有利。 通常,序列总是从单个文档的连续完整句子中构造,这样总长度最多为 512 个标记。问题在于当我们到达文档末尾时。研究人员在这方面比较了是否值得停止采样句子,还是额外采样下一个文档的前几个句子(并在文档之间添加相应的分隔标记)。结果表明,第一个选项更好。

最终,对于 RoBERTa 的最终实现,作者选择保留前两个方面而省略第三个方面。尽管第三个见解带来了观察到的改进,但研究人员没有继续采用它,因为这会使之前实现之间的比较变得更加困难。这是因为达到文档边界并在此停止意味着输入序列将包含少于 512 个标记。为了在所有批次中保持类似的标记数量,在这种情况下需要增加批量大小。这会导致批量大小的变化和更复杂的比较,而研究人员希望避免这种情况。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3. 增加批量大小

最近在自然语言处理(NLP)领域的进展表明,增加批量大小并适当减少学习率和训练步数通常会改善模型的性能。

提醒一下,BERT base 模型在 256 个序列的批量大小下训练了 100 万步。作者尝试了 2K 和 8K 的批量大小,并选择了后者来训练 RoBERTa。相应的训练步数和学习率值分别变为 31K 和 1e-3。

同时,需要注意的是,批量大小的增加通过一种叫做“梯度累积”的特殊技术可以更容易地进行并行化。

4. 字节文本编码

在 NLP 中,存在三种主要的文本标记化类型:

  • 字符级别的标记化

  • 子词级别的标记化

  • 单词级别的标记化

原始的 BERT 使用了子词级别的标记化,词汇表大小为 30K,该词汇表在输入预处理后使用多个启发式方法进行学习。RoBERTa 使用字节而非 Unicode 字符作为子词的基础,并将词汇表大小扩展到 50K,而无需任何预处理或输入标记化。这导致 BERT base 和 BERT large 模型分别增加了 15M 和 20M 的额外参数。RoBERTa 中引入的编码版本表现稍逊于之前的版本。

尽管如此,RoBERTa 中词汇表大小的增长使其能够编码几乎任何单词或子词,而无需使用未知标记,这相比于 BERT 是一个显著的优势。这使得 RoBERTa 可以更全面地理解包含稀有词汇的复杂文本。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

预训练

除此之外,RoBERTa 采用与 BERT large 相同的架构参数,应用了上述四个方面。RoBERTa 的总参数数量为 355M。

RoBERTa 在五个大规模数据集的组合上进行了预训练,总共达到了 160 GB 的文本数据。相比之下,BERT large 只在 13 GB 的数据上进行了预训练。最后,作者将训练步数从 100K 增加到 500K。

结果是,RoBERTa 在最受欢迎的基准测试中超越了 BERT large 和 XLNet large。

RoBERTa 版本

类似于 BERT,研究人员开发了两个版本的 RoBERTa。基础版和大型版中的大多数超参数是相同的。下图展示了主要的区别:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

RoBERTa 的微调过程类似于 BERT 的过程。

结论

在本文中,我们探讨了 BERT 的改进版本,该版本通过引入以下几个方面修改了原始训练过程:

  • 动态掩蔽

  • 省略下一句预测目标

  • 在更长的句子上进行训练

  • 增加词汇表的大小

  • 在数据上使用更大的批量进行更长时间的训练

结果显示,RoBERTa 模型在顶级基准测试中优于其前身。尽管配置更复杂,但 RoBERTa 仅增加了 1500 万个参数,同时保持了与 BERT 相当的推理速度。

资源

除非另有说明,否则所有图片均为作者提供

石头剪刀布:量子计算的妙趣

原文:towardsdatascience.com/rock-paper-scissors-a-quantum-computing-twist-bcf66b88d781

教程

以先进计算的新方式来玩

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Kory Becker

·发表于Towards Data Science ·14 分钟阅读·2023 年 5 月 16 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来源:Stable Diffusion

享受量子计算游戏的乐趣

我喜欢展示量子计算的效果。特别是,通过使用量子叠加态和纠缠的游戏。

量子计算是一项极其激动人心的技术,它将影响几乎所有的行业和科学。了解量子计算的工作原理,特别是它如何不同于经典计算,能让你成为更好的程序员和更具逻辑思维能力的人!

所以,我认为设计一个可以在量子计算机上实现的游戏,进一步展示量子计算和传统计算之间的差异,会很有趣。

不乏令人惊叹的量子奇迹

与经典计算机相比,量子计算机具备许多令人惊讶(甚至令人困惑!)的强大特性。

通过指数处理能力提高性能,评估多个场景同时,甚至生成高级随机数,这些都可以在量子层面实现。

我们将重点关注叠加态——这种技术允许量子计算机同时评估多种不同的场景。为了增加趣味性——我们将使用游戏!

量子游戏的创意

我四处搜寻,希望找到一个足够简单、易于玩耍,同时又能利用量子处理的游戏创意。

如果这以前从未做过,那就更好了!

这使我想到了一些经典游戏,比如井字游戏扑克以及其他游戏。

虽然许多将量子计算与游戏结合的研究论文通常关注各种算法方法和数学复杂性,但我希望创建一些更容易理解的东西。

编写量子计算程序并不一定很难。

一旦你了解了各种量子门和量子比特的行为,你就可以创建大量的量子应用程序。

还有什么比用石头、剪刀、布的游戏更好地展示这一点呢!

石头、剪刀、布

石头、剪刀、布是一个两人游戏。游戏中每个玩家秘密选择一个石头、剪刀或布的物品。玩家通常数到三,然后同时展示他们的选择。

游戏规则规定,石头击败剪刀,剪刀击败布,布击败石头。

足够简单!或者说是吗?

一个数学悖论

石头、剪刀、布的基本原则实际上是一种权重和价值的衡量。

我们可以认为石头的价值大于剪刀。同样,剪刀的价值大于布。到目前为止,一切都很好。

石头 > 剪刀 > 布

现在,如果石头大于剪刀,剪刀大于布,那么石头也一定大于布。然而,根据游戏规则,布的价值大于石头!

石头 > 剪刀 > 布 > 石头?

这确实是一个悖论!

让我们从数学角度思考一下这个问题。

让我们退一步来考虑一下使石头、剪刀、布游戏如此独特的悖论。

设想我们有三个变量:A、B 和 C(分别代表石头、布和剪刀)。每个变量被赋予一个权重,使得 A > B 和 B > C。不等式的传递性规定,根据这种排列,A > C。

这会使我们相信,如果石头 > 剪刀,剪刀 > 布,那么石头 > 布。显然,这不是游戏的实际玩法!

实际上,这就是哈代悖论的前提。

石头、剪刀、布的哈代悖论

哈代悖论,由卢西安·哈代于 1992 年提出,考察了三种变量表面上在每个变量的权重上都大于下一个的情况,仍然可以产生最后一个变量大于第一个变量的情况——违反了不等式的传递性。

实际上,这种类型的违反在量子世界中是显然可能的,在量子世界中,粒子可能会纠缠在一起,事实上,它还可以在石头、剪子、布的游戏中找到!

我认为我们刚刚找到了一款完美的游戏,来展示量子计算背后的部分力量。

建立游戏规则

我们将创建一个量子计算程序,可以在石头、纸、剪刀的游戏中找到所有获胜手。

我们将创建经典版和量子版程序,以展示复杂性的差异。

我们需要做的第一件事是对游戏中的项进行编码,以便我们可以在算法中使用它们。由于每个玩家在每轮游戏中可以选择三项中的一个(石头、纸或剪刀),我们可以有九种不同的游戏手可能性。

让我们使用每个项的首字母来代表游戏中每个玩家的可能选择。因此,可能的手牌如下所示。

[RR, RP, RS, PR, PS, PP, SR, SP, SS]

上面的列表显示了所有可能的手牌,从石头对石头(RR),石头对纸(RP),石头对剪刀(RS)等开始。

在上述九种可能的手牌中,只有三种是获胜选择:石头对剪刀(RS),剪刀对纸(SP),和纸对石头(PR)。

[RS, SP, PR]

进入数字世界

现在我们已经定义了游戏选择,我们需要将这些选择从字母(R,S,P)转换为二进制数字零或一。这是必要的,以便我们最终可以将选择表示为量子比特。

由于我们有三项,我们将从零到二(00,01,10)表示它们。

# The Items
00 = Rock
01 = Paper
10 = Scissors

重要的是要注意我们为每个项分配的二进制值,因为我们在查看程序的输入和输出时会参考这些值。

接下来,让我们从玩家一的角度定义游戏规则。

# Rock
00 vs 00 = Tie
00 vs 01 = Loss
00 vs 10 = Win
# Paper
01 vs 00 = Win
01 vs 01 = Tie
01 vs 10 = Loss
# Scissors
10 vs 00 = Loss
10 vs 01 = Win
10 vs 10 = Tie

现在我们已经为每种可能的游戏手定义了简单的数字定义,让我们开始编写一些代码。

将游戏编码为比特

由于我们的游戏项被定义为二进制值,我们现在可以将这些值存储在量子比特中。让我们创建一个 Python 对象,将每个项的对应值定义为上面部分中列出的值。

# Encode the choices as qubits.
choices = {
 ‘rock’: [0,0],
 ‘paper’: [0,1],
 ‘scissors’: [1,0]
}

现在,让我们看看能否找到所有可能的获胜动作。

创建获胜的逻辑表达式

我们已经为每种选择定义了一个表示(石头 00,纸 01,剪刀 10)。由于我们有两个玩家,每轮将有四个比特。

一轮游戏可能如下面所示。

*玩家 1 选择石头。

玩家 2 选择纸。

石头 = 00 和 纸 = 01*

输入将是 0001。

为了确定这是玩家一的获胜动作,我们需要检查一些逻辑来决定游戏规则。

游戏规则规定石头击败剪刀,剪刀击败纸,纸击败石头。

我们可以使用布尔逻辑来编码这些规则。

bool isWin = (rock and scissors) or (scissors and paper) or (paper and rock)

发现所有获胜手的慢速方法

让我们开始编写一个经典计算机程序来找出所有获胜手。

我们可以创建一个名为 check_all_games() 的方法,该方法遍历所有可能的项组合,并仅返回对玩家一有利的手牌。

def check_all_games():
    # Generate a list of all possible game choices for player1 and player2.
    result = []
    count = 0

    games = list(itertools.product([0, 1], repeat=4))
    for game in games:
        # Example: (1, 0, 0, 1) => scissors vs paper
        player1 = list(game[0:2])
        player2 = list(game[2:4])

        # A quick check to make sure both player moves are valid.
        if player1 in list(choices.values()) and player2 in list(choices.values()):
            # ...
            is_win = isWin(player1, player2)
            if is_win:
                result += [game]

        count += 1

    return (result, count)

([(0, 0, 1, 0), (0, 1, 0, 0), (1, 0, 0, 1)], 16)

该方法返回包括石头对剪刀、纸对石头和剪刀对纸的获胜手牌列表。它还返回搜索所有组合所需的迭代次数。

(0, 0, 1, 0) = 石头 (0, 0) 对比 剪刀 (1, 0)

(0, 1, 0, 0) = 纸 (0, 1) 对比 石头 (0, 0)

(1, 0, 0, 1) = 剪刀 (1, 0) 对比 纸 (0, 1)

你注意到找到所有获胜游戏需要 16 次迭代吗?更不用说,这些迭代包括无效的比特组合,例如 [1, 1, 1, 1] — 这些甚至不对应有效的项!

量子计算能做得更好吗?

让我们再试一次。不过,这次我们将创建一个量子计算程序来找到所有获胜手牌。

以经典程序的相同方式,我们将定义一个 isWin() 函数,编码游戏规则。

一个编码了某些特定逻辑规则(例如我们游戏中的获胜规则)的黑箱量子电路称为 oracle

由于我们的 oracle 将处理二进制值 0 和 1,而不是变量名称,让我们用这些值重写我们的逻辑表达式。

bool isWin = (00 and 10) or (01 and 00) or (10 and 01)

此外,由于我们将使用量子计算库 Qiskit,我们需要将量子比特按相反的顺序表示。因此,我们将通过交换右边和左边位的位置来调整我们的逻辑。

[(0, 0, 1, 0), (0, 1, 0, 0), (1, 0, 0, 1)]
[(q1 q0 q3 q2),(q1 q0 q3 q2),(q1 q0 q3 q2)]

第一行是我们经典程序返回的获胜手牌结果。我们只是用量子比特(标记为 q0、q1、q2、q3)表示每个位。量子比特按相反的顺序排列,使得前两个位是玩家一,最后两个位是玩家二。每个玩家的量子比特按最不重要的位在右,最重要的位在左排列(对应于 [q1, q0] 和 [q3, q2])。我们对所有三种获胜手牌组合都重复这一过程。

创建一个 oracle

让我们为我们的量子计算解决方案创建 oracle。

就像我们在经典程序中做的那样,我们将使用布尔逻辑编码游戏规则。然而,这次的不同之处在于,我们引用 q0、q1、q2 和 q3 来表示石头、纸和剪刀。

例如,编码在我们 oracle 中的第一种获胜手牌是石头对剪刀。我们可以如下所示地编码这一点。

石头对剪刀

(0, 0, 1, 0)

(q1 q0 q3 q2)

首次赢牌条件

(not q0 and not q1 and not q2 and q3)

反转量子比特顺序

(not q1 and not q0 and q3 and not q2)

转换为二进制

(00 对比 10)

转换为游戏轮次

(石头对剪刀)

# Define a classical logical circuit with 4 variables (qubits).
isWin = 'def isWin(q0: Int1, q1: Int1, q2: Int1, q3: Int1) -> Int1:\n  return (not q0 and not q1 and not q2 and q3) or (q0 and not q1 and not q2 and not q3) or (not q0 and q1 and q2 and not q3)'

# Convert the logic to a quantum circuit.
formula = ClassicalFunction(isWin)
fc = formula.synth()

# Convert the quantum circuit to a quantum program.
qc = QuantumCircuit(4+1)
qc.compose(fc, inplace=True)

所有获胜手牌都在 oracle 中用一行布尔逻辑进行编码。这创建了一个量子计算电路,可以在我们的程序中用于找到所有获胜手牌!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

针对游戏石头、剪刀、布的所有获胜手牌的量子计算神谕。来源:作者。

将这些内容整合在一起,我们可以使用这个神谕创建一个量子计算程序。

# Get the number of qubits needed.
n = len(choices['rock']) * 2

qc = QuantumCircuit(n + 1, 1)

# Paper vs Rock.
qc = encode('paper', 'rock', qc)

# Append the rock, paper, scissors oracle.
qc.append(oracle, range(5))

# Measure the result!
qc.measure(4, 0)

在这个例子中,我们玩的是纸对石头的单轮游戏。在得到的量子计算程序中,请注意第一个量子比特(q0)使用 X-门反转为一,而第二个量子比特(q1)保持为零。这对应于(01),表示纸。同样,第三个和第四个量子比特(q2 和 q3)保持为零(00),对应于石头。

这是一场纸对石头的游戏。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个用于纸对石头的量子计算电路,其中玩家一选择纸(01),玩家二选择石头(00)。来源:作者。

我们的量子计算程序的结果返回了一个输出,指示这是否是一个获胜的手牌。由于量子比特的输出是反向的(记住,我们是从右到左读取的!),我在下面高亮了一个示例并附上了解释每个量子比特值的注释。

1 00 01
^- win
  ^^----- rock
     ^^--------paper

运行量子计算程序

让我们运行程序并查看结果。由于纸总是战胜石头,我们期望我们的程序在量子程序的所有测量中都输出一个值为一的结果。

simulator = Aer.get_backend('aer_simulator')
job = execute(qc, simulator)
result = job.result()
counts = result.get_counts()

key = max(counts, key=counts.get)

print(counts)
plot_histogram(counts)

{‘1’: 1024}

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

玩家一的石头、剪刀、布获胜手牌。来源:作者。

确实,我们可以看到所有测量结果都是强值一。这表明纸对石头对玩家一是胜利的!

同样,我们可以将相同的神谕应用于纸对剪刀的游戏轮次。在这一轮中,我们预计结果为零,因为纸总是被剪刀战胜。

# Rock vs Scissors.
qc = encode('paper', 'scissors', qc)

{‘0’: 1024}

我们再次得到了正确的答案,表明这是玩家一的失利。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

玩家一的输牌。来源:作者。

到目前为止,我们只是确定单轮游戏是否对玩家一有利。这并不令人印象深刻。毕竟,我们的经典程序找到了所有获胜手牌(尽管计算需要 16 次迭代!)。

我们能找到所有的获胜手牌吗?

量子处理的力量

结果是,我们已经创建了一个编码了游戏获胜手牌的量子神谕,我们实际上可以计算所有的获胜手牌。更棒的是,我们可以在一个 CPU 周期内完成这个计算!

我们将量子比特置于叠加态,而不是将其硬编码为特定的零或一,这些值对应于每个玩家选择的石头、纸或剪刀项。这将量子比特的值从 0 1 改变为 0 1 同时存在!

通过使用叠加态,我们可以在一次执行中评估 所有 可能的游戏手牌,并仅返回那些满足神谕布尔逻辑的获胜手牌。

这是一个如何实现的例子。

qc = QuantumCircuit(n + 1)

qc.h(range(n))

# Append the sudoku oracle.
qc.append(oracle, range(n+1))

# Measure the result!
qc.measure_all()
qc.draw(output='mpl')

请注意,我们没有为玩家一和玩家二硬编码特定项目。相反,我们使用Hadamard 门将所有四个量子比特放入超位置,以便它们同时持有 0 和 1 的值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过将玩家一和玩家二的量子比特放入超位置来找到石头、剪刀、布的所有可能获胜组合。来源:作者。

这将产生如上所示的量子计算电路。如果我们运行这个程序,我们应该看到所有满足 oracle 布尔逻辑的获胜组合的指示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用没有放大的 oracle 执行量子电路。来源:作者。

在结果中,最重要的量子比特(最左侧或最下方)是 0(失败)或 1(获胜)。所以,我们关注的是图表最右侧的 3 个获胜结果。

然而,这似乎不太对劲!

实际上,所有可能的量子比特值组合似乎是完全随机的。

Grover 搜索救援

结果表明,当在超位置的量子比特空间中与一个 oracle 进行搜索时,我们需要放大满足 oracle 的获胜结果,同时最小化不满足 oracle 的失败结果。

我们可以使用Grover 搜索量子算法来实现这一点。

Grover 的搜索算法利用扩散器和放大过程,使正确的结果“漂浮”得更高,而错误的结果保持较低。它可以用于在无序项目的数据库中搜索密钥,并且比任何经典算法的搜索速度平方级更快。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 Grover 搜索算法找到石头、剪刀、布中所有可能的获胜动作。来源:作者。

在使用 Grover 搜索算法运行这个新电路后,加上我们相同的石头、剪刀、布的 oracle,我们可以看到输出的变化。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

石头、剪刀和布中的获胜组合。从左到右:布 (01) 对 石头 (00),剪刀 (10) 对 布 (01),石头 (00) 对 剪刀 (10)。来源:作者。

检查上述结果,我们确实有三个结果远高于其余组合。实际上,这些结果直接对应于游戏中的获胜动作!

如果我们解码每一个结果,从图表最左侧开始,并反转 Qiskit 返回的比特,我们可以确定获胜的组合。请记住,最上面的比特是最低有效比特,对应于玩家一。

0001 = 布 (01) 对 石头 (00) = 获胜

0110 = 剪刀 (10) 对 布 (01) = 获胜

1000 = 石头 (00) 对 剪刀 (10) = 获胜

0001 => 01 versus 00 => paper versus rock => WIN
   ^-  q0
  ^--- q1
 ^---- q2
^----- q3

最令人惊讶的是,经典程序需要 16 次迭代才能找到这三种获胜组合。量子计算程序只需要一次!

还有一点乐趣

我们刚刚研究了量子计算程序如何通过在 CPU 上进行一次执行即可找到石头、剪刀、布游戏中的所有获胜招数。我们通过将量子位置于叠加态来实现这一点。

然而,通过调整量子位的叠加状态,我们实际上可以创造出程序的不同行为。

例如,假设我们想要找出在玩家二给定特定选择时,玩家一的最佳行动。我们可以通过将玩家一的量子位置于叠加态,而将玩家二的量子位固定为特定值来做到这一点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在玩家二选择石头时,为玩家一找到一个获胜的动作。来源:作者。

如果我们现在运行量子程序,期望看到的结果是一个高测量值,这将对应于玩家一应选择的获胜手牌,以击败玩家二,而不是看到三个高测量值(对应于所有获胜手牌)。

在上述场景中,我们为玩家二分配了石头(00)的选择。让我们看看量子程序选择的行动是什么!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

当玩家二选择石头(00)时,玩家一的制胜招数是纸(01)。来源:作者。

结果表明(0001)。从最低位到最高位读取,这评估为玩家一在玩家二选择石头(00)时选择纸(01)。实际上,这一举动确实是玩家一的制胜招数!

纸胜石头!

你可以在 这里 下载完整的石头、剪刀、布程序代码示例。

轮到你了

现在我们已经完成了一个量子计算程序来找到石头、剪刀、布游戏中的所有获胜招数,让我们思考一下我们所取得的成就。

一个经典程序需要 16 次迭代才能找到所有获胜的手牌。相比之下,使用 Grover 搜索的量子版本只需 1 次迭代。这只是经典计算机和量子计算机工作方式之间的一个惊人差异。

Grover 搜索可以应用于许多不同的可搜索性问题,包括算法、文件系统和数据库,仅举几例。此外,由于量子计算领域仍然如此年轻,你有真正的机会产生影响。

我希望你对学习更多关于这项惊人技术的兴趣被激发。现在轮到你了!

关于作者

如果你喜欢这篇文章,请考虑在 MediumTwitter 和我的 网站 上关注我,以便接收我未来的帖子和研究工作通知。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值