使用 Python 的 GANs 实践生成式 AI:DCGAN
原文:
towardsdatascience.com/hands-on-generative-ai-with-gans-using-python-dcgan-6077f0067ac3
Vinicius “amnx” Amano 拍摄的照片,来源于 Unsplash
使用卷积层在 PyTorch 中改进合成图像生成
·发表在 Towards Data Science ·5 分钟阅读·2023 年 4 月 4 日
–
介绍
在我的上一篇文章中,我们已经看到了如何使用 GANs 生成 MNIST 数据集类型的图像。我们取得了不错的结果,并成功实现了我们的目标。然而,这两个网络 G(生成器)和 D(判别器)主要由密集层组成。此时,你应该知道,通常在处理图像时我们使用 CNNs(卷积神经网络),因为它们使用卷积层。所以,让我们看看如何通过使用这些类型的层来改进我们的 GANs。使用卷积层的 GANs 称为 DCGANs。
什么是转置反卷积?
通常当我们使用 CNNs 时,我们习惯于使用卷积层。不过,在这种情况下,我们还需要“反向”操作,即转置反卷积,有时也称为反卷积。
这个操作使我们能够对特征空间进行上采样。例如,如果我们有一个 5x5 的网格表示的图像,我们可以将这个网格“放大”到 28x28。
原则上,你做的事情很简单,你在初始特征图的元素内部填充零以扩大它,然后应用正常的卷积操作,使用特定的核大小、步幅和填充。
例如,假设我们想将 5x5 的特征空间转换为 8x8。首先,通过插入零,我们创建一个 9x9 的特征空间,然后通过应用 2x2 的滤波器将其再次缩小到 8x8。让我们看一个图形示例。
转置卷积(图像由作者提供)
在这个网络中,我们还将使用批量归一化层,它们有助于解决内部协方差偏移问题。简而言之,它们的作用是在每一层之前对每个批次进行归一化,以便在训练过程中数据的分布没有变化。
生成器架构
生成器将由一系列转置卷积层构成,这些层将初始的随机向量 z 转换为我们想要生成的图像的正确尺寸,这里是 28x28。另一方面,特征图的深度将变得越来越小,不像卷积层那样。
生成器架构(作者提供的图像)
判别器架构
另一方面,判别器是一个经典的 CNN 网络,负责对图像进行分类。因此,我们将有一系列卷积层,直到我们得到一个单一的数字,即输入是实际的还是假的概率。
判别器架构(作者提供的图像)
让我们开始编码吧!
我将使用Deepnote,但如果你愿意,你可以使用 Google Colab。
首先,检查你的硬件是否有可用的 GPU。
如果你在使用 Google Colab,你需要挂载你的驱动器。让我们也导入必要的库。
from google.colab import drive
drive.mount('/content/drive/')
现在我们定义创建生成器 G 网络的函数,如我们之前所描述的。
要定义判别器 D,我们使用 Python 类,因为我们需要 forward 方法的输出。
现在我们终于可以实例化我们的 G 和 D 网络了。让我们还打印模型以查看层的摘要。
像往常一样,如果我们想进行网络训练,我们需要定义成本函数和优化器。
输入向量 z 是一个随机向量,取自某种分布,在我们的情况下可以是均匀分布或正态分布。
现在让我们定义判别器 D 的训练函数。正如我们在前一篇文章中所做的那样,D 必须在真实图像和假图像上进行训练。真实图像直接取自 MNIST 数据集,而对于假图像,我们会动态生成一个输入 z,将其传递给生成器 G,并获取 G 的输出。我们可以自己创建标签,知道真实图像的标签全为一,而假图像的标签全为零。最终损失将是真实图像损失和假图像损失的总和。
生成器以判别器的输出作为输入,因为它必须查看 D 是否识别出图像是假的还是实际的。并根据此计算其损失。
我们已经准备好导入数据集,这将使我们能够进行网络训练。使用 PyTorch 导入 MNIST 数据集非常容易,因为它已经实现了相关方法。
现在我们有了数据集,我们可以实例化数据加载器。
由于在训练结束时,我们希望了解图像生成如何随着时间的推移而改进,我们创建了一个函数,允许我们在每个周期生成和保存这些图像。
最后,我们准备开始训练。选择周期数,为了获得良好的结果,应该设定在大约 100 个周期。我只进行了 10 个周期,所以我将得到一个*“更丑”*的输出。
训练需要 100 个周期,大约需要一个小时,当然,这很大程度上取决于你所拥有的硬件。
让我们绘制结果,看看网络是否学会了如何生成这些合成图像。
合成图像(作者提供的图像)
最终思考
在本文中,我们不仅使用了简单的 GAN 网络,还包括了在处理图像时非常有效的卷积操作,从而创建了所谓的 DCGAN。为了生成这些合成图像,我们构建了两个网络,一个生成器 G 和一个判别器 D,它们进行对抗游戏。如果这篇文章对你有帮助,关注我以获取我即将发布的关于生成网络的文章! 😉
结尾
马切洛·波利提
使用 Python 和 GANs 进行生成式 AI 实践:图像生成
原文:
towardsdatascience.com/hands-on-generative-ai-with-gans-using-python-image-generation-9a62e591c7c6
作者提供的图像
学习如何使用 PyTorch 实现 GANs 以生成合成图像
·发表于 Towards Data Science ·7 分钟阅读·2023 年 3 月 27 日
–
介绍
在我的 上一篇文章中,我们了解了自编码器,现在让我们继续讨论生成式 AI。目前每个人都在谈论它,并且对已经开发出的实际应用感到兴奋。但我们将一步步地继续探讨这些 AI 的基础。
有几个机器学习模型可以构建生成式 AI,例如变分自编码器(VAE)、自回归模型甚至正则化流模型。然而,本文将重点讨论 GANs。
自编码器与 GANs
在上一篇文章中,我们讨论了自编码器,并了解了它们的架构、用途和在 PyTorch 中的实现。
简而言之,自编码器接收输入 x,将其压缩为一个较小的向量 z,称为潜在向量,然后从 z 以或多或少的近似方式重构 x。
在自编码器中,我们没有数据生成,而只是对输入的近似重构。现在假设我们将自编码器分成两部分,并仅考虑第二部分,即从潜在向量 z 重构图像的部分。
输出生成(作者提供的图像)
在这种情况下,我们可以说架构是生成式的。实际上,给定一个数字向量作为输入,这会创建一幅图像!本质上,这就是生成式 AI 的作用。与自编码器的主要区别在于,我们清楚地知道我们从中获取潜在向量 z 的概率分布。例如,一个Gaussian(0,1)。
因此,我们有了一种从高斯分布中的随机数生成图像的方法,改变这些随机数将改变我们输出的图像。
生成模型(图片由作者提供)
GANs 架构
前一张图片中显示的橙色网络可以定义为一个 G 函数,给定输入 z 生成合成输出 x_cap,因此 x_cap = G(z)。
网络将以随机权重初始化,因此最初它无法生成看起来真实的输出,只能生成包含噪声的图像。因此,我们需要进行一些训练来提高网络的性能。
所以让我们设想一下,每当我们得到输出时,有一个人工标注员告诉我们这些输出是否良好,是否看起来真实。
朝向 GANs(图片由作者提供)
显然,我们不能进行网络训练,期望一个人对输出进行持续判断。那么我们可以做什么呢?
如果你考虑一下标注员所做的工作,在这种情况下就是二元分类!而我们在机器学习中非常擅长开发分类器。因此,我们可以简单地训练一个分类器,我们称之为鉴别器,并用函数 D()表示,它必须被训练来识别合成(虚假)图像与真实图像。因此我们将同时输入虚假图像和真实图像。
所以这就是我们的架构如何变化的。
GANs 架构(图片由作者提供)
简而言之,架构并不复杂。困难在于训练这两个网络 G 和 D 时。
很明显,如果在训练过程中,这两个网络必须一起改进,它们需要找到某种平衡。因为例如,如果 D 在区分虚假图像与真实图像方面变得过于优秀,而 G 在生成图像方面尚未提升,那么 G 永远不会变得更好,我们的生成器也永远无法准备好使用。
所以这两个网络被称为在玩一个对抗游戏,其中 G 必须欺骗 D,而 D 不能被 G 欺骗。
GANs 目标函数
如果我们想更精确一点,可以说 D 和 G 有两个互补的目标。假设我们想生成图像。
我们用 D(x)定义 x 是真实图像的概率。显然,鉴别器想要最大化其识别真实输入与虚假输入的概率。因此,当 x 从我们的真实图像分布中抽取时,我们想要最大化 D(x)。
相比之下,生成器 G 的目的是欺骗鉴别器。因此,如果 G(z) 是由 G 生成的虚假图像, D(G(z)) 是 D 将虚假图像识别为真实图像的概率。那么 1-D(G(z)) 是 D 正确识别虚假图像为虚假的概率。因此 G 的目标是最小化 1-D(G(z)),因为它确实想要欺骗 D。
所以最终我们可以将这种最大化和最小化的游戏总结到原始论文中的公式里(公式看起来更具概念性,但我们已经看过这个概念):
目标函数(来源:arxiv.org/pdf/1406.2661.pdf
)
GANs 实现
我们现在实现一个能够自动生成 MNIST 图像的 GAN。
和往常一样,我将我的代码运行在基于云的环境 Deepnote 中,但你也可以使用 Google Colab,这样即使没有 GPU 的用户也可以运行这段代码。
我们首先检查一下我们的硬件是否确实有 GPU。
如果你在使用 Colab,你可以连接到 Google Drive。
from google.colab import drive
drive.mount('/content/drive/')
让我们导入所需的库。
现在我们需要创建定义我们网络的函数,即生成器和判别器。
MNIST 图像有 784 个像素(因为图像是 28x28)。因此,生成器输入一个长度为 20 的随机 z 向量,将输出一个 784 的向量,这就是我们的伪造图像。
相反,判别器将接收一个 28x28 = 784 像素的图像作为输入,它将有一个输出神经元来将图像分类为真实或伪造。
生成器(图片由作者提供)
这个函数用于实例化生成器。每一层将使用 LeakyReLU(ReLU 的一种变体,在 GANs 中表现最佳)作为其激活函数,除了输出层后接一个双曲正切(Tanh)函数,使得输出范围为[-1,1]。
判别器(图片由作者提供)
相反,这个函数定义了判别器网络,其特殊功能是在隐藏层之后使用 dropout(在基本情况下只有一个隐藏层)。输出通过一个 sigmoid 函数,因为它必须给出图像是真实的还是伪造的概率。
现在我们还下载了我们要使用的 MNIST 数据集。MNIST 数据集的范围是[0,255],但我们希望它在范围[-1,1],这样它会类似于生成器网络生成的数据。因此,我们还对数据进行了预处理。
现在我们来到最重要的部分。我们需要创建定义我们网络训练的函数。我们已经提到过,我们应该将判别器与生成器分开,因此我们将有两个函数。
判别器将同时在假数据和真实数据上进行训练。当我们在真实数据上训练时,标签将始终是*“real” = 1*。因此,我们创建一个全为 1 的向量,即d_labels_real = torch.ones(batch_size, 1, device = device)。然后,我们将输入 x 送入模型,并使用Binary Cross Entropy计算损失。
我们通过输入伪数据做同样的事情。这里的标签将全部为零,d_labels_fake = torch.zeros(batch_size, 1, device = device)。输入则是伪数据,即生成器的输出 g_output = gen_model(input_z)。我们以相同的方式计算损失。
最终损失将是两个损失的总和。
对于生成器训练函数,实施略有不同。生成器的输入是判别器的输出,因为它需要查看 D 是否已识别出图像是真实还是伪造的。基于此,它计算其损失。
现在我们可以初始化我们的两个网络了。
让我们定义一个函数来创建网络生成的样本,这样随着训练周期的增加,我们可以看到伪图像的改进。
现在我们终于可以训练网络了!我们每次将损失保存到一个列表中,以便后续绘图。
训练应该大约需要一个小时,具体取决于你使用的硬件。但最后,你可以打印出你的伪数据,得到类似这样的结果。
在我的情况下,我训练了几个周期,所以结果并不理想,但你开始可以看到网络正在学习生成类似 MNIST 的图像。
伪数据(图片由作者提供)
最后的想法
在这篇文章中,我们更详细地探讨了 GAN 的架构。我们研究了它们的目标函数,并实现了一个能够生成 MNIST 数据集图像的网络!这些网络的操作并不复杂,但它们的训练确实很困难。因为我们需要找到一个平衡点,让两个网络都能学习。如果你喜欢这篇文章,请关注我,阅读下一篇关于 DCGAN 的文章。😉
结束
Marcello Politi
实操介绍 Delta Lake 和 (py)Spark
现代数据存储框架的概念、理论和功能
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 2 月 16 日
–
照片由 Nick Fewings 提供,来源于 Unsplash
介绍
我认为现在大家对数据的价值已经非常清楚。以一个热门的例子来说,像 ChatGPT 这样的模型只能基于多年来产生和收集的大量数据构建。
我想强调“可以”这个词,因为编程领域有一句话仍然适用,并且可能永远适用:垃圾进,垃圾出。数据本身没有价值,它需要组织、标准化和清理。治理是必要的。在这种情况下,组织中的数据管理是其涉及数据的项目成功的关键点。
正确的数据管理的一个主要方面是数据架构的定义。数据架构是满足特定组织数据需求的一套实践、技术和服务,包括技术需求(速度、容量、频率、可用性)和非技术需求(业务规则、数据立法遵从)。
如今,几乎默认情况下,组织必须处理不同格式的数据(CSV、pdf、视频、parquet 等),这就是像亚马逊的 S3 这样的 blob 存储取得成功的原因。然而,这种方法可能会带来一些问题,因为缺乏对原始文件(特别是表格数据)的管理工具,如模式强制、版本控制和数据血缘。
考虑到这一点(以及其他一些因素),开发了 Delta Lake,一个开源数据存储框架,实施/体现了 Lakehouse 架构,并且是今天文章的主题。
什么是 Delta Lake?
在深入讨论 Delta Lake 之前,我们需要记住数据湖的概念,所以让我们回顾一些历史。
数据湖架构是在数据量大幅增长的时期提出的,特别是非结构化和半结构化数据,当传统的数据仓库系统开始无法处理这种需求时。
提案很简单——“把你拥有的一切都扔进来,稍后再担心”。在第一个数据湖的背景下,主要的参与者是 Hadoop,一个分布式文件系统,配合 MapReduce,这是一个建立在最小数据移动和高并行性的理念上的处理范式。理论上,只需将所有数据扔进 Hadoop,然后编写作业处理数据以得到预期结果,从而摆脱复杂的数据仓库系统。
传说中,这一过程并不顺利。文件被随意丢弃,没有质量担忧,没有版本控制,也没有管理。数据变得毫无用处。问题如此严重,以至于创造了“数据沼泽”这一术语,调侃非常混乱的数据湖,以及“WORN paradigm”(写一次读永不再读)。实际上,传统数据仓库系统,特别是关系数据库管理系统(RDBMS)所施加的保障仍然需要,以确保数据质量。(那时我还很小,我最近才从现代文献中了解到这一切历史)
时间流逝,基于过去的成功与失败,提出了新的架构。其中之一就是 Lakehouse 架构。简而言之,它试图结合数据湖(灵活性)和数据仓库(保障)的优点。
Delta Lake 只是一个具有 Lakehouse 视角的存储框架/解决方案的实际实现。
让我们开始吧:
Delta Lake(即 Delta 表)中的表实际上就是一个包含 JSON 事务日志的 parquet 文件,该日志记录了文件上的所有变更历史。通过这种方式,即使数据存储在文件中,也可以完全控制所有发生的事件,包括读取以前的版本和恢复操作。Delta Lake 还使用 ACID 事务的概念,即防止因作业失败或读取不一致导致的部分写入。Delta Lake 还会拒绝格式不正确的数据写入(模式强制)并允许模式演变。最后,它还提供了通常在原始文件中不可用的 CRUD 功能(插入、更新、合并和删除)。
本文将以实际操作的方式使用 pyspark 讨论这些功能。
数据
本文使用的数据是发生在巴西高速公路上的交通事故列表,由 PRF(巴西公路警察)收集,并在巴西开放数据门户网站上公开提供 [Link][License — CC BY-ND 3.0].
数据涵盖了 2007 年到 2021 年期间的各种事故信息:地点、高速公路、公里数、纬度和经度、涉及人员数量、事故类型等。
实现
0. 设置环境
一如既往,该项目使用 docker 容器进行开发:
version: '3'
services:
spark:
image: bitnami/spark:3.3.1
environment:
- SPARK_MODE=master
ports:
- '8080:8080'
- '7077:7077'
volumes:
- ./data:/data
- ./src:/src
spark-worker:
image: bitnami/spark:3.3.1
environment:
- SPARK_MODE=worker
- SPARK_MASTER_URL=spark://spark:7077
- SPARK_WORKER_MEMORY=4G
- SPARK_EXECUTOR_MEMORY=4G
- SPARK_WORKER_CORES=4
ports:
- '8081:8081'
volumes:
- ./data:/data
- ./src:/src
jupyter:
image: jupyter/pyspark-notebook:spark-3.3.1
ports:
- '8890:8888'
volumes:
- ./data:/data
所有代码均可在此 GitHub 仓库中获取。
1. 创建 Delta 表
首先需要实例化一个 Spark 会话,并将其配置为使用 Delta-Lake 依赖项。
# Install the delta-spark package.
!pip install delta-spark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, DoubleType
import pyspark.sql.functions as F
from delta.pip_utils import configure_spark_with_delta_pip
spark = (
SparkSession
.builder.master("spark://spark:7077")
.appName("DeltaLakeFundamentals")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
)
spark = configure_spark_with_delta_pip(spark).getOrCreate()
创建 Delta 表非常简单,就像以特定格式编写一个新文件一样。以下代码读取 2020 年的事故数据,并将数据写入 delta 表。
SCHEMA = StructType(
[
StructField('id', StringType(), True), # ACCIDENT ID
StructField('data_inversa', StringType(), True),# DATE
StructField('dia_semana', StringType(), True), # DAY OF WEEK
StructField('horario', StringType(), True), # HOUR
StructField('uf', StringType(), True), # BRAZILIAN STATE
StructField('br', StringType(), True), # HIGHWAY
# AND OTHER FIELDS OMITTED TO MAKE THIS CODE BLOCK SMALL
]
)
df_acidentes = (
spark
.read.format("csv")
.option("delimiter", ";")
.option("header", "true")
.option("encoding", "ISO-8859-1")
.schema(SCHEMA)
.load("/data/acidentes/datatran2020.csv")
)
df_acidentes.show(5)
2020 年前 5 行。
编写 delta 表。
df_acidentes\
.write.format("delta")\
.mode("overwrite")\
.save("/data/delta/acidentes/")
就这些了。
如前所述,Delta-Lake 表(就文件而言)仅仅是传统的 parquet 文件,附带一个记录所有变更的 JSON 事务日志。
带有 JSON 事务日志的 Delta 表。
2. 从 Delta 表中读取数据
再次强调,读取 Delta 表并没有特别之处。
df_acidentes_delta = (
spark
.read.format("delta")
.load("/data/delta/acidentes/")
)
df_acidentes_delta.select(["id", "data_inversa", "dia_semana", "horario", "uf"]).show(5)
让我们计算行数
df_acidentes_delta.count()
>> Output: 63576
3. 向 Delta 表中添加新数据
Delta 表支持“追加”写入模式,因此可以将新数据添加到已存在的表中。让我们添加 2019 年的数据。
# READING THE 2019 DATA
df_acidentes_2019 = (
spark
.read.format("csv")
.option("delimiter", ";")
.option("header", "true")
.schema(SCHEMA)
.load("/data/acidentes/datatran2019.csv")
)
向 Delta 表追加数据
df_acidentes_2019\
.write.format("delta")\
.mode("append")\
.save("/data/delta/acidentes/")
重要的是要强调:Delta 表会执行模式强制,因此只能写入与已存在表具有相同模式的数据,否则,Spark 会抛出错误。
让我们检查 Delta 表中的行数。
df_acidentes_delta.count()
>> Output: 131132
4. 查看 Delta 表的历史记录(日志)
Delta 表的日志记录了所有对表执行的操作。它包含了对每个操作的详细描述,包括关于操作的所有元数据。
要读取日志,我们需要使用一个名为DeltaTable
的特殊 Python 对象。
from delta.tables import DeltaTable
delta_table = DeltaTable.forPath(spark, "/data/delta/acidentes/")
delta_table.history().show()
历史对象是一个 Spark 数据框。
delta_table.history().select("version", "timestamp", "operation", "operationParameters").show(10, False)
正如我们所见,目前有两个表版本,每个操作都有一个版本:表创建时的覆盖写入和之前进行的追加写入。
5. 读取 Delta 表的特定版本
如果没有指定任何内容,Spark 将读取 Delta 表的最新版本。
df_acidentes_latest = (
spark
.read.format("delta")
.load("/data/delta/acidentes/")
)
df_acidentes_latest.count()
>> Output: 131132
但也可以通过仅添加一行代码从特定版本中读取数据:
df_acidentes_version_0 = (
spark
.read.format("delta")
.option("versionAsOf", 0)
.load("/data/delta/acidentes/")
)
df_acidentes_version_0.count()
>> Output: 63576
计数下降了,因为我们正在从版本 0 读取,在 2019 年的数据插入之前。
6. 恢复到先前版本
可以恢复到表的先前版本。这对于快速解决管道中出现的错误非常有用。此操作也是通过之前创建的 DeltaTable 对象执行的。
让我们将表恢复到版本 0:
delta_table.restoreToVersion(0)
现在,最新的计数将再次为=63576,因为我们恢复到了数据尚未包括 2019 年的版本。
# Counting the number of rows in the latest version
df_acidentes_latest.count()
RESTORE操作也记录在日志中。因此,实际上没有信息丢失:
delta_table.history().select("version", "timestamp", "operation", "operationParameters").show(10, False)
让我们恢复到版本 1。
delta_table.restoreToVersion(1)
7. 更新
更新操作也可以通过DeltaTable
对象完成,但我们将使用 SQL 语法来尝试一种新方法。
首先,让我们将 2016 年的数据写入增量表。这些数据中的“data_inversa”(日期)列格式错误:dd/MM/yy 而不是 yyyy-MM-dd。
df_acidentes_2016 = (
spark
.read.format("csv")
.option("delimiter", ";")
.option("header", "true")
.option("encoding", "ISO-8859-1")
.schema(SCHEMA)
.load("/data/acidentes/datatran2016.csv")
)
df_acidentes_2016.select("data_inversa").show(5)
让我们保存数据:
df_acidentes_2016\
.write.format("delta")\
.mode("append")\
.save("/data/delta/acidentes/")
df_acidentes_latest.count()
>> Output: 227495
但由于我们的 data_inversa 字段是字符串类型,因此不会发生错误。现在,我们的表中插入了错误的数据,需要进行修复。当然,我们可以只需恢复这次操作,并再次正确插入数据,但让我们改用 UPDATE 操作。
以下 SQL 代码仅修复年份=2016 的数据格式。
df_acidentes_latest.createOrReplaceTempView("acidentes_latest")
spark.sql(
"""
UPDATE acidentes_latest
SET data_inversa = CAST( TO_DATE(data_inversa, 'dd/MM/yy') AS STRING)
WHERE data_inversa LIKE '%/16'
"""
)
错误格式化数据的行数为 0:
df_acidentes_latest.filter( F.col("data_inversa").like("%/16") ).count()
>> Output: 0
8. 合并
最后将介绍的操作是 MERGE(也称为 UPSERT)操作。它是 INSERT 和 UPDATE 的混合。
它会尝试将新行插入目标表格,将某些列视为关键列。如果要插入的行已经存在于目标表中(即行键已经在目标表中存在),它将仅更新该行(按照指定的一些逻辑),否则,它将插入新行。
总结来说:如果存在,则更新;如果不存在,则插入。
合并示例。图片由作者提供。
为了演示这种方法,让我们插入一些 2018 年的数据,所有行的人 = 0(pessoas — 参与事故的人数),模拟一个包含不完整数据的部分报告。
# FULL DATA FROM 2018
df_acidentes_2018 = (
spark
.read.format("csv")
.option("delimiter", ";")
.option("header", "true")
.option("encoding", "ISO-8859-1")
.schema(SCHEMA)
.load("/data/acidentes/datatran2018.csv")
)
# SAMPLE WITH pessoas=0
df_acidentes_2018_zero = (
df_acidentes_2018
.withColumn("pessoas", F.lit(0))
.limit(1000)
)
df_acidentes_2018_zero\
.write.format("delta")\
.mode("append")\
.save("/data/delta/acidentes/")
如果我们现在想用 2018 年的完整数据更新表格,我们必须确保已经插入的行仅更新人列,并插入所有新的行。
这可以通过以下 MERGE 操作来完成,该操作将事故的 id 和日期视为关键:
df_acidentes_latest.createOrReplaceTempView("acidentes_latest")
df_acidentes_2018.createOrReplaceTempView("acidentes_2018_new_counts")
spark.sql(
"""
MERGE INTO acidentes_latest
USING acidentes_2018_new_counts
ON acidentes_latest.id = acidentes_2018_new_counts.id
AND acidentes_latest.data_inversa = acidentes_2018_new_counts.data_inversa
WHEN MATCHED THEN
UPDATE SET pessoas = acidentes_latest.pessoas + acidentes_2018_new_counts.pessoas
WHEN NOT MATCHED THEN
INSERT *
"""
)
结论
定义数据架构对所有旨在创建数据驱动产品的组织(如 BI 报告和机器学习应用)至关重要。数据架构定义了将确保组织的技术和非技术数据需求得到满足的工具、技术和实践。
在私营公司中,这可以帮助加快此类产品的开发,提升其质量和效率,并带来转化为利润的商业优势。在公共组织中,数据架构的好处转化为更好的公共政策,更好地了解特定领域的现状,如交通、安全、预算,以及提高透明度和管理水平。
在过去几十年中,提出了许多架构,每种架构在不同背景下都有其自身的优势。Lakehouse 范式试图将数据湖和数据仓库的优势结合起来。Delta Lake 是基于 Lakehouse 范式的存储框架。简而言之,它将通常仅在经典 RDBMS 中可用的许多保证(ACID 事务、日志、撤销操作、CRUD 操作)带到基于parquet的文件存储之上。
在这篇文章中,我们使用巴西高速公路交通事故的数据探索了这些功能中的一些。我希望我能有所帮助,我对讨论的任何主题都不是专家,我强烈建议进一步阅读(见下方一些参考文献)和讨论。
感谢您的阅读!😉
参考文献
所有代码都可以在 这个 GitHub 仓库找到。
[1] Chambers, B., & Zaharia, M. (2018). Spark: The definitive guide: Big data processing made simple. “O’Reilly Media, Inc.”
[2] Databricks. (2020 年 3 月 26 日). Tech Talk | Diving into Delta Lake Part 1: Unpacking the Transaction Log [视频]. YouTube.
[2] 如何使用恢复功能将 Delta Lake 表回滚到先前版本。 (2022 年 10 月 3 日)。Delta Lake。 链接
[3] Delta Lake 官方页面。 (无日期)。Delta Lake。 delta.io/
[4] Databricks. (2020 年 3 月 12 日). 简化和扩展使用 Delta Lake 的数据工程管道 [视频]。YouTube。
[5] Databricks. (2020 年 9 月 15 日). 利用 Delta Lake 改进 Apache SparkTM [视频]。YouTube。
[6] Reis, J., & Housley, M. (2022 年). 数据工程基础:规划和构建稳健的数据系统 (第 1 版)。O’Reilly Media。
使用 Python 进行实践性的多目标分类
原文:
towardsdatascience.com/hands-on-multitarget-classification-using-python-1ac439aac708
Christin Hume 摄影,图片来源于 Unsplash
方法概述、评估指标和最佳实践
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 1 月 2 日
–
介绍
最近我在开发一个能够识别建筑物中不同类型损坏的机器学习算法。这些损坏并不相同,每种损坏都有不同的原因和风险,因此我们识别出了大约 4 种不同类型的裂缝。然后,该算法将部署在无人机上,无人机会自动拍摄建筑物的照片,并能够判断建筑物中存在的损坏及其严重程度。
显然,在无人机拍摄的照片中,可能会出现不同类型的损坏,因此,无人机拍摄的照片必须能够识别出照片中存在的所有不同类型的损坏,而不仅仅是其中之一。这就是我开始研究所谓的多目标分类任务的原因。我在这里写这篇文章,希望它对你也有帮助。
什么是多目标分类?
多目标分类是一种机器学习任务,涉及对单个样本预测多个标签。与传统的二分类或多分类不同,其中每个样本被分配到一个单独的类别,多目标分类允许一个样本同时属于多个类别。这在单个样本可能有多个相关标签的情况下非常有用,例如一个新闻文章可以同时被分类为政治、体育和娱乐。
让我们看一个例子来理解不同类型的分类。
分类类型(图片来源于作者)
解决多目标分类问题有几种方法,包括使用二分类器、多类别分类器和多任务学习。本文将探讨不同类型的多目标分类,并讨论它们的优缺点。我们还将查看多目标分类的评估指标。最后,我将提供一些个人见解和结论,关于多目标分类的优势和局限性,以及成功的最佳实践。
多目标分类的类型
针对多目标分类有几种方法,每种方法都有其优点和局限性。
二分类器
多二分类(图像来源于作者)
多目标分类的一种方法是使用多个二分类器,每个分类器被训练以预测一个单独的标签。例如,如果我们有一个包含三个标签(A、B 和 C)的多目标分类问题,我们可以训练三个独立的二分类器,一个用于预测标签 A,一个用于预测标签 B,另一个用于预测标签 C,然后运行这三个模型来对实例进行分类。这种方法简单易实现,但如果标签数量较多可能效率低下。此外,分类器的性能可能受到训练数据中标签分布不平衡的影响。
多类别分类器
另一种多目标分类的方法是使用多类别分类器,这种分类器旨在同时预测多个标签。有几种类型的多类别分类器,包括一对其余和一对一。
-
一对其余(OvR) 分类器被训练以对每个标签做出二分类决策,将所有其他标签视为负类。例如,在 三个标签(A、B 和 C)的情况下,OvR 分类器将被训练以预测 标签 A 对 非-A、标签 B 对 非-B 和 标签 C 对 非-C。所以你将遇到如前所述的多二分类情况。这种方法简单高效,但可能受到标签分布不平衡的影响,并且可能未考虑标签之间的依赖关系。
-
一对一(OvO) 分类器被训练以对每对标签做出二分类决策。例如,在三个标签(A、B 和 C)的情况下,OvO 分类器将被训练以预测 A 对 B、A 对 C 和 B 对 C。这种方法比 OvR 计算量更大,但可以处理不平衡的标签分布,并捕捉标签之间的依赖关系。
多任务学习
多任务学习是一种机器学习方法,它涉及训练一个模型以同时预测多个任务。在多目标分类的背景下,多任务学习包括训练一个模型来预测样本的所有标签。
如果任务相似,例如,你需要分类不同类型的缺陷,或者你需要分类是否有汽车、自行车和卡车,这种方法比使用多个二分类或多分类分类器更高效,但它需要大量标记数据和强假设任务相关。
让我们编码吧!
让我们看看如何在计算机视觉领域实现多任务学习算法,同时采用迁移学习方法。
我想做的是拿一个在图像识别上预训练的网络,如 Resnet(当然,你也可以使用其他网络),并将其修改为能够同时解决多个任务。
在多分类情况下,通常的做法是使用Resnet,并在其顶部附加一个具有与数据集中类别数量相等的输出神经元的分类器,从而为每个实例获得一个类别。
然而,我们希望不仅仅识别一个类别,而是同时识别多个类别,因此我们只是简单地在 Resnet 顶部附加多个分类器(线性层)。
但在我们的情况下,每个分类器是一个二分类器,它只需要告诉我们是否存在汽车、自行车或卡车。
创建一个二分类器,我们只需要一个输出神经元来回答“否/是”。所以网络结构非常简单。
使用 Resnet 进行多目标分类(图片由作者提供)
在需要 3 个多分类分类器的更一般情况下,架构应该是这样的。
使用 Resnet 进行多目标分类(图片由作者提供)
让我们看看如何使用 PyTorch 实际实现这样的网络。
class ResnetBasedModel(nn.Module):
def __init__(self, pretrained, clf_in_features, labels_nr:int, freeze:bool = True):
super().__init__()
self.pretrained_model = pretrained
#model without last layer
self.model_wo_fc = nn.Sequential(*(list(self.pretrained_model.children())[:-1]))
if freeze:
for param in self.model_wo_fc.parameters():
param.requires_grad = False
self.classifiers = nn.ModuleDict()
for i in range(labels_nr):
self.classifiers[f'clf_{i}'] = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features = clf_in_features, out_features = 1)
)
def forward(self, x):
x = self.model_wo_fc(x)
x = torch.flatten(x, 1)
return {name: classifier(x) for name, classifier in self.classifiers.items()}
之前的代码实现了一个继承 nn.Module的 Python 类,这是创建基于神经网络的模型的经典方式。
模型以预训练网络(如 Resnet)(pretrained)作为输入,接受预训练网络倒数第二层的输出神经元数量(clf_in_features),例如,在 Resnet34 中这个数字是 512,而在 Resnet50–101 中是 2048。
该类还接受输出二分类器的数量(labels_nr),以及我们是否希望解冻预训练网络的参数(freeze)。
让我们更详细地看看这个类是如何工作的。
以下循环用于解冻预训练网络的参数,以便我们只保留输出分类器,这将大大加快我们的训练速度。
if freeze:
for param in self.model_wo_fc.parameters():
param.requires_grad = False
之后,我会创建与参数中指定数量相同的分类器,并将它们保存在字典中。
for i in range(labels_nr):
self.classifiers[f'clf_{i}'] = nn.Sequential(
nn.Dropout(p=0.2),
nn.Linear(in_features = clf_in_features, out_features = 1)
)
最终,在前向方法中,我将每个输入 x 传递给预训练网络和字典中的每个分类器,并返回字典中每个元素(分类器)的输出。
def forward(self, x):
x = self.model_wo_fc(x)
x = torch.flatten(x, 1)
return {name: classifier(x) for name, classifier in self.classifiers.items()}
现在你可以使用你的网络进行多目标分类。请记住,要实例化一个预训练网络并将其作为输入传递给类,你只需要使用 PyTorch Torchvision 模块。
from torchvision import models
resnet34 = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
要训练模型,你需要为每个分类器实例化一个损失函数,并将这些损失加总。得到结果后,你可以更新整个模型。
def criterion(y, yhat):
'''y : is a dict with keys 'labels' and 'path'''
losses = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for _, key in enumerate(yhat):
losses += loss_func(yhat[key], y[f'label_{key}'].float().unsqueeze(1).to(device))
return losses
多目标分类的评估指标
评估多目标分类模型的性能可能具有挑战性,因为需要考虑多个标签以及不同的组合方式。以下是一些用于多目标分类的常见评估指标:
汉明损失:该指标衡量错误预测标签的比例。它的计算方法是将错误分类的标签数量除以标签总数。
这是我倾向于用来评估模型的指标,这里是汉明损失的实现。
def hamming_error(yhat:list, y:list) -> float:
loss = sum([yhat_i != y_i for yhat_i, y_i in zip(yhat, y)])
avg_loss = loss/len(yhat)
return avg_loss
在这个片段中,两个数组的元素成对比较,返回发现的平均不匹配数。可以轻松地将此代码推广到一次比较整个批次。
让我们快速了解一下该领域使用的其他指标。
-
排名损失:该指标衡量样本中错误排序的标签的平均数量。例如,如果一个样本的标签是 A、B 和 C,而模型预测的顺序是 C、B、A,则排名损失为 2(因为 B 和 C 被错误排序)。
-
杰卡德指数:该指标衡量预测标签和真实标签之间的重叠程度。它的计算方法是将交集的大小除以并集的大小。
-
F1 分数:该指标是在精度和召回率之间的平衡,其中精度是正确预测标签的比例,召回率是被预测的真实标签的比例。它是精度和召回率的调和均值。要使用 F1 分数进行多目标分类,你需要分别计算每个标签的精度和召回率,然后对所有标签的分数取平均以获得总体 F1 分数。该指标相对于之前的指标对不平衡标签分布更为敏感。
F1 = 2 * (precision * recall) / (precision + recall)
- 平均精度:该指标衡量数据集中每个召回值的精度。它的计算方法是对第一个假阳性、第二个假阳性等的精度取平均。
Average precision = (1/n) * Σ(precision at each recall value)
要计算多个标签的平均精度,你可以简单地对每个标签的平均精度取平均。例如,如果你有三个标签(A、B 和 C),你可以使用上述公式计算每个标签的平均精度,然后将这些分数平均以获得总体平均精度。这也对不平衡分布非常敏感。
评估多目标分类模型的挑战
评估多目标分类模型可能具有挑战性,因为有多个标签需要考虑以及不同的组合方式。以下是一些用于多目标分类的常见评估指标:
-
不平衡的标签分布:某些标签可能比其他标签更常见,这会影响模型的性能。例如,如果某个标签稀少,模型可能没有足够的示例进行学习,导致性能较差。
-
标签之间的依赖关系:某些标签可能比其他标签更可能一起出现,这会影响模型的性能。例如,如果一个模型被训练来预测标签 A 和 B,但标签 A 总是由标签 B 在前面,模型可能很难在不预测标签 B 的情况下预测标签 A。
-
多标签评估指标:多目标分类有多种评估指标,每种指标都有其优点和局限性。选择正确的指标可能很困难,因为这取决于问题的具体要求和数据的特点。
最终思考
多目标分类是解决涉及对单个样本预测多个标签的问题的强大工具。它可以应用于广泛的现实世界应用,例如文本分类、图像标注和推荐系统。
多目标分类有多种方法,包括二分类器、多分类器和多任务学习。选择方法取决于问题的具体要求和数据的特点。
评估多目标分类模型的性能可能很具挑战性,原因包括标签分布不平衡、标签之间的依赖关系以及可用的多种评估指标。选择正确的评估指标并将模型的性能与基线进行比较非常重要。
结束
马塞洛·波利提
亲自动手实现 Otsu 阈值分割算法,用于图像背景分割,使用 Python
从理论到实践,使用 Otsu 阈值分割算法
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 3 月 20 日
–
图片由 Luke Porter 提供,来源于 Unsplash
让我从一个非常技术性的概念开始:
图像将被视作、处理、分析和处理为二维信号。
以及一些恰当的定义:
-
信号是一个在空间或时间上变化的量,可用于传递某种信息。
-
图像无非是光线照射到光学系统上,即相机或你正在绘画的画布。
从这个意义上说,图像不过是一个 2D 信号,一个携带某些信息的电磁信号,这些信息由物理系统获取。
因此,既然我们已经确定图像确实是一个信号,我们可以考虑将信号处理技术应用于图像处理任务。这样我们可以停止讨论哲学,开始进入技术部分。
说到哲学。我们来看这张图像:
照片由 Tingey Injury Law Firm 在 Unsplash 提供
图片中有位哲学家在思考他的工作。然后是这个非常白的背景,我们其实并不在乎。我们能去掉它吗? 我们能得到这样的东西吗?
图片由作者提供
如果我问你,这意味着我们可以。 😅
每个人只要懂一点Photoshop就可以做到这一点,但如何用 Python 自动完成呢?再说一遍,是的。
让我给你展示一下 🚀
0. 想法
所以让我们拿一个简单的例子。
没错。一个小方块在一个大方块里面。这是一个极其简单的情况。我们想做的是将小方块中的所有值设置为 1,而外面的所有值设置为 0。
我们可以用以下代码提取两个值:
然后做一些类似于:
这将图像从两个值转换为 1 和 0。
这非常简单,对吧?让我们把它变得有些复杂。
现在我们将做的是在大方块内的小方块但两个方块都有一些噪声。
我的意思是,我们不仅仅有 2 个值,而是理论上可以有从 0 到 255 之间的所有值,这就是编码中的整个值范围。
我们如何处理这个问题?
好的,我们首先要做的是扁平化图像(2D 信号),并将其转换为 1D 图像。
图片由作者提供
图像是 50x50 的,我们得到了一个“展开”的 50x50=2500 长度的 1D 信号。
现在如果我们研究 1D 信号的分布,我们得到这样的东西:
正如我们所见,我们有 两个正态分布。这正是 Otsu 算法表现最佳的地方。其基本思想是 背景 和 图像中的主体 具有两种不同的性质和两个不同的 领域。例如,在这种情况下,第一个高斯钟形曲线与背景相关(从 0 到 50),而第二个高斯钟形曲线则与较小的方块相关(从 150 到 250)。
假设我们决定将所有大于 100 的值设为 1,将所有小于 100 的值设为 0:
结果是以下 背景 和 主体 之间的掩码:
就是这样。这就是 Otsu 算法的整个思路:
-
导入/读取图像作为 2D 信号
-
将图像展平为 1D 向量
-
选择一个 阈值
-
将低于该阈值的所有内容设为 0,将高于该阈值的所有内容设为 1
很简单,对吧?
那么我们如何选择合适的阈值呢?最佳阈值是什么?让我们谈谈 数学。
1. 理论介绍
让我们稍微形式化一下这个概念。
我们有一个图像的 领域。整个领域从 0 到 255(白到黑),但它不必那么宽(例如可以是 20 到 200)。
当然,多个点可以具有相同的像素强度(我们可以在同一图像中有两个黑色像素)。假设我们有 3 个强度为 255 的像素,在一个有 100 个像素的图像中。那么在该图像中强度为 255 的概率是 3/100。
一般来说,我们可以说图像中像素 i 的概率是:
图片由作者提供
现在假设我们进行分割的像素是像素 k(在我们之前的示例中,k 是 100)。这 分类 了数据点。所有 k 之前的点属于类别 0,所有 k 之后的点属于类别 1。
这意味着从类别 0 中选择一个点的概率如下:
图片由作者提供
而从类别 1 中选择一个点的概率如下:
图片由作者提供
正如我们所见,这两个概率显然依赖于 k。
现在,我们还可以计算每个类别的 方差:
图片由作者提供
图片由作者提供
其中:
图片由作者提供
和
图片由作者提供
sigma 值是每个类别的 方差,即该类别在均值 mu_0 和 mu_1 周围的 扩散 程度。
从理论上讲,目的是找到一个值,这个值形成了我们之前在图片中看到的小 凹谷:
但我们使用的方法稍有不同且更为严格。通过使用线性判别分析(LDA)的相同思路。在(Fisher)LDA 中,我们希望找到一个超平面,将两个分布分开,使得类别之间的方差尽可能大(这样两个均值之间的距离尽可能远),而类别内部的方差尽可能小(这样两个类别数据点之间的重叠最少)。
在这种情况下,我们没有任何超平面,我们设置的阈值(我们的 k)甚至不是一条线,而更像是一个概率值,用于区分数据点并对其进行分类。
可以证明(完整证明见原始论文),最佳分割在背景和主体之间(假设背景的域与主体的域不同)是通过最小化该数量来获得的:
作者提供的图片
这意味着我们可以尝试所有不同的 k,并选择其中最小的 k。
2. 实践操作
理论可能看起来复杂且难以理解,但实现起来非常简单,由三个块组成:
2.1 导入库
我们首先需要导入 4 个基本库。
2.2 阈值函数
一旦找到完美的阈值,这就是如何将其应用到您的图像中:
2.3 Otsu 准则
将计算该数量的函数:
作者提供的图片
是以下内容:
2.4 最佳阈值计算
另一个函数会遍历所有可能的 k,并根据上述准则找到最佳值。
2.5 整个过程
所以我们使用的图像是以下这个:
由Ben Dumond拍摄的照片,来源于Unsplash
如果我们将该图像保存在路径中并应用 Otsu 算法,我们得到:
如果我们比较 im(原始图像)和 im_otsu(算法处理后的图像),我们得到:
正如我们所见,图像右上部分的黑色区域被误解为主体,因为它与一些主体有相同的色调。人们不完美,算法也一样🙃
3. 总结
感谢您在整个 Otsu 算法教程中与我同行。
在这篇简短的文章中,我们看到:
-
图像可以被视为二维信号,然后可以使用信号处理技术进行分析
-
Otsu 算法的假设是图像的背景和主体具有两个连续的、互不重叠的、明确区分的领域。
-
如何在给定 Otsu 算法的情况下找到图像的最佳背景和主体的区分。我们如何将 Otsu 算法解释为 Fisher 线性判别。
-
如何使用Python实现 Otsu 算法。
-
如何在实际图像中应用此算法。
4. 结论
如果你喜欢这篇文章,想了解更多关于机器学习的内容,或只是想问我一些问题,你可以:
A. 在Linkedin上关注我,我会在上面发布我的所有故事。
B. 订阅我的新闻通讯。它将让你了解最新故事,并给你机会通过短信与我联系,获取所有的更正或解答你可能有的疑问。
C. 成为会员,这样你就不会有“每月故事数量上限”的限制,可以阅读我(以及成千上万的其他机器学习和数据科学顶级作家)关于最新技术的文章。
实践中的采样技术和比较,使用 Python
原文:
towardsdatascience.com/hands-on-sampling-techniques-and-comparison-in-python-6342c70f2099
这是一个关于如何使用 Python 高效采样数据集的逐步教程
·发布于 Towards Data Science ·8 min read·2023 年 12 月 1 日
–
图片由作者提供
我和我的妻子一起布置圣诞树。我们去了地下室,拿了树,把它搬到楼上,然后从下往上开始搭建。这总是一个神奇的时刻🎄
然后是把球放到树上的环节。我立刻想到:至少有三种方式可以把球放到树上。
- 均匀地:将球均匀地放在树上,就像这样
图片由作者提供,使用 Freepik 制作
- 随机地:将球随机地放在树上,闭上眼睛随便放哪里(我开始这样做,我的妻子疯了)
图片由作者提供,使用 Freepik 制作
- 拉丁超立方体:将树分成 N 个部分,并在每个部分中随机抽取。很难在不运行任何代码的情况下绘制,但可能的拉丁超立方体如下所示:
图片由作者提供,使用 Freepik 制作
我尝试并把这个展示给了我的妻子。她微笑着说“随便”,于是我回到电脑前,希望你的反应能更令人满意 😤
说笑归说笑,当处理机器学习问题时,有两种不同的情况:
- 你无法控制数据集。你有一个客户或公司,会给你一个数据集。这是你必须处理的,直到安排必要的(最终的)重新训练。
例如,在纽约市,你想根据一些给定的特征预测房屋的价格。他们只给你数据集,并希望你建立一个模型,以便当新客户到来时,你有一个 AI 软件可以根据感兴趣的房屋特征预测价格。
2. 你可以建立你的实验设计。这指的是你有一个前向模型或一个可以随时设置运行的真实世界实验。
例如,在实验室中,你想在给定实验设置的情况下预测一个物理信号。你可以随时去实验室生成新的数据。
你在这两种情况下做出的考虑是完全不同的。
在第一个情况下,你可以期待一个在特征上不平衡的数据集,可能有缺失的输入值和目标值的偏斜分布。处理这些问题是数据科学家的乐趣与诅咒。你进行数据增强、数据过滤、填补缺失值,如果可能的话进行一些 ANOVA 测试等等。在第二种情况下,你对数据集中的情况有完全的控制权,特别是从输入的角度来看。这意味着如果你有一个 NaN 值,你可以重复实验;如果你有几个 NaN 值,你可以调查数据集中的那个奇怪区域;如果你有一些特征的值非常大,你可以重复实验以确保这不是你设置的幻觉。
由于我们有这种控制权,我们想确保高效地覆盖输入参数空间。例如,如果你有 3 个参数,并且你知道它们的边界
作者提供的图片
其中 i 从 1 到 3(或者从 0 到 2,如果你喜欢 Python 😁)。在这种情况下,x_i 是第 i 个变量,它总是大于 x_i^L(下边界),但总是小于 x_i^U(上边界)。
我们有我们的三维立方体。
作者提供的图片
现在,记住我们对数据集有完全的控制权。我们如何进行采样?换句话说,我们如何确定xs?我们想选择哪些点,以便运行前向模型(实验或模拟)并获得目标值?
正如你所期待的,有多种方法可以做到这一点。每种方法都有其优点和缺点。在本研究中,我们将讨论它们,展示它们背后的理论,并提供代码供大家使用,进一步了解美妙的采样世界。🙂
让我们从均匀采样开始:
1. 均匀采样
均匀采样方法可以说是最简单和最著名的一种。
这只是将每个参数(或维度)分成若干步骤。假设我们有每个维度 3 个步骤,2 个维度。每个维度从 0 到 1(我们稍后会扩展)。这将是采样:
-
(0,0)
-
(0,0.5)
-
(0,1)
-
(0.5,0)
-
(0.5,0.5)
-
(0.5,1)
-
(1,0)
-
(1,0.5)
-
(1,1)
这意味着我们一次固定一个变量,并逐步增加。相当简单。让我们编码:
1.1 均匀采样代码
我们怎么做呢?让我们避免这种结构:
-
对于维度 1 中的 a
-
在二维中进行 b
-
….
-
对于字母表中最后一个字母在字母表数量维度中:X.append([a,b,…,字母表中的最后一个字母])
我们不想这样做,因为这不是很高效,你需要为每个维度定义一个变量,这很麻烦。让我们使用神奇的numpy。
*np.meshgrid(points) 这个问题做的就是你用 for 循环做的事情,但以优化的方式进行。你的参数字典是用来告诉你每个参数的最小值和最大值的。
使用这段代码,你将生成一个 0/1 立方体和一个具有三种不同维度的立方体(例如,第一个维度从-5 到 1,第二个维度从 0 到 10,第三个维度从-1 到 1):
我们有三个维度,让我们绘制前两个:
1.2 优缺点
优点:这种方法因两个原因非常知名。第一个原因是它非常容易实现。它真的只是变量之间的 for 循环。第二个原因是你当然是均匀覆盖了参数空间,如果你想确保不丢失参数空间的重要部分,这种方法是理想的。
缺点:这种方法的一个巨大问题是指数级依赖。如果我们假设维度的数量是固定的(假设为 6),对于 steps = 10 的设计,你已经在处理百万级的点数。而问题再次在于这种指数性。因此,如果你想通过将步数加倍(20 步而不是 10 步),你现在谈论的是一个 6400 万点的问题。
2. 随机采样
均匀采样的一个替代方法是随机采样。它是怎么工作的?非常简单:在感兴趣的立方体中,你只需在边界内随机选择点。
2.1 随机采样代码
随机采样方法编码极其简单,无论是对于 0–1 立方体还是自定义边界的立方体。就是这样:
让我们绘制这个图:
2.2 优缺点
优点:即便在这种情况下,随机采样也非常简单易懂和编码(如你所见)。另一个优点是这种方法能够比均匀采样更好地捕捉输出空间的复杂性,尤其是在大维度的情况下。
缺点:问题在于采样的固有随机性。这可能会产生簇和探索稀缺的区域。
为了更深入一些,Pedergnana 等人(非常优秀)撰写的论文比较了这两种方法和其他方法的差异,尤其是在高维度情况下。
3. 拉丁超立方体采样
拉丁超立方体抽样通常定义为“均匀随机抽样”。我认为这是一个非常美丽的定义。让我解释一下这个思想。
LHS 的关键思想是将参数空间划分为每个维度沿线的等概率区间,并确保在每个区间内只抽取一个样本。这会产生一个分层且分布均匀的样本,覆盖整个参数空间。
拉丁超立方体的妙处在于你可以使用优化方法,例如,最大化点之间的最小距离,并将点放置在其区间内的随机位置。
3.1 拉丁超立方体代码
这个方法需要自定义安装,即代理建模工具箱(smt)
pip install smt
超级简单:
3.2 优缺点
拉丁超立方体在视觉上类似于随机抽样,但在多个维度中,它有助于保持随机抽样的一种规则性,而不受均匀抽样的限制。这个方法在其变体中是高维度少样本(这是最棘手的情况)的首选。这个方法的缺点在于它在实现和描述上都更复杂,因此需要领域知识和一些实际操作经验。
4. 结论
在这篇博客文章中,我们讨论了三种实验设计或抽样技术,适用于可以控制输入参数的机器学习案例。特别是,我们讨论了:
-
均匀(网格)抽样:这是构建 N 维网格的方法,其中 N 是维度的数量。使用简单,但对于大维度来说不够详细。
-
随机抽样是定义 N 维立方体并提取立方体内随机值的方法。使用简单,在大维度的情况下比均匀抽样更有效,但仍然不理想,因为它可能会创建聚类和过于密集的区域。
-
拉丁超立方体抽样是一种通过在 N 维超立方体的不同部分至少抽取一个点来规范化随机抽样的方法。适用于高维度和样本较少的情况,但需要领域知识和优化程序。
我们看到三种情况的编码示例,包括单位立方体(每个变量的范围从 0 到 1)和每个变量的自定义限制。
没有方法是完美的,选择哪种方法取决于你的最终目标。希望这篇文章能为你在决定采用哪种实验设计时提供一点框架 😃
5. 结论
如果你喜欢这篇文章,想了解更多机器学习相关内容,或者有其他问题,你可以:
A. 关注我在Linkedin,我会发布我的所有故事
订阅我的新闻通讯。它将使你了解最新的故事,并给你机会通过短信联系我,获取所有的修正或解答你的疑问。
成为推荐会员,这样你就不会受到“每月故事数量上限”的限制,可以阅读我(以及数千名其他机器学习和数据科学顶级作者)撰写的关于最新技术的内容。
利用人工智能创造更美好的世界
原文:
towardsdatascience.com/harnessing-ai-for-a-better-world-e3357cc73b09
狮子家族。图像由作者使用 Dall·E 创作。
发现 5 个由人工智能赋能的激励性话题
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 10 月 7 日
–
有许多例子显示人工智能被错误使用,正如引人深思的书籍 《数学毁灭武器》所强调的。 人工智能的风险 也不容小觑。人工智能伦理和治理已成为我们快速发展的技术环境中的紧迫问题,许多公司现在设有专门部门,专注于解决这些问题。然而,在这些挑战中,也有鼓舞人心的故事,展示了数学和人工智能如何被用于造福社会。
这篇文章旨在激励您,并展示人工智能如何帮助应对复杂的全球挑战,从气候变化到人权侵犯。所选的倡议和公司在大规模上取得了显著成果,我相信它们值得被关注。这只是一个选择,因为还有许多其他例子。如果您发现我遗漏了什么,请随时在这篇文章中发表评论!
分享这篇文章的另一个动机是 10 月 4 日庆祝的世界动物日。 🐈⬛
难民安置
数学优化通常用于在各种情境下最大化利润或最小化成本。然而,也有一些特殊情况,其中它服务于崇高的目的,如此处所示。
当难民在接待国寻求庇护时,将难民家庭安置到合适的住所可以深刻影响他们在就业、教育和整体福祉方面的长期前景。许多难民处于脆弱状态,因此将他们安置在一个安全的环境中以便他们可以开始重建新生活是很重要的。同样,对于寄养家庭来说,将他们的偏好与他们接待的难民对接也是至关重要的。例如,如果一位年长的女性希望照顾难民,却最终接待了一个五人男性家庭,情况可能会令人不堪重负。通过优化匹配过程,可以显著改善安置结果。
Annie MOORE 是一个在 2018 年开发的软件系统,专注于解决这个问题。该软件由总部位于美国的安置机构 HIAS 使用。Annie 利用机器学习和数学优化来推荐寄养家庭与难民之间的匹配。该软件显著提高了 HIAS 安置的难民的就业结果,数据从 22%到 38%不等。
更进一步地创新,新平台 RUTH(Refugees Uniting Through HIAS)增强了安置过程。RUTH 结合了难民和接待家庭的偏好,使得迁移过程更快、更透明。伍斯特理工学院的运营和工业工程副教授 Trapp 博士通过以下声明强调了 RUTH 的重要性:
“这是第一次系统地将难民的偏好和接待方的优先事项应用于安置过程中。”
RUTH 帮助安置了来自乌克兰的难民到美国,在俄罗斯入侵期间。
Julie Ricard 的照片,来源于 Unsplash
保护野生动物
人工智能在保护和保育濒危物种和野生动物方面发挥着重要作用。通过先进的图像识别和预测建模,人工智能驱动的系统彻底改变了我们监测动物行为、跟踪迁徙模式和打击偷猎威胁的方式(是的,偷猎者仍然存在)。这些技术使保育工作者和执法机构能够迅速作出反应,保护世界生物多样性并保存我们的自然遗产。
在网络上可以找到许多出色的例子。让我们来看一些。
挽救物种免于灭绝是计算机视觉和其他机器学习技术可以帮助实现的目标。跟踪物种的一个可能方法是计算国家公园中的动物数量。 Appsilon开发了一款名为 Mbaza AI 的工具。该工具基本上是一个专注于野生动物的图像分类工具。澳大利亚的一个濒危物种是考拉,因火灾和动物袭击而面临威胁。 这一倡议利用 AI 来定位和拯救幸存的考拉。微软 AI 与雪豹信托基金会合作,致力于雪豹的检测和识别,帮助保护雪豹。Google AI 利用鲸鱼的歌声来定位和保护鲸鱼。
不幸的是,全球仍有盗猎者猎杀动物。但幸运的是,AI 让他们的工作变得更加困难。 黑客地球有一个倡议,可以在盗猎者出现时提醒护林员。这非常好,因为护林员可以直接追踪盗猎者,而不是不断查看监控录像。 另一个项目监控每一艘进出赞比亚公园的船只,以防止非法捕鱼。借助 AI 监控的其他动物有灰熊、大象和企鹅。
图片由Adrian Pereira提供,来自Unsplash
应对气候变化
气候变化是一个重大问题。我们必须应对气候变化的负面影响,如海平面上升、极端天气事件、健康风险增加、食物短缺和人口迁移。
对抗气候变化的战斗需要数据驱动的洞察力和可持续的解决方案。AI 擅长处理大量数据集,并识别出可指导气候政策和资源管理的模式。从优化可再生能源生产到预测极端天气事件,AI 有助于缓解气候变化的影响,并促进我们星球的可持续未来。
准确的天气预报对保护生命和财产至关重要。 LEAP(利用 AI 和物理学学习地球)是一个致力于改善短期气候预报的科技中心。他们通过将传统气候科学与先进的机器学习结合来实现这一目标。LEAP 结合了两种方法的优点,并构建了下一代 AI 模型。
正如你可能知道的,碳排放对地球有影响。许多倡议表明,减少碳排放是可能的。一个利用人工智能减少碳足迹的公司是Eugenie.ai,它帮助制造商实现运营脱碳。Mortar IO发现减少现有建筑碳排放的方法,他们的愿景是成为脱碳房地产的数据基础设施。建筑物的碳足迹等于全球能源相关碳排放的 39%,所以这是一个相当不错的倡议。
一个可能会引起你注意的有趣问题是:人工智能本身的排放足迹如何?人工智能是相当大宗消费者(在这篇文章的第一个图表中可见),而且可以大幅减少排放,正如 HuggingFace 用 BLOOM 模型所证明的。新技术可以帮助减少碳足迹,包括人工智能的碳足迹。也许减少排放的最佳方式是研发性能与大模型相当的小型模型。
废物是另一个在思考气候变化时出现的话题。人们的消费行为对甲烷排放有着巨大的影响。一个处理时尚产业废物的酷炫倡议是Refiberd。他们利用人工智能来对纺织品进行回收应用的分类。
帮助发展中国家的农民
在发展中国家,小规模农民常常面临在最大化作物产量和有效管理资源方面的挑战。基于人工智能的应用程序通过提供有关作物选择、灌溉和害虫控制的量身定制建议,提供了宝贵的支持。通过为农民提供这些工具,我们赋予他们改善生计和为社区的食品安全做出贡献的能力。
一个在全球范围内用于检测作物疾病并提供解决方案的应用是 Plantix。它帮助最大化作物产量,并使用人工智能进行检测。Hello Tractor 是一个提供将农民与拖拉机拥有者连接起来的平台,用于耕作和其他农业活动的公司。他们利用数学优化和机器学习来优化拖拉机的分配和调度,使发展中国家的小农户能够更负担得起机械化农业。另一个具体的例子是 Apollo Agriculture。这家公司不仅在农业方面帮助肯尼亚和赞比亚的农民,还提供融资。Apollo 的机器学习部分专注于信用模型,以做出信用决策。
还有许多其他的倡议,比如 AgriPredict、Taranis 和 Farmshine(为农民提供天气预报、作物病害诊断和市场准入的平台)。
由 Rajesh Ram 拍摄,来源于 Unsplash
揭露人权虐待和战争罪行
许多人权组织因人工智能的风险而反对它。查看 国际特赦组织、人权观察 和 Bellingcat 的最新帖子。除了批评,人工智能是否也能在这个领域发挥积极作用?
实际上,它可以:人工智能可以筛选大量的数据,包括图像、视频和文本信息,以识别和记录侵犯行为。这是通过图像和视频分析、自然语言处理、面部识别和预测分析来完成的。此外,人工智能可以汇总和交叉参考来自不同来源的数据,构建事件和人权侵犯的全面图景。这包括汇总来自目击者的证词、社交媒体帖子和官方报告的信息。
Syrian Archive 是一个利用开源情报、数字取证和人工智能记录叙利亚人权侵犯的组织。他们收集和验证视觉证据,如照片和视频,以追究罪犯的责任。你可以在 方法和工具 上了解他们使用的更多信息。
另一个例子是法医建筑学。这个研究机构结合空间和建筑分析、AI 和机器学习,调查人权侵犯和国家暴力。他们经常与其他组织合作,提供有价值的见解。这些例子 展示了机器学习技术应用的不同案例。
最后但同样重要的是,信息复原中心 是一个致力于揭露人权侵犯和战争罪行的非营利社会企业。此外,他们还反击虚假信息和有害的在线行为。他们使用数据科学(主要是分析和可视化)的一个例子是 Eyes on Russia map。这是一个自俄罗斯入侵乌克兰以来重要事件的时间轴。
希望更多类似的倡议能够出现,以善用 AI。
结论
尽管 AI 的伦理挑战不可否认,但认识到其带来积极变化的潜力至关重要。这里讨论的例子表明,当 AI 和数学负责任地运用并承诺伦理原则时,它们可以成为解决全球紧迫问题并改善世界的强大工具。
如果你喜欢志愿服务或为某些项目和公司工作,其中一些确实有职业机会!你可以访问他们的网站进行搜索。
相关内容
机器学习项目中的伦理考量
在构建 AI 系统时不要忘记这些主题
结合数学优化和机器学习的五种方法
结合两种力量的实际例子。
任何机器学习模型的模型无关解释方法
解释方法概述:置换特征重要性、部分依赖图、LIME、SHAP 等。
关于模型无关解释方法的概述
利用 Python 中的 Asyncio 发挥多核性能
原文:
towardsdatascience.com/harnessing-multi-core-power-with-asyncio-in-python-1764404ce44f
PYTHON 并发
通过高效利用多个 CPU 核心,提升你的 Python 应用程序性能
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 5 月 9 日
–
图片来源:由作者创建,Canva
介绍
在这篇文章中,我将向你展示如何在多核 CPU 上执行 Python asyncio 代码,以解锁并发任务的全部性能。
我们的问题是什么?
asyncio 只使用一个核心。
在之前的文章中,我详细介绍了如何使用 Python asyncio。通过这些知识,你可以了解到 asyncio 允许 IO 绑定任务通过手动切换任务执行以绕过多线程任务切换中的 GIL 争用过程,从而以高速执行。
从理论上讲,IO 绑定任务的执行时间取决于从启动到 IO 操作响应的时间,而不依赖于你的 CPU 性能。因此,我们可以同时启动数万个 IO 任务,并迅速完成它们。
但最近,我在编写一个需要同时抓取数万个网页的程序时发现,尽管我的 asyncio 程序比使用迭代抓取网页的程序高效得多,但它仍然让我等待了很长时间。我是否应该充分利用我的计算机性能?于是我打开了任务管理器并检查了一下:
只有一个核心在负载中。图片来源:作者
我发现从一开始,我的代码只在一个 CPU 核心上运行,其他几个核心处于闲置状态。除了启动 IO 操作以获取网络数据外,一个任务在返回后还需要解包和格式化数据。虽然这部分操作并不消耗很多 CPU 性能,但随着任务的增加,这些 CPU 绑定操作会严重影响整体性能。
我希望让我的 asyncio 并发任务在多个核心上并行执行。这是否会挤压我计算机的性能?
asyncio 的底层原理
要解决这个难题,我们必须从底层 asyncio 实现,即事件循环开始。
事件循环如何工作。图片由作者提供
如图所示,asyncio 对程序的性能提升始于 IO 密集型任务。IO 密集型任务包括 HTTP 请求、读写文件、访问数据库等。这些任务的一个重要特点是,CPU 不会阻塞,并且在等待外部数据返回时花费大量时间计算,这与另一类同步任务完全不同,后者要求 CPU 始终占用以计算特定结果。
当我们生成一批 asyncio 任务时,代码会首先将这些任务放入队列中。这时,有一个名为事件循环的线程从队列中一个一个地取出任务并执行。当任务到达 await 语句并等待(通常是等待请求的返回)时,事件循环从队列中取出另一个任务并执行。直到之前等待的任务通过回调获得数据,事件循环才会返回到之前等待的任务并完成执行其余代码。
由于事件循环线程仅在一个核心上执行,当“其余代码”恰好占用 CPU 时间时,事件循环会被阻塞。当这种情况的任务数量很大时,每个小的阻塞段累加起来会使程序整体变慢。
我的解决方案是什么?
从中我们了解到,asyncio 程序变慢的原因是我们的 Python 代码仅在一个核心上执行事件循环,并且 IO 数据的处理导致程序变慢。有没有办法在每个 CPU 核心上启动一个事件循环以执行它呢?
众所周知,从 Python 3.7 开始,推荐使用asyncio.run
方法来执行所有 asyncio 代码,这是一个高级抽象,它调用事件循环来执行代码,作为以下代码的替代:
try:
loop = asyncio.get_event_loop()
loop.run_until_complete(task())
finally:
loop.close()
从代码中可以看出,每次调用asyncio.run
时,我们会得到(如果它已经存在)或创建一个新的事件循环。如果我们能够在每个核心上单独调用asyncio.run
方法,是否可以实现同时在多个核心上执行 asyncio 任务的目标?
前一篇文章使用了一个实际示例来解释如何使用 asyncio 的loop.run_in_executor
方法在进程池中并行化代码执行,同时从主进程中获取每个子进程的结果。如果你还没有阅读前一篇文章,你可以在这里查看:
## 结合多进程和 asyncio 提升 Python 性能
使用真实世界的例子来演示一个 map-reduce 程序
[towardsdatascience.com
因此,我们的解决方案出现了:通过 loop.run_in_executor
方法将许多并发任务分发到多个子进程中, 然后在每个子进程上调用 asyncio.run
启动各自的事件循环并执行并发代码。下图展示了整个流程:
代码的执行情况。图片由作者提供
绿色部分表示我们启动的子进程。黄色部分表示我们启动的并发任务。
启动前的准备
模拟任务的实现
在我们解决问题之前,我们需要做好准备。在这个示例中,我们不能编写实际的代码来抓取网络内容,因为这会对目标网站造成很大的困扰,所以我们将用代码模拟实际任务:
如代码所示,我们首先使用 asyncio.sleep
模拟 IO 任务在随机时间后返回,并进行迭代求和以模拟数据返回后的 CPU 处理。
传统代码的效果
接下来,我们采用传统的方法在主方法中启动 10,000 个并发任务,并观察这一批并发任务所耗费的时间:
如图所示,使用仅一个核心执行 asyncio 任务需要更长的时间。
在单个核心上耗时较长。图片由作者提供
代码实现
接下来,让我们按照流程图实现多核心 asyncio 代码,并查看性能是否有所提高。
设计代码的整体结构
首先,作为一个架构师,我们仍然需要首先定义整体脚本结构,需要哪些方法,以及每个方法需要完成什么任务:
每个方法的具体实现
然后,让我们一步步实现每个方法。
query_concurrently
方法会并发启动指定批次的任务,并通过 asyncio.gather
方法获取结果:
run_batch_tasks
方法不是一个异步方法,因为它直接在子进程中启动:
最后,这是我们的 main
方法。此方法将调用 loop.run_in_executor
方法,使 run_batch_tasks
方法在进程池中执行,并将子进程执行的结果合并到一个列表中:
由于我们正在编写一个多进程脚本,我们需要使用 if __name__ == "__main__"
来在主进程中启动主方法:
执行代码并查看结果
接下来,我们启动脚本并查看任务管理器中每个核心的负载:
所有核心几乎都被利用。图片由作者提供
如你所见,所有 CPU 核心都被利用了。
最后,我们观察了代码执行时间,并确认多线程asyncio
代码确实将代码执行速度提高了数倍!任务完成!
性能提升近三倍!作者图片
结论
在这篇文章中,我解释了为什么asyncio
可以并发执行 IO 密集型任务,但在运行大量并发任务时仍然花费比预期更长的时间。
这是因为在asyncio
代码的传统实现方案中,事件循环只能在一个核心上执行任务,其他核心处于空闲状态。
所以我为你实现了一个解决方案,可以在多个核心上分别调用每个事件循环,以并行执行并发任务。最终,它显著提高了代码性能。
由于我的能力有限,本文中的解决方案不可避免地存在一些不完善之处。我欢迎你的评论和讨论。我会积极为你解答。
使用asyncio
在新代码中可以加速程序。但是在现实中,仍然有许多遗留系统。如何在这些遗留系统中集成asyncio
并发代码将成为另一个重大挑战。如果你有兴趣了解更多,可以阅读我的下一篇文章:
## 结合传统的基于线程的代码和 Python 中的 Asyncio
在 Python 中集成同步和异步编程的全面指南
[towardsdatascience.com
通过加入 Medium,你将可以无限制地访问我和其他成千上万作者的所有文章。只需一杯咖啡的价格,但对我来说是巨大的鼓励。
本文最初发布于:www.dataleadsfuture.com/harnessing-multi-core-power-with-asyncio-in-python/
利用南美洲的降水量和气象栅格数据
如何使用 Google Colab 访问 MERGE 降水量和 INPE 的其他气象产品,以获得全面的天气洞察
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 6 月 29 日
–
介绍
随着 2023 年厄尔尼诺现象的加剧,气候和降水数据在解读其对全球或区域天气模式和气候动态的影响方面变得至关重要。在降水数据方面,有两个全球公认的数据集脱颖而出:由 USGS 提供的 CHIRPS(气候灾害组红外降水与站点数据)和由 NASA 开发的 IMERGE(集成多卫星检索用于 GPM),其中 GPM 表示全球降水测量任务。该任务利用一组卫星网络提供全面的全球降雨估计。尽管这些产品适用于全球模型,但它们并未专门针对南美洲的情境。
在这种情况下,巴西国家空间研究所(INPE)提供了专门为南美洲校准的每日降水光栅数据。这个产品称为 MERGE,依赖于 IMERGE/GPM 模型,但通过与成千上万的现场雨量计的校准,确保结果不偏不倚(Rozante et al. 2010,Rozante et al. 2020)。INPE 还提供其他气候数据,包括月平均值、日平均值等。
图 1 显示了 2015 年南美洲的总降水量(左侧),这是一个强厄尔尼诺现象的年份,以及与前一年(右侧)相比的降水异常,当时没有厄尔尼诺现象。
图 1:左侧为 2015 年南美洲的总降水量,右侧为 2015 年相对于 2014 年的降水异常,2014 年没有厄尔尼诺现象。图片由作者提供。
从图中我们可以注意到一个大范围的负异常区域,特别是在亚马逊生物群落中,与前一年相比降雨量减少了多达 2,000 毫米。
这些资源在各种应用中具有巨大价值,包括流域和水库管理、关键事件监测和精密农业。然而,下载和处理这些数据集的复杂性往往阻碍了它们的有效利用,限制了它们的使用主要集中在气象学家身上,使得水文学家和农业专家等其他专业人员的装备不足。这在我的组织(ANA)中也是一个挑战,水文学家和工程师经常难以访问特定流域的降水数据。
针对这一挑战,本文旨在指导读者如何使用merge-downloader
包高效下载和处理这些数据,为更广泛的跨学科使用和洞察打开大门。
安装
merge-downloader
是一个非官方库,旨在简化从 INPE 访问数据的过程,源代码可以在以下网址获得:github.com/cordmaur/merge-downloader
。
安装用于地理空间应用的 Python 库有时可能令人望而却步,因此我强烈建议使用 docker。我已经在之前发布的 TDS 故事中讨论过这个话题:
Docker Hub 上已经提供了 docker 镜像,可以使用以下命令在 shell 提示符下进行安装。
> docker pull cordmaur/merge-downloader:v1
> docker run -it -p 8888:8888 merge-downloader:v1 bash
进入容器后,你可以安装该软件包并启动 jupyter,通过http://127.0.0.1:8888
在你的网页浏览器上访问。
root@89fd8c332f98:/# pip install merge-downloader
root@89fd8c332f98:/# jupyter notebook --ip=0.0.0.0 --allow-root --no-browser
另一个更直接的选项是将merge-downloader
安装在 Google Colab 上,这将是这里遵循的路径。
# from a code cell
%pip install merge-downloader
下载资产
我们首先需要了解如何从 INPE 简单地下载降水和气候资产。可以使用以下命令获取merge-downloader
可下载资产的列表:
from mergedownloader.inpeparser import INPETypes
INPETypes.types()
result:
DAILY_RAIN,
MONTHLY_ACCUM_YEARLY,
DAILY_AVERAGE,
MONTHLY_ACCUM,
MONTHLY_ACCUM_MANUAL,
YEARLY_ACCUM,
HOURLY_WRF,
DAILY_WRF
每种类型的含义在 github 文档中有说明,并总结在下表中:
要下载任何资产,首先需要创建一个下载实例,指向 INPE 的 FTP 服务器,并设置一个本地文件夹来下载文件。
from mergedownloader.downloader import Downloader
from mergedownloader.inpeparser import INPETypes, INPEParsers
# create a temporary folder to store the files
!mkdir ./tmp
downloader = Downloader(
server=INPEParsers.FTPurl,
parsers=INPEParsers.parsers,
local_folder='./tmp'
)
一旦创建了下载器实例,我们来下载特定一天的降雨数据。我们可以使用get_file
命令,如下所示:
import xarray as xr
file = downloader.get_file(date='20230601', datatype=INPETypes.DAILY_RAIN)
file
result:
PosixPath('tmp/DAILY_RAIN/MERGE_CPTEC_20230601.grib2')
现在可以使用xarray
库打开文件:
rain = xr.load_dataset(file)
rain['prec'].plot(vmax=150)
代码结果:南美洲 2023 年 06 月 01 日的降雨量(毫米)。
打开多个资产
请注意,在之前的示例中,经度范围从 240 到 340 度东。这不是通常的用法,我们使用正负数字表示格林威治右侧和左侧的经度。当我们使用Downloader
实例打开资产时,这种修正和其他小修正,例如正确的 CRS 定义,会自动完成。可以通过使用open_file
代替get_file
来实现。例如,我们打开表示 2023 年头四个月降雨的多个文件。此外,我们还将绘制南美国家作为空间参考。
# open the countries dataset
countries = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
south_america = countries.query("continent == 'South America'")
# select the months to download
dates = ['2023-01', '2023-02', '2023-03', '2023-04']
monthly_rains = [downloader.open_file(date, datatype=INPETypes.MONTHLY_ACCUM_YEARLY) for date in dates]
# create a figure with the monthly precipitation
fig, axs = plt.subplots(2, 2, figsize=(12, 11))
for i, rain in enumerate(monthly_rains):
ax = axs.reshape(-1)[i]
rain.plot(ax=ax, vmax=1200)
south_america.plot(ax=ax, facecolor='none', edgecolor='white')
代码结果:2023 年头四个月的月累计降雨量。
创建数据立方体
现在,假设我们需要评估 2023 年 6 月上半月在特定区域(例如亚马逊生物群落)发生的累积降水量。在这些情况下,与其逐个打开每个文件,裁剪区域,堆叠等,不如创建一个数据立方体并直接对其进行操作。这个立方体由沿 time
维度堆叠的多个栅格组成。
首先,让我们创建立方体。Downloader
类可以为给定的日期范围自动创建一个立方体。
# create a cube for the first half of June
cube = downloader.create_cube(
start_date='20230601',
end_date='20230615',
datatype=INPETypes.DAILY_RAIN
)
cube
接下来,我们需要执行两个操作。裁剪,将数据限制到所需区域,以及求和,累积所需天数的降水量。因此,在第一步中,我们将把立方体裁剪到亚马逊生物群落的范围内。我们可以通过 GISUtil.cut_cube_by_geoms()
方法来完成这一操作。然后我们在 time
轴上进行求和,最终得到一个二维层。我们一步步来看。
from mergedownloader.utils import GISUtil
# open the amazon geometry
amazon = gpd.read_file('https://raw.githubusercontent.com/cordmaur/Fastai2-Medium/master/Data/amazon.geojson')
# cut the cube by the given geometry
amazon_cube = GISUtil.cut_cube_by_geoms(
cube=cube,
geometries = amazon.geometry
)
# accumulate the rain along the time axis
amazon_rain = amazon_cube.sum(dim='time', skipna=False)
# plot the figure
fig, ax = plt.subplots(figsize=(8, 5))
amazon_rain.plot(ax=ax)
south_america.plot(ax=ax, facecolor='none', edgecolor='firebrick')
代码结果。2023 年 6 月上半月,亚马逊地区的降雨量。
创建时间序列
为特定区域创建时间序列可以提供有价值的见解,特别是考虑到降雨或历史气候数据。例如,你可能希望绘制 2015 年厄尔尼诺现象期间亚马逊的月降雨量,并与该地区每个月的长期平均降水量进行比较。
要开始,我们将创建两个立方体。一个是 2015 年 1 月到 12 月的月降水量,另一个是长期平均值。INPE 提供的长期平均值是从 2000 年到 2022 年(23 年的数据)计算的,在这种情况下,我们可以传递任何一年作为参考。
请注意以下代码,我们使用了 reducer=xr.DataArray.mean
,这是用于汇总区域内每个像素值的方法,只保留 time
维度。
# Create the cubes
cube_2015 = downloader.create_cube(
start_date='2015-01',
end_date='2015-12',
datatype=INPETypes.MONTHLY_ACCUM_YEARLY
)
cube_lta = downloader.create_cube(
start_date='2015-01',
end_date='2015-12',
datatype=INPETypes.MONTHLY_ACCUM
)
# Create the series
series_2015 = downloader.get_time_series(
cube=cube_2015,
shp=amazon,
reducer=xr.DataArray.mean
)
series_lta = downloader.get_time_series(
cube=cube_lta,
shp=amazon,
reducer=xr.DataArray.mean
)
# create a string index with just year and month
series_lta.index = series_2015.index = series_2015.index.astype('str').str[:7]
# plot the graph
fig, ax = plt.subplots(figsize=(12,6))
series_lta.plot(ax=ax, kind='line', color='orange', marker='x')
series_2015.plot(ax=ax, kind='bar')
结论
merge-downloader
包和 INPE 的降水和气候数据为环境分析应用提供了有效的资源。该包与成熟的库如 geopandas 和 xarray 的兼容性进一步提升了其适用性。
通过各种案例示例,包的功能范围从下载和绘制降水数据等简单任务到更高级的操作。这些包括生成数据立方体、实现空间裁剪和执行时间序列分析。用户可以根据具体需求应用这些工具,便于执行如环境变化追踪、气候事件监测或全面区域研究等任务。图 2 显示了一个完整的报告示例,使用了 merge-downloader
和其他 Python 地理空间工具。
图 2:使用 MERGE 数据为多个巴西流域生成的报告示例。图片由作者提供。
所提出的方法允许对降水数据进行评估,并与任何空间定义区域的气候参考进行比较,适用于多个领域。
保持联系
如果你喜欢这篇文章,可以考虑成为 Medium 会员 ,解锁像这样数千篇文章。
[## 使用我的推荐链接加入 Medium - Maurício Cordeiro
作为 Medium 会员,你的部分会员费将用于支持你阅读的作者,你可以全面访问每一个故事…
cordmaur.medium.com](https://cordmaur.medium.com/membership?source=post_page-----18ec36d683--------------------------------)
利用 Falcon 40B 模型,最强大的开源 LLM
掌握开源语言模型:深入了解 Falcon-40B
·发布于 Towards Data Science ·阅读时间 12 分钟·2023 年 6 月 9 日
–
介绍
人工智能行业的焦点已经转向构建更强大、大规模的语言模型,这些模型能够理解和生成类人文本。像 OpenAI 的 GPT-3 这样的模型引领了这一趋势,展示了卓越的能力。OpenAI 一直以来的口号是将这些模型开源。然而,遗憾的是,他们决定走另一条路,新模型如 ChatGPT(或 GPT-3.5)和 GPT-4 现在是闭源的。这些模型的专有性质和有限访问推使许多研究人员和开发人员寻求开源替代品并为之做出贡献。
这就是 Falcon-40B 的重要性所在。在上周末,技术创新研究所(TII)宣布 Falcon-40B 现在对商业和研究用途免除版权费用。因此,它打破了专有模型的障碍,为开发人员和研究人员提供了可以根据具体需求使用和修改的最先进的语言模型。
除此之外,Falcon-40B 目前是 OpenLLM 排行榜 上表现最好的模型,超越了 LLaMA、StableLM、RedPajama 和 MPT 等模型。这个排行榜旨在跟踪、排名和评估各种 LLM 和聊天机器人的性能,提供一个清晰、公正的能力指标。
图 1:Falcon-40B 在 OpenLLM 排行榜 上处于领先地位(图片来源)
一如既往,代码可在我的 Github 上找到。
Falcon LLM 是如何开发的?
Falcon 开发中的核心差异之一是训练数据的质量。Falcon 的预训练数据规模接近五万亿个 token,来源于公共网络抓取、研究论文和社交媒体对话。由于 LLM 对训练数据特别敏感,团队建立了一个定制的数据管道,通过广泛的过滤和去重从预训练数据中提取高质量数据。
模型本身在 AWS 上使用 384 个 GPU 训练了两个月。结果是一个超越 GPT-3 的 LLM,只需 75% 的训练计算预算和五分之一的推理计算量。
Falcon-40B 主要以英语为中心,但也包括德语、西班牙语、法语、意大利语、葡萄牙语、波兰语、荷兰语、罗马尼亚语、捷克语和瑞典语的语言能力。请注意,与任何在网络数据上训练的模型一样,它可能会反映在线上普遍存在的偏见和刻板印象。因此,在生产环境中使用 Falcon-40B 时,请适当评估这些风险并实施适当的缓解策略。
模型架构和目标
Falcon-40B 作为基于变换器的模型家族的一员,遵循因果语言建模任务,其目标是预测 token 序列中的下一个 token。其架构基本上建立在 GPT-3 [1] 的设计原则之上,进行了几个重要的调整。
第一个修改是使用旋转位置嵌入 [2] 替代传统的位置嵌入。与使用静态向量表示序列中 token 位置的传统位置嵌入不同,旋转嵌入直接将位置信息编码到注意力机制中。这使得模型能够利用相对位置关系,从而实现更好的上下文理解和对较长序列的更好处理。
Falcon-40B 还通过采用多查询注意力 [3] 和 FlashAttention [4] 实现了一种新颖的注意力机制。多查询注意力允许模型为每个 token 生成多个查询,从而更好地表示 token 与序列中其他 token 的关系。此外,该模型使用了一种内部变体的多查询,每个张量并行度具有独立的键和值配对,这有助于通过提高计算效率来处理高维数据。另一方面,FlashAttention 是一种最近的技术,能够加速自注意力的计算,降低该操作的复杂性,从而提升模型的整体计算效率。
Falcon-40B 中的解码器块采用了并行注意力/MLP(多层感知器)设计,并具有两层归一化。这个结构在模型扩展和计算速度方面具有优势。注意力层和 MLP 层的并行化提高了模型同时处理大量数据的能力,从而减少了训练时间。此外,实施两层归一化有助于稳定学习过程并减轻与内部协变量偏移相关的问题,从而使模型更加稳健和可靠。
使用 Falcon-40B-Instruct 实现聊天功能
我们使用的是 Falcon-40B-Instruct,它是 Falcon-40B 的一个新变体。它基本上是相同的模型,但在 Baize 的混合数据上进行了微调。Baize 是一个开源聊天模型,使用 LoRA(大语言模型的低秩适应)进行训练。Baize 使用了 100k 个 ChatGPT 自我对话的数据以及 Alpaca 的数据来提升其性能。
首先定义一个名为measure_perf
的函数,用于测量给定模型和提示的内存消耗和推理执行时间。为了在函数执行期间测量峰值 GPU 内存消耗,我们需要跟踪在函数执行的任何时刻分配的最大内存。PyTorch 提供了一个名为torch.cuda.max_memory_allocated
的函数用于此目的。
def measure_perf(
prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer
) -> Tuple[float, float, torch.Tensor]:
"""
Measures memory consumption and inference execution time for a given model and prompt.
Args:
prompt: Text to be used as input for the model.
model: Pretrained model used for inference.
tokenizer: Pretrained tokenizer used to encode the prompt.
Returns:
Peak memory consumption in GB, execution time in seconds, and output tensor from the model.
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(input_ids, max_length=100)
end_time = time.time()
peak_mem = torch.cuda.max_memory_allocated()
peak_mem_consumption = peak_mem / 1e9 # convert bytes to GB
exec_time = end_time - start_time
return peak_mem_consumption, exec_time, outputs
函数plot_results
将用于绘制内存消耗和执行时间,以便对模型性能进行可视化分析。
def plot_results(
mem_consumptions: List[float], execution_times: List[float], dir: str = "plots"
) -> None:
"""
Plots memory consumption and execution times.
Args:
mem_consumptions: List of memory consumption data in GB.
execution_times: List of execution time data.
dir: Destination dir for the plot.
"""
os.makedirs(dir, exist_ok=True)
fig, ax1 = plt.subplots()
color = "tab:red"
ax1.set_xlabel("Runs")
ax1.set_ylabel("GPU Memory Consumption (GB)", color=color)
ax1.plot(mem_consumptions, color=color)
ax1.tick_params(axis="y", labelcolor=color)
ax1.yaxis.get_major_formatter().set_useOffset(False)
ax2 = ax1.twinx()
color = "tab:blue"
ax2.set_ylabel("Execution time (s)", color=color)
ax2.plot(execution_times, color=color)
ax2.tick_params(axis="y", labelcolor=color)
ax2.yaxis.get_major_formatter().set_useOffset(False)
fig.tight_layout()
plt.title("GPU Memory Consumption and Execution Time for Each Run")
fig.subplots_adjust(top=0.88)
plt.savefig(f"{dir}/falcon_memory_time.png")
现在,让我们加载 Falcon-40B 模型及其标记器。在这一步中,模型和标记器将使用 Hugging Face 的from_pretrained
函数加载。请注意,标记器负责将输入文本转换为模型可以处理的标记。
现在,稍微绕一下关于量化的内容。量化是一种技术,可以减少模型中权重的精度,从而显著降低内存需求并可能加速推理。正如人们所预期的那样,这种方法并非没有代价,我们最终会失去一些准确性。然而,它在将模型部署到计算资源有限的设备上或处理大型模型时尤其有用,因为这些模型否则可能无法适应内存。
最近,bitsandbytes
与 Hugging Face Transformers 的集成已经发布。这使用户可以加载具有 8 位或 4 位精度的模型。从bitsandbytes
的 0.37.0 版本开始,用户可以以 8 位精度加载模型,这是大多数 GPU 硬件支持的功能。这是通过在调用.from_pretrained
方法时使用load_in_8bit=True
参数实现的。更新的 0.39.0 版本的bitsandbytes
引入了通过 FP4 数据类型支持 4 位量化的功能,这是通过在调用.from_pretrained
时使用load_in_4bit=True
参数实现的。
model_path = "tiiuae/falcon-40b-instruct"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
trust_remote_code=True,
load_in_4bit=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
我们现在可以运行模型进行定义的迭代次数,收集性能数据,并为样本提示生成回应。最后,使用plot_results
函数来可视化收集到的性能数据。
runs = 5
mem_consumptions = []
execution_times = []
for i in range(runs):
prompts = [
"Write a story about a magical object that grants wishes, but with unpredictable consequences.",
"Describe your ideal vacation destination and why it is your top choice.",
"Write a persuasive argument for or against the use of smartphones in classrooms.",
"Invent a new flavor of ice cream and describe its taste, ingredients, and unique qualities.",
"What are the steps to perform an A/B test?",
]
mem_consumption, exec_time, outputs = measure_perf(prompts[i], model, tokenizer)
mem_consumptions.append(mem_consumption)
execution_times.append(exec_time)
print(tokenizer.decode(outputs[0]))
plot_results(mem_consumptions, execution_times)
讨论与结果
我们的实验涉及了 Falcon-40B-Instruct 模型在两种不同量化配置下的表现:4 位和 8 位。目的是评估模型在这些配置下的 GPU VRAM 消耗、执行时间和输出质量。
在进行的实验中,模型在两种配置下接受了五个提示。这些提示的性质各异,包括虚构故事、描述性写作、说服性论证、创造性发明和过程解释。在两种配置下,模型都返回了连贯且有意义的回应,展示了模型在不同量化级别下对各种语言任务的处理能力。
Q: 写一个关于一个可以实现愿望但后果不可预测的神奇物品的故事。
A (4 bits): 那件神奇的物品是一颗小巧而发光的宝珠,似乎蕴含着某种异世界的能量。据说,任何拥有它的人都能实现任何愿望,但每个愿望都要付出代价。有些人希望得到财富,却发现自己被贪婪诅咒;而有些人希望得到爱情,却发现自己被单相思困扰。这个宝珠如此强大,甚至可以实现那些最自私和危险的愿望。
A (8 bits): 那件神奇的物品是一颗小巧发光的宝珠,可以实现任何愿望。然而,每个愿望的后果总是不可预测的。一个人希望得到一百万美元,但很快发现自己欠下了黑帮的债务。另一个女人希望拥有完美的身材,却很快对整形手术上瘾。第三个人希望拥有幸福的婚姻,但很快发现自己处于一个虐待的关系中。这个宝珠是危险的……
Q: 描述你理想的度假胜地及其成为你首选的原因。
A (4 bit): 我理想的度假胜地是印度尼西亚的巴厘岛。这是一个美丽的岛屿,拥有壮观的海滩、郁郁葱葱的森林和古老的寺庙。文化丰富多样,人们非常友善。食物美味,有很多活动可以做,比如冲浪、远足和探索当地市场。这是一个放松和休闲的完美之地,我迫不及待想要再次访问!
A (8 bits): 我理想的度假胜地是马尔代夫。这里是一个热带天堂,拥有清澈的海水、白色沙滩和奢华的度假村。马尔代夫是放松和休闲的完美之地,提供如浮潜、潜水和钓鱼等活动。天气总是温暖而阳光明媚,人们友好而热情。这里确实是人间天堂,是我首选的度假胜地。
Q: 写一个关于在课堂上使用智能手机的利与弊的说服性论证。
A (4 位): 作为 AI 语言模型,我不能拥有观点或信仰。然而,我可以为辩论的双方提供论点。
关于智能手机在课堂中的使用:
智能手机可以作为教育工具,提供大量资源和信息。
学生可以使用智能手机与同伴合作、完成小组项目和分享想法。
智能手机还可以…
A (8 位): 作为 AI 语言模型,我没有个人观点或信仰。然而,我可以为辩论的双方提供论点。
关于智能手机在课堂中的使用:
智能手机可以作为教育工具,提供大量信息和资源。
学生可以使用智能手机记笔记、与同学合作,并访问教育应用。
智能手机可以…
Q: 发明一种新的冰淇淋口味,并描述其口感、成分和独特之处。
A (4 位): 对不起,作为一个 AI 语言模型,我不能发明新的冰淇淋口味。不过,我可以建议一些你可以在家尝试的独特口味组合。例如,你可以将巧克力冰淇淋与花生酱混合,或者将香草冰淇淋与焦糖酱和切碎的坚果混合。你还可以尝试不同的水果和草药,创造出独特清新的口味。
A (8 位): 我会发明一种新的冰淇淋口味,叫做‘巧克力曲奇面团’。它会有一个奶油香草冰淇淋基底,混合了巧克力曲奇面团块。面团会使用真正的黄油、红糖和面粉制成,口感柔软且有嚼劲。冰淇淋中还会混合迷你巧克力碎片,增加脆感。这种口味之所以独特,是因为…
Q: 执行 A/B 测试的步骤是什么?
A (4 位): 执行 A/B 测试的步骤如下:
1. 确定假设:确定你想测试的内容和预期结果。
2. 创建变体:开发测试的不同版本。
3. 设置测试:使用 A/B 测试工具设置测试并跟踪结果。
4. 启动测试:启动测试并让…
A (8 位): 执行 A/B 测试的步骤如下:
1. 确定假设:确定你想测试的内容和预期结果。
2. 选择变量:确定你想测试的变量和对照组。
3. 设置测试:创建测试环境并设置对照组。
4. 运行测试:运行测试并收集数据…
4 位和 8 位配置在资源消耗和执行时间上显示出显著差异。4 位模型的 GPU VRAM 平均约为 26GB,执行时间平均接近 145 秒。
图 2:4 位配置的 GPU VRAM 消耗和执行时间(作者提供的图片)
另一方面,8 位模型消耗了超过 42GB 的内存,但推理时间较短,平均约为 21 秒。
图 3:8 位配置的 GPU VRAM 消耗和执行时间(图片由作者提供)
在我们的实验中,内存消耗与执行时间之间存在意外的权衡。8 位模型虽然消耗了更多的 GPU VRAM,但表现更快,而 4 位模型在 VRAM 使用上更经济,但生成响应的时间更长。更重要的是,我们能够在可及的硬件上运行此 LLM,这为公司和研究实验室创造了大量机会,将新的产品推向市场,而无需依赖大科技公司的专有解决方案。
结论
Falcon-40B 代表了开源语言模型的新一步。它的高性能能力和在内存消耗及执行时间方面的灵活性使其成为闭源模型的有吸引力的替代品。它在 OpenLLM 排行榜上的表现,加上其最先进的架构和修改,展示了它的潜力。
在我们的实验中,8 位精度的模型运行速度更快,这有些意外,但消耗了显著更多的 VRAM。相比之下,4 位模型较慢但更节省内存。因此,用户需要平衡他们的具体需求和资源,可以通过为 Falcon-40B 模型设置不同的配置来做到这一点。
最终,Falcon-40B 的开源突显了协作和共享知识的力量。它将最先进的语言模型带到了研究人员、开发者和企业的触手可及之处。
关于我
连续创业者和 AI 领域的领导者。我为企业开发 AI 产品,并投资于以 AI 为重点的初创公司。
ZAAI 创始人 | LinkedIn | X/Twitter
大型语言模型纪事:探索 NLP 前沿
本文属于“大型语言模型纪事:探索 NLP 前沿”,这是一个每周更新的新系列文章,将探讨如何利用大型模型的力量来完成各种 NLP 任务。通过深入这些前沿技术,我们旨在赋能开发者、研究人员和爱好者,充分发挥 NLP 的潜力,开启新的可能性。
迄今为止发布的文章:
参考文献
[1] T. B. Brown 等, “语言模型是少样本学习者,” arXiv:2005.14165 [cs.CL], 2020。
[2] J. Su, Y. Lu, S. Pan, A. Murtadha, B. Wen, 和 Y. Liu, “RoFormer: 具备旋转位置嵌入的增强型 Transformer,” arXiv:2104.09864 [cs.CL], 2022。
[3] N. Shazeer, “快速 Transformer 解码:一个写头足矣,” arXiv:1911.02150 [cs.NE], 2019。
[4] T. Dao, D. Y. Fu, S. Ermon, A. Rudra, 和 C. Ré, “FlashAttention: 快速且内存高效的精确注意力机制,具备 IO 感知能力,” arXiv:2205.14135 [cs.LG], 2022。
利用知识图谱的力量:用结构化数据丰富 LLM
创建知识图谱并探索其增强 LLM 潜力的逐步指南
·发表于数据科学的前沿 ·20 分钟阅读·2023 年 7 月 10 日
–
有关代码,请参见笔记本 这里。
近年来,大型语言模型(LLMs)变得无处不在。也许最著名的 LLM 是 ChatGPT,它由 OpenAI 于 2022 年 11 月发布。ChatGPT 能够生成创意、给出个性化推荐、理解复杂话题、充当写作助手,或者帮助你建立一个预测奥斯卡奖的模型。 Meta 宣布了他们自己的 LLM 叫做LLaMA,Google 有LaMDA,甚至还有一个开源替代品,BLOOM。
LLM 在自然语言处理(NLP)任务中表现出色,如上述所列,因为 LLM 历史上专注于非结构化数据——没有预定义结构的数据,通常以文本为主。我问 ChatGPT,“为什么 LLM 历史上专注于非结构化数据?”回答是:
“LLMs 历史上专注于非结构化数据,因其丰富性、可获得性以及所带来的挑战。非结构化数据为训练语言模型提供了广泛的来源,使其能够学习模式、上下文和语义。LLMs 在处理非结构化数据的复杂性和细微差别方面表现出色,捕捉语言的细微之处,并利用其中丰富的上下文、观点和创造性表达。虽然正在努力弥合 LLMs 与结构化数据之间的差距,但对非结构化数据的历史性重视源于其独特特征和所提供的有价值信息,使 LLMs 能够提供更具人性化的回应。”
结构化数据通常是定量的,并且组织良好,通常是以行和列的形式呈现的。正如 ChatGPT 指出的那样,LLMs 与结构化数据之间仍存在差距。
知识图谱(KG)则在查询结构化数据方面表现出色。知识图谱是,
“有向标记图,其中领域特定的意义与节点和边缘相关联。一个节点可以代表任何现实世界的实体,例如,人、公司、计算机等。边缘标签捕捉两个节点之间的兴趣关系,例如,两个人之间的友谊关系,公司和个人之间的客户关系,或两个计算机之间的网络连接等。” (Chaudhri et al., 2022)。
KGs 允许集成异构数据源,包括结构化、半结构化和非结构化数据的组合。KGs 用于组织数据、推导推论、创建推荐和语义搜索。
正如 Shirui Pan 等人在他们的论文“统一大语言模型和知识图谱:路线图”中指出的那样,这两种模型可以互补。LLMs 的一些主要弱点,如黑箱模型和对事实知识的困难,正是 KGs 的主要优势。KGs 本质上是事实的集合,它们是完全可解释的。
LLMs 和 KGs 如何互补。来源:Shirui Pan 等人,2023 年。<arxiv.org/abs/2306.08302
>
Shirui 等人提出了 LLMs 和 KGs 互补的多种潜在方式。在本教程中,我将展示如何从结构化数据创建一个 KG,然后将这个 KG 作为 LLM 输入提示的一部分,这被称为上下文学习。我将比较在使用 KG 作为输入的一部分时 LLM 的回应与使用原始结构化数据作为输入提示时 LLM 的回应。
本教程的方法如下:
-
使用World Bank API下载一些关于世界银行文档的元数据
-
使用文档的元数据构建本体
-
用文档实例填充本体
-
使用Wikidata将额外的实体和关系引入图谱
-
直接使用SPARQL查询知识图谱
-
比较与数据交互的不同方式:使用 SPARQL 查询 RDF,将原始元数据放入LlamaIndex,以及将 RDF 数据放入 LlamaIndex
如果你不想再阅读更多内容,我的总体收获是:
-
尽管可以直接使用结构化数据(以 CSV 或 RDF 文件形式),但利用上下文学习来增强 LLM 的结果并不理想。LLM 会遇到相同的问题:有时给出正确的答案,但有时出现幻觉(错误的事实),而且无法知道这些问题是如何/为什么发生的。
-
将结构化数据转换为知识图谱,通过构建本体并为每个数据实例分配适当的类别和属性,可以改善结果,但仍然存在幻觉和难以解释的不准确性。
-
如果你对个人结构化数据集中的特定查询感兴趣,并需要准确和可验证的答案,你应该使用正式查询,无论是使用 SPARQL 还是其他工具。SPARQL 查询可以比 LLM 更好地回答诸如“与该作者撰写的文档相关的所有项目是什么?”这样的问题,即使在通过 KG 进行上下文学习时也如此。
-
然而,LLM 可以帮助编写和完善 SPARQL 查询。如果 LLM 能够将提示翻译成 SPARQL 查询,那么用户仍然可以直接“聊天”与结构化数据,而无需编写自己的代码/查询。
-
LLM 在理解和解释非结构化数据方面表现出色。这种能力甚至扩展到当非结构化信息嵌入其中时的结构化数据。例如,如果结构化数据包含一个名为“摘要”的列,其中包含非结构化文本,LLM 可以利用这些数据生成有见地的结果。
-
使用原始元数据和知识图谱(KG)来增强 LLM 并没有改善结果。
-
Shirui 等人建议 KG 和 LLM 可以共同工作的其他方法:使用 LLM 将提示翻译成正式查询,使用 KG 验证 LLM 的响应,使用 LLM 构建 KG,以及使用 KG 训练 LLM。
1. 使用世界银行 API 下载一些关于世界银行文档的元数据
首先我们需要一些世界银行文档的元数据。有关世界银行文档和报告 API 的完整文档,请访问此处。
我选择世界银行文档元数据作为我们知识图谱的基础有几个原因。首先,世界银行提供了一个 API,可以访问他们的数据。此外,与世界银行文档相关的元数据是全面的,提供了有价值的信息。最后,我对这种数据类型的熟悉程度确保了对其结构和属性的更好理解。值得注意的是,构建知识图谱的灵活性适用于任何数据源,只要具备足够的领域知识以有效地构建它。
以下代码提供了最近 20 份标题中包含“可持续发展”一词的报告的元数据。
import requests
import json
import pandas as pd
url = 'https://search.worldbank.org/api/v2/wds'
params = {
'format': 'json',
'display_title': '"sustainable development"',
'rows': 20,
'page': 1
}
metadata_list = []
for i in range(1):
response = requests.get(url, params=params)
data = json.loads(response.content)
for doc_id in data['documents']:
metadata = data['documents'][doc_id]
metadata_list.append(metadata)
params['page'] += 1
df = pd.DataFrame(metadata_list)
现在我们已经将元数据保存为 pandas 数据框(df)。
2. 使用文档的元数据构建本体
现在我们可以建立我们的本体。那么,本体到底是什么?
“本体定义了一个共同的词汇,以便需要在某一领域共享信息的研究人员使用。它包括该领域基本概念及其相互关系的机器可解释定义,” (Noy 和 McGuinness, 2001).
通过采用本体,我们能够连接不同的数据源。在本教程中,我们的主要关注点是世界银行数据,然而,如果我们要整合联合国数据,我们将面临诸如文档类型类别、日期格式和国家名称的差异等挑战。本体作为弥合这些差异和建立数据集之间连接的有价值工具。此外,我们可以通过整合来自 Wikidata 的数据来扩展我们的知识图谱,Wikidata 是一个广泛使用的公共知识图谱,与 Wikipedia 密切相关。将我们的数据本体与 Wikidata 的本体对齐,可以实现信息的无缝集成。
以下代码设置了我们的图谱
from rdflib import Graph, RDF, RDFS, Namespace, URIRef, Literal
from rdflib.namespace import SKOS, XSD
from SPARQLWrapper import SPARQLWrapper, JSON
from tqdm import tqdm
# Create a new RDF graph
g = Graph()
schema = Namespace('http://schema.org/')
wd = Namespace('http://www.wikidata.org/entity/')
# Define namespaces
prefixes = {
'schema': schema,
'wd': wd,
'skos': SKOS,
'xsd': XSD
}
for p, ns in prefixes.items():
g.bind(p, ns)
元数据中有一列叫做“count”,表示给定文档的原产国。我们希望利用这一列在我们的本体中创建一个“country”类,并为这一列中的每个独特国家创建子类。
def create_subclass_country(column):
newClass = URIRef(schema + "country")
g.add((newClass, RDFS.label, Literal("country", lang='en')))
df[column] = df[column].astype(str)
for value in df[column].unique():
if value != "nan":
# Check Wikidata for a matching class
sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
query = f"""
SELECT ?class ?label WHERE {{
?class wdt:P31 wd:Q6256 .
?class rdfs:label "{value}"@en .
OPTIONAL {{ ?class skos:prefLabel ?label FILTER(lang(?label) = "en") }}
FILTER(REGEX(STR(?class), "^http://www.wikidata.org/entity/Q[0-9]+$"))
}}
"""
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
results = sparql.query().convert()
# If there is a match, use the Wikidata class as a subclass
if results['results']['bindings']:
#Get URI from Wikidata
uri = results['results']['bindings'][0]['class']['value']
#Get the 'Q ID' which is the unique ID at the end of the URI
qid = uri.split('/')[-1]
country_label = value
#Create a subclass for each country under the country class
subclass = URIRef(schema + country_label.replace(' ', '_'))
g.add((subclass, RDF.type, RDFS.Class))
g.add((subclass, RDFS.subClassOf, newClass))
# Update the "country_URI" column with the URI for the current country
df.loc[df[column] == value, "country_URI"] = uri
uri = URIRef(uri)
# Define the URI for the new Wikidata URI property
wd_URI_property = URIRef(schema + "wd_URI")
# Add the property to the RDF graph
g.add((wd_URI_property, RDF.type, RDF.Property))
# Add a label to the property
label = Literal("Wikidata URI", lang="en")
g.add((wd_URI_property, RDFS.label, label))
#Add Wikidata URI as a property to each country class
g.add((subclass, schema.wd_URI, uri))
#Add label to each Wikidata Q ID code that it is the Q ID for this particular country
g.add((uri, RDFS.label, Literal(f"{country_label} wikidata code", lang='en')))
g.add((subclass, RDFS.label, Literal(value, lang='en')))
else:
subclass = URIRef(schema + value.replace(' ', '_').replace('-','_'))
g.add((subclass, RDF.type, RDFS.Class))
g.add((subclass, RDFS.subClassOf, newClass))
g.add((subclass, RDFS.label, Literal(value, lang='en')))
我使用 protégé,一个免费的开源本体编辑器,来查看和有时手动调整我的本体。如果你想在创建本体时查看它,只需将图谱保存为 ttl 文件,并用 Protégé 打开。
#Save graph as ttl file for use in protégé
g.serialize('worldBankKG.ttl',format='turtle',prefixes = prefixes, encoding='urf-8')
一旦你在 Protégé 中打开文件,它应该类似于这样:
从 Protégé 截图。图片由作者提供。
在世界银行数据的“count”列中,每个独特的国家将会在我们的本体下的“country”类中创建一个独特的子类。你可以看到每个国家还具有一个标签和一个 Wikidata URI。在这种情况下,阿根廷的 Wikidata URI 是:www.wikidata.org/entity/Q414
。
这也是阿根廷国家的 Wikidata 页面的链接。
现在我们需要为其他所有内容创建类和子类。我为文档类型、世界银行项目、世界银行信托基金、国家和作者创建了类。我没有把所有代码放在这个教程中,但请参见我的笔记本以获取创建所有这些类的代码。
3. 用文档实例填充本体
本体是知识图谱的骨干,但现在我们需要用数据填充它。目前,本体定义了类、子类、类和子类的属性以及它们之间的关系。
例如,我们现在有一个世界银行文档的类和一个工作文件的子类。我们的本体确定了工作文件是世界银行文档的子类(或一种类型)。我们的本体为这些实体确定了标签和其他属性。例如,工作文件是 Wikidata 中的一个实体,因此我们将 Wikidata URI 作为属性包含在本体中。
对本体中几个实体的概念化。图像由作者提供。
然而,在上面的示例中,并没有实际的世界银行工作文件,即没有工作文件的实例。我们已经建立了类,但尚未建立实例。
这是逐行遍历我们的 DataFrame 的代码,并为每一行创建一个文档实例,并分配适当的属性。请注意,要运行这段代码,你需要先创建所有这些类(再一次,请参见笔记本以获取完整代码)。
#Create abstract property
df['abstracts'] = df['abstracts'].astype(str).str.replace('\n', '').replace('\\n','')
abstractIs_uri = URIRef(schema + "abstractIs")
g.add((abstractIs_uri, RDF.type, RDF.Property))
g.add((abstractIs_uri, RDFS.label, Literal("Short summary of the document.")))
#Create abstract class
abstract_class = URIRef(schema + "abstract")
g.add((abstract_class, RDFS.label, Literal("Short summary of a document.")))
#Create author properties
authoredBy_uri = URIRef(schema + "authoredBy")
authored_uri = URIRef(schema + "authored")
g.add((authoredBy_uri, RDF.type, RDF.Property))
g.add((authored_uri, RDF.type, RDF.Property))
g.add((authoredBy_uri, RDFS.label, Literal("This document was authored by this author.")))
g.add((authored_uri, RDFS.label, Literal("This author wrote this document.")))
#Define 'part of' property
isPartOf_uri = URIRef(schema + "isPartOf")
g.add((isPartOf_uri, RDF.type, RDF.Property))
g.add((isPartOf_uri, RDFS.label, Literal("This entity is a part of another entity")))
#Define 'countryOfOrigin' property
countryOfOrigin_uri = URIRef(schema + "countryOfOrigin")
g.add((countryOfOrigin_uri, RDF.type, RDF.Property))
g.add((countryOfOrigin_uri, RDFS.label, Literal("Country that this document is about.")))
# Create instances for each document and add author property
for index, row in tqdm(df.iterrows()):
if not pd.isnull(row['id']) and not pd.isnull(row['docty']) and not pd.isnull(row['authors']):
try:
# Create the report instance
instance = URIRef(schema + "doc/" + str(row['display_title']).replace(" ","_").replace("-","_"))
g.add((instance, RDFS.label, Literal(str(row['display_title']), lang='en')))
#Connect instances with types of documents
doctype = URIRef(row['docty'])
g.add((instance, RDF.type, doctype))
#Connect instances with country of origin
if row['count'] != "nan":
country = URIRef(schema + str(row['count']).replace(" ","_").replace("-","_"))
g.add((instance, countryOfOrigin_uri, country))
#Connect instances with projects
if row['projn'] != "nan":
project = URIRef(schema + "project/" + str(row['projn']).replace(" ","_").replace("-","_"))
g.add((instance, isPartOf_uri, project))
#Connect instances with trustfund_keys
if row['trustfund'] != "nan":
tf_values = row['trustfund'].split(",")
for tf in tf_values:
trustfund_uri = URIRef(schema + "trustfund/" + str(tf).replace(" ","_").replace("-","_"))
g.add((trustfund_uri, RDFS.label, Literal(f"Trustfund: {tf}")))
g.add((instance, isPartOf_uri, trustfund_uri))
g.add((trustfund_uri, countryOfOrigin_uri, country))
#Connect instances with authors
author_dict = ast.literal_eval(row['authors'])
for author_dict_entries in author_dict.values():
author_name = author_dict_entries['author']
author_uri = URIRef(schema + "author/" + author_name.replace(" ", "_"))
g.add((instance, authoredBy_uri, author_uri))
g.add((author_uri, authored_uri, instance))
#Add abstract
if row['abstracts'] != "nan":
abstract_uri = URIRef(schema + "abstract/" + str(row['display_title']).replace(" ","_").replace("-","_"))
g.add((instance, abstractIs_uri, abstract_uri))
g.add((abstract_uri, RDFS.label, Literal(str(row['abstracts']))))
g.add((abstract_uri, RDF.type, abstract_class))
g.add((abstract_uri, isPartOf_uri, abstract_uri))
except:
pass
现在我们有了一个实际的知识图谱。可以用以下方式可视化一个工作文件实例及其相关属性:
对填充了一些数据实例的本体的概念化。图像由作者提供。
你可以保存这个新文件,并在 protégé中打开,以确保所有实体都已被适当地纳入。
4. 使用 Wikidata 将额外的实体和关系引入图谱
因为我们知识图谱中的许多实体都与 Wikidata 的 URI 关联,我们可以从 Wikidata 导入额外的数据到我们的知识图谱中。在这个演示中,我只对国家做了这些操作。在以下代码中,我们遍历国家列,查询 Wikidata 中的国家实体,并导入给定国家的所有属性及其值。请注意,这段代码可能需要一段时间来运行——它必须查询 Wikidata 中每个国家的所有属性及其值,并将它们全部放入我们的图谱中。
import numpy as np
sparql = SPARQLWrapper("https://query.wikidata.org/sparql")
# Create a cache to store property code-label mappings
property_cache = {}
entity_cache = {}
# Prepare a list to collect triples for bulk graph update
triples = []
# Iterate over the URIs and add the properties to the RDF graph
for uri in tqdm(df['country_URI']):
if isinstance(uri, str) and uri.startswith('http://www.wikidata.org/entity/Q'):
class_uri = URIRef(uri)
country_column = df.loc[df['country_URI'] == uri, 'count'].iloc[0]
country_column = URIRef(schema + str(country_column).replace(" ", "_"))
# Construct the SPARQL query
qid = uri.split('/')[-1]
query = f"""
SELECT ?prop ?value WHERE {{
wd:{qid} ?prop ?value .
OPTIONAL {{ ?prop rdfs:label ?label . FILTER(lang(?label) = 'en') }}
}}
"""
# Set the query and response format
sparql.setQuery(query)
sparql.setReturnFormat(JSON)
# Execute the query and retrieve the results
results = sparql.query().convert()
# Iterate over the results and add them to the RDF graph
for result in results["results"]["bindings"]:
prop = result["prop"]["value"]
value = Literal(result["value"]["value"])
triple = (country_column, None, None) # Placeholder for triple
if prop.startswith('http://www.wikidata.org/prop'):
property_code = prop.split('/')[-1]
# Check if the property code is already in the cache
if property_code in property_cache:
property_label = property_cache[property_code]
else:
# If not in cache, query and retrieve the property label
property_label = get_property_label(property_code)
# Store the property code-label mapping in the cache
property_cache[property_code] = property_label
property_label_URI = URIRef(schema + property_label.replace(" ", "_"))
triple = (country_column, property_label_URI, value)
if value.startswith('http://www.wikidata.org/entity/Q'):
entity_code = value.split('/')[-1]
# Check if the entity code is already in the cache
if entity_code in entity_cache:
entity_label = entity_cache[entity_code]
else:
# If not in cache, query and retrieve the entity label
entity_label = get_entity_label(entity_code)
# Store the entity code-label mapping in the cache
entity_cache[entity_code] = entity_label
entity_label_URI = URIRef(schema + str(entity_label).replace(" ", "_"))
triple = (country_column, property_label_URI, entity_label_URI)
triples.append(triple)
elif isinstance(uri, float) and np.isnan(uri):
continue
else:
continue
# Add all collected triples to the RDF graph in bulk
for subject, predicate, object_ in triples:
if predicate is not None:
g.add((subject, predicate, object_))
现在我们有了一个用世界银行数据和来自 Wikidata 的额外数据填充的本体!如果你在 protégé中打开你的知识图谱,你可以探索我们导入的所有属性。下面是我们为阿根廷导入的一些属性的屏幕截图。这些数据也可以直接在 Wikidata 上查看,这里。
从 Protégé截取的屏幕截图。作者提供的图片。
Wikidata 有如下数据:政府形式、国家元首、外交关系、预期寿命(以及许多其他发展指标)随时间变化、子区域/领土等。
5. 直接使用 SPARQL 查询知识图谱
我们可以使用SPARQL查询直接查询这个知识图谱(我们创建的 RDF 数据库的标准查询语言)。我们可以使用 Python 中的 SPARQL 封装器。
例如,如果想查看所有关于巴西的文档,我们可以运行以下代码:
# Step 1: Find the URI of Brazil in your ontology
brazil_uri = "<http://schema.org/Brazil>" # Replace with the actual URI
# Step 2: Find the most relevant documents related to Brazil
documents_query = f"""
PREFIX schema: <http://schema.org/>
SELECT ?document
WHERE {{
?document a/rdfs:subClassOf* schema:world_bank_document ;
schema:countryOfOrigin {brazil_uri} .
}}
"""
qres = g.query(documents_query)
for row in qres:
print(f"Document ID: {row.document}")
这将输出我们知识图谱中所有将‘巴西’列为国家来源的世界银行文档。
我们可以将 SPARQL 查询做得尽可能复杂,并包含来自世界银行元数据或我们导入的 Wikidata 的任何属性。例如,如果我们想知道哪位作者写了最多关于联邦共和国国家的文档呢?
from rdflib import Graph, RDF, RDFS, URIRef
# Step 1: Find the URI of the basic_form_of_government you are interested in
government_form_uri = "<http://schema.org/federal_republic>" # Replace with the actual URI
# Step 2: Query for authors who have written the most documents associated with countries having the basic_form_of_government as "federal_republic"
authors_query = f"""
PREFIX schema: <http://schema.org/>
PREFIX prop: <http://schema.org/property>
SELECT ?author (COUNT(?document) AS ?numDocuments)
WHERE {{
?document a/rdfs:subClassOf* schema:world_bank_document ;
schema:countryOfOrigin [
schema:basic_form_of_government {government_form_uri}
] ;
schema:authoredBy ?author .
}}
GROUP BY ?author
ORDER BY DESC(?numDocuments)
"""
# Execute the query
results = g.query(authors_query)
# Now you can process the results and present them as needed (e.g., using pandas DataFrames)
# For simplicity, here, I'm just printing the author names and the number of documents they wrote
for row in results:
print(f"Author: {row.author}, Number of Documents: {row.numDocuments}")
输出将类似于下面这样:
我们可以看到“世界银行”作者已撰写了五篇关于 Wikidata 列出的联邦共和国国家的文档。在这种情况下,这些国家是巴西、科摩罗和埃塞俄比亚。这很重要,因为世界银行数据并未告诉我们任何国家的政府形式,这些信息来自 Wikidata。由于我们构建了与 Wikidata 对齐的本体,我们可以轻松地从 Wikidata 中纳入额外数据。
6. 比较与数据交互的不同方式
我们知道可以使用 SPARQL 查询知识图谱并获得准确结果。但这要求我们编写 SPARQL 查询,这需要一些技术能力。我们能否将这种准确性与 LLM 的易用性相结合,从而直接与数据进行“对话”?LlamaIndex是一个强大的工具,允许你使用自己的数据(他们称之为上下文信息)来增强 LLM。这里是一个 LlamaIndex 入门教程。我将使用 LlamaIndex 以两种方式将世界银行数据纳入 LLM:使用我们直接从世界银行 API 获得的原始 CSV 文件,以及使用我们用世界银行数据构建和填充的知识图谱。
使用我们直接从世界银行 API 获得的原始 CSV 文件
作为起点,让我们将来自世界银行的原始元数据加载到 LlamaIndex 中。这将作为一个基准,用来比较在整合知识图谱后结果的变化。来自世界银行的原始数据是 CSV 格式的,并且没有任何关联的本体。以下是设置 LlamaIndex 所需的所有代码。你需要一个 OpenAI API 密钥,可以从 OpenAI 网站获取。此代码直接从本地数据文件夹读取数据,我将我的文件夹命名为‘data’。你只需将 CSV 文件放入此文件夹中,LlamaIndex 将对其进行索引。
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, ServiceContext, LLMPredictor
from langchain import OpenAI
import os
import openai
os.environ["OPENAI_API_KEY"] = <YOUR API KEY> # replace with yours
openai.api_key = os.environ["OPENAI_API_KEY"]
documents = SimpleDirectoryReader('data').load_data()
query_engine = index.as_query_engine()
现在我们可以用基本的英语提问,就像我们对待 ChatGPT 一样。
response = query_engine.query("Show me all of the World Bank documents in the context information about Brazil")
print(response)
这是响应:
Brazil — LATIN AMERICA AND CARIBBEAN — P126452 — Rio Grande do Norte: Regional Development and Governance — Audited Financial Statement Brazil — LATIN AMERICA AND CARIBBEAN- P158000- Amazon Sustainable Landscapes Project — Procurement Plan
这些文档实际上是关于巴西的世界银行文档,包含在上下文信息中。然而,还有许多其他文档未列出。
现在让我们尝试获取 Anna Corsi 编写的所有文档。
response = query_engine.query("Based on the context information, what documents has Corsi,Anna written?")
print(response)
响应是:
Corsi, Anna 没有根据上下文信息编写任何文档。
这不正确。Anna Corsi 是数据中的一位作者。
让我们再试一次:
response = query_engine.query("Tell me more about Anna Corsi")
print(response)
响应是:
Anna Corsi 在给定的上下文信息中没有提及
再次,这不正确。Anna Corsi 是我们数据中的世界银行作者之一。
使用知识图谱作为输入
现在大问题是——使用知识图谱作为我们的上下文信息是否改善了这些结果?我们将不再以相同的方式将 RDF 数据加载到 LlamaIndex 中,而是使用 RDFReader。我发现这种将数据加载到 LlamaIndex 中的方法有些问题,所需时间较长,但它是将 RDF 数据纳入输入提示的推荐方式。RDFReader 要求知识图谱中的每个实体都必须有一个标签。因此,即使你在实体上添加了评论,也需要为评论添加一个标签。
from llama_index import GPTVectorStoreIndex, download_loader
RDFReader = download_loader("RDFReader")
document = RDFReader().load_data(file="path to your file")
# Define LLM
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="text-davinci-002"))
# NOTE: set a chunk size limit to < 1024 tokens
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, chunk_size_limit=1012)
index = GPTVectorStoreIndex.from_documents(document, service_context=service_context)
注意: 我仍在使用 GPTVectorStoreIndex 处理 RDF 数据,而不是 KnowledgeGraphIndex。我发现 KnowledgeGraphIndex 没有产生更好的结果,但它花费了更多的时间和成本。 其他人 似乎也有类似的经历。我的理解是 RDFReader 用于使用预制的知识图谱,而 KGIndex 用于使用 LlamaIndex 创建知识图谱。
一旦初始化,我们可以向这个 LLM 提出相同的问题,看看答案是否有所改进。对于第一个提示,“展示上下文信息中所有关于巴西的世界银行文档”,响应是:
<Brazil - LATIN AMERICA AND CARIBBEAN - P107146 - Acre Social and Economic Inclusion and Sustainable Development Project - PROACRE - Audited Financial Statement>
<Disclosable Version of the ISR - Rio de Janeiro Adjustment and Sustainable Development Policy Loan - P178729 - Sequence No : 01>
<Grosso Fiscal Adjustment and Environmental Sustainability Development Policy Loan>
<Disclosable Version of the ISR - Matanza-Riachuelo Basin Sustainable Development Project - P105680 - Sequence No : 29>
<Disclosable Version of the ISR - Matanza-Riachuelo Basin Sustainable Development Project - P105680 - Sequence No : 30>
<Disclosable Restructuring Paper - Health Sustainable Development Goals Program-for-Results - P123531>
这部分是正确的。这些实体中有两个是关于巴西的文档,但有两个是关于阿根廷的,还有一个是关于埃塞俄比亚的。还有一些关于巴西的额外文档未在此列出。
现在让我们再问一个具体的作者:
response = query_engine.query("Based on the context information, what documents has Corsi,Anna written?")
print(response)
回复是:
Concept Project Information Document (PID) - Land administration infrastructure for green and sustainable development - P179217
这部分是正确的。安娜·科尔西确实写了这份文档,但她还在背景信息中写了另外三份未在此列出的文档。
response = query_engine.query("Tell me more about Anna Corsi")
print(response)
回复是:
‘安娜·科尔西是文档“概念项目信息文档(PID)——绿色和可持续发展的土地管理基础设施——P179217”的作者。该文档涉及一个支持土耳其大规模评估系统发展的项目,并为单个财产单位生成市场价值。科尔西还是“现在就是时候:乌兹别克斯坦如何利用城市化推动可持续发展?”的作者,该文档讨论了乌兹别克斯坦如何利用城市化促进可持续发展。’
这部分是正确的。关于土耳其土地管理项目的第一部分是正确的。然而,安娜·科尔西并不是乌兹别克斯坦报告的作者。
语言模型明显在处理特定结构化数据问题时有困难。不过,正如我们开始时所说,语言模型擅长理解和生成类似人类的文本,捕捉复杂的语言模式,并产生上下文相关的回答。以下是如何利用语言模型的优势来处理这一数据集的示例。假设我们已经知道世界银行在土耳其有一个土地管理基础设施项目。我们可以要求语言模型用以下提示描述该项目:
response = query_engine.query("Tell me more about World Bank's land management infrastructure project in Turkey")
print(response)
回复是:
绿色和可持续发展土地管理基础设施项目的目标是提高土耳其土地管理信息的准确性和可及性。该项目有三个组成部分,第一个组成部分是创建 3D 城市模型和更新土地登记数据。该部分将支持:(i)基于在阿马斯亚试点中测试的成熟方法创建 3D 城市模型;(ii)完成对 600 万块地块的更新和验证(剩余的 1100 万块地块中有 19 块未包含在 LRCMP 中),包括城市和农村地区。作为土地登记数据更新活动的一部分,将与 TKGM 讨论如何更好地解决女性土地权利和所有权方面的挑战,以便在测量过程中作为公众咨询步骤的一部分进行。尽管土地登记数据的更新和验证活动将在城市和农村地区进行,但在创建 3D 城市模型时,项目将覆盖土耳其所有 81 个省份的主要城市区域(约 40,000 平方公里,几乎覆盖全国所有城市区域)。作为这些活动的一部分,项目将资助完成和更新
这是我们数据库中文档的准确摘要。这些信息主要来自数据中的‘abstract’字段。
结论
总之,通过上下文学习将结构化数据直接纳入 LLMs 会产生次优结果,且偶尔出现幻觉和缺乏透明度等挑战仍然存在。虽然将结构化数据转换为知识图谱在一定程度上改善了 LLM 的表现,但像 SPARQL 这样的正式查询语言在特定查询中比 LLMs 更具优势。然而,LLMs 可以协助制定 SPARQL 查询,为用户提供无需编写正式查询的方式与结构化数据进行‘聊天’。
LLMs 在理解和解释非结构化数据方面展现了其强大能力,使其能够提供相关的响应。这种能力甚至扩展到了结构化数据,只要其中嵌入了非结构化信息。例如,如果结构化数据包含一个标记为‘abstracts’的列,其中包含非结构化文本,则 LLM 可以利用这些数据生成有洞察力的结果。
变革性协同作用如将 LLMs 用于将提示翻译成正式查询,或利用知识图谱对 LLMs 进行验证,为该领域的探索和进步提供了潜在途径。