TowardsDataScience 2023 博客中文翻译(二百五十一)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

论文解读——《基于潜在扩散模型的高分辨率图像合成》

原文:towardsdatascience.com/paper-explained-high-resolution-image-synthesis-with-latent-diffusion-models-f372f7636d42

虽然 OpenAI 在自然语言处理领域通过其生成文本模型占据了主导地位,但其图像生成模型 DALL·E 现在面临一个有效的开源竞争者——Stable Diffusion。本文将深入探讨基于 Stable Diffusion 的潜在扩散论文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mario Larcher

·发表在《数据科学前沿》 ·10 分钟阅读·2023 年 3 月 30 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 13 的一部分来自于《基于潜在扩散模型的高分辨率图像合成》,生成的提示是“潜在空间的油画”。

引言

在我撰写这篇文章时,OpenAI 的聊天机器人 ChatGPT 正在通过与微软产品的整合继续获得关注,这些产品被超过十亿人使用。尽管谷歌最近推出了自己的 AI 助手 Bard,其他公司也在该领域取得了进展,但 OpenAI 依然处于前沿,没有明显的竞争对手。有人可能认为,OpenAI 的 DALL·E,即图像生成模型,在条件和非条件图像生成领域也会同样主导。然而,实际上是一个开源替代品——Stable Diffusion,在受欢迎度和创新性方面领先

本文深入探讨了 Stable Diffusion 背后的科学论文,旨在提供对这一正在革新图像生成领域的模型的清晰而全面的理解。虽然其他文章提供了该技术的高层次解释,但本文超越了表面,探索了经常被忽视的细节。

之前图像生成方法的局限性

在深入分析科学论文《基于潜在扩散模型的高分辨率图像合成》中的方法之前,了解该工作所解决的关键问题是至关重要的。

多年来,图像生成主要通过四类模型来解决:生成对抗网络(GANs)、变分自编码器(VAEs)、自回归模型(ARMs),以及最近的扩散概率模型(DMs)。

生成对抗网络(GANs)

自 2014 年首次出现以来,生成对抗网络(GANs)一直是图像生成领域的主流方法之一。尽管 GANs 在数据变异性有限的情况下显示出有前景的结果,但它们也存在一些问题。最著名的问题是模式崩溃,在这种情况下,生成器生成的输出范围有限,而不是多样化的图像集。

模式崩溃:这种现象发生在生成器可以交替生成有限数量的输出,来欺骗判别器。通常,GANs 在捕捉完整数据分布方面存在困难。

一般而言,它们的训练****常常不稳定的

变分自编码器(VAEs)

变分自编码器(VAEs)是 GANs 的替代方案,提供了几个优势。它们不会遭遇模式崩溃,并且可以高效地生成高分辨率图像。然而,它们的样本质量****不总是可比于 GANs

自回归模型(ARMs)

自回归模型(ARMs)在密度估计方面表现出色,并在这一领域取得了显著成果。然而,它们计算需求高的架构和顺序采样过程限制了它们仅能生成 低分辨率图像

扩散概率模型(DMs)

扩散模型在密度估计和样本质量方面取得了显著进展,但它们在像素空间中通过向与原始图像大小相同的张量添加或去除噪声来操作,导致了推断速度缓慢计算成本高。例如,即使是相对较小的图像,如尺寸为 512x512 的 RGB 图像,也对应于大约 80 万个值的张量,这使得在训练过程中进行梯度传播和在生成过程中进行迭代的方法时,生成更大图像的计算需求也非常高。

条件机制

基于文本描述或其他图像风格生成图像通常是期望的,但将结果条件化到一个或多个输入 在之前的方法中一直是一个挑战

潜在扩散的高级概述

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3. 来自 高分辨率图像合成与潜在扩散模型

总结科学论文《高分辨率图像合成与潜在扩散模型》中提出的方法,我们可以将其分解为四个主要步骤

  1. 第一步是使用位于图上左上角的编码器 E 提取 图像的紧凑的表示。与其他方法不同,潜在扩散在编码器定义的潜在空间中工作而不是在像素空间中

  2. 接下来,将高斯噪声添加到图中上中部的图像中,作为从 zzT扩散过程的一部分(假设应用了 T 步噪声添加)。

  3. 然后,将 zT 表示通过位于图中下部中央的U-Net。U-Net 的作用是预测 zT-1,这个过程重复 T-1 次,直到我们得到 z,然后通过解码器 D 将其从潜在空间返回到像素空间

  4. 最后,该方法允许通过映射各种输入模态(如语义图或文本)进行任意条件。这通过首先使用专用编码器 τθ 转换输入 y ,然后 将其映射到 U-Net 的中间层,使用与 Transformer 架构相同的 交叉注意力层 实现。

通过这个总体概述,我们现在可以更详细地深入了解这些步骤。

感知图像压缩

潜在扩散明确将图像压缩阶段分离,以去除高频细节(感知压缩),与生成阶段分开,在生成阶段模型学习数据的语义和概念组成(语义压缩)。

目标函数

为了训练用于图像压缩自编码器,作者采用了 《驯化变压器用于高分辨率图像合成》VQGAN使用的方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2. 来源于 《驯化变压器用于高分辨率图像合成》

特别是,用于训练自编码模型 (E, D) 的目标函数是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

方程 25. 来源于 《基于潜在扩散模型的高分辨率图像合成》

x^重建图像 D(E(x)),Lrec重建损失xx^ 之间的平方误差),Ladv对抗损失,定义为 log*(1 - Dψ(x^)),Dψ* 是一个基于补丁的判别器,优化以区分原始图像和重建的“伪造”图像 x^(所以 Dψ(x) 试图对真实图像 x 输出 1,对重建的“伪造”图像 x^ 输出 0),*,Lreg正则化损失

正则化

作者实验了两种不同的正则化方法

第一种方法涉及低权重的Kullback-Leibler项,类似于标准的 VAE。

Kullback-Leibler (KL) 惩罚:Kullback-Leibler 散度是两种分布之间的一种统计距离。在这种情况下,目标是使潜在变量 z ~ N(Eµ , Eσ²) 的分布接近标准正态分布 N(0, 1)。施加这个约束可以通过使潜在空间更加集中来正则化它,因此,例如,如果 z 接近于 z1 和 z2,则 D(z) 将与 D(z1) 和 D(z2) 有某种共同之处。

在第二种方法中,潜在空间通过 向量量化 层进行正则化。

向量量化 (VQ):VQ 是科学论文 Neural Discrete Representation Learning 中介绍的 VQVAE 以及之前提到的 VQGAN 所使用的方法。从上面的图像可以看出,对于编码器输出 z^ 的每个空间位置,对应的向量(其大小取决于 z^ 的通道数量)会被替换为在可学习的“代码本”中最接近的向量。这会限制解码器在推理过程中可能的输入,只能是代码本向量的组合(即潜在空间的离散化或量化)。

在 VQ 正则化的潜在空间中,z 在量化层之前被提取,并在解码器中吸收量化操作,即它可以被解释为 D 的第一层。

扩散模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2. 来自 Denoising Diffusion Probabilistic Models

由于这篇文章讨论的是潜在扩散模型而不是一般的扩散模型,我将只描述它们最重要的方面。首先,让我们区分 两个过程正向反向

正向过程

正向扩散过程,即图中从右向左的过程,是一个 马尔可夫链,即时间 t 的图像仅依赖于时间 t-1 的图像,而不是所有之前的图像。每一步,xt 都根据以下 转移概率 进行采样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Eq. 2. 来自 Denoising Diffusion Probabilistic Models

在上面的公式中,βt 定义了一个 方差调度,可以通过将其视为超参数来学习或保持不变。正向处理的一个有趣特性是可以在任意时间步 t 以封闭形式对 xt 进行采样。使用符号

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们有

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Eq. 4. 来自 Denoising Diffusion Probabilistic Models

所以,总结一下,在正向过程中,我们可以通过从均值和方差由上述公式定义的高斯分布中采样来获取任意时间 t 的图像。

反向过程

给定正向过程,逆向 过程也遵循高斯分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

至于方差,作者将其设置为

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

他们通过实验指出,两者

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

产生等效的结果。

在看到均值的参数化之前,让我们重新参数化等式 4 的正向过程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

此时我们将均值参数化为

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

其中ϵθϵ的估计量,来自xt,具体来说,它是时间条件 U-Net的一个变体。

在这一点上,我们已经具备了采样xt-1条件于xt所需的所有元素,考虑到我们知道逆过程描述开始时引入的高斯分布的所有参数。

不进入数学细节,目标(简化版)是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

等式 1 来自于 高分辨率图像合成与潜在扩散模型

其中t 从 {1, . . . , T} 中均匀采样。

潜在表示的生成建模

如前所述,Latent Diffusion 的工作方式类似于前面解释的扩散模型。然而,它的不同之处在于它从通过编码器获得的图像的潜在表示 z 开始(潜在空间),而不是从图像 x像素空间)开始。这一细节大大减少了计算负担,因为潜在空间比像素空间更紧凑。

基于此,将xt替换为扩散模型目标中的zt,我们得到新的目标:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

等式 2 来自于 高分辨率图像合成与潜在扩散模型

条件机制

在这项研究之前,对于如何通过输入条件扩散模型的探索是有限的,输入包括类别标签或模糊版本的输入图像。Latent Diffusion 提出的方案具有很高的通用性,涉及 附加信息直接集成到 U-Net 模型的中间层中,使用交叉注意力,这类似于 Transformer 架构。

更具体地说,输入信息(如文本)首先通过一个领域特定的编码器 τθ 转换为中间表示(稍后会提供示例)。然后,将该表示传递通过交叉注意力层并添加到 U-Net 的中间层:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在方程中,*ϕi(zt)*表示 U-Net 的扁平化中间表示,Ws 是可训练的投影矩阵。尽管论文没有详细说明,代码实现却揭示了交叉注意力层的输出被加到原始 U-Net 层中。这可以在以下代码片段中看到:

x = self.attn2(self.norm2(x), context=context) + x

在这里,attn2 表示交叉注意力层,而上下文指的是τθ(y)。虽然这一过程的完整实现更为复杂,但这是关键的概念元素。欲深入了解这一机制,请参考BasicTransformerBlock模型中使用的模块。

实验

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5. 来自 High-Resolution Image Synthesis with Latent Diffusion Models

论文进行大量实验探讨了图像生成的各种方法,包括无条件生成、布局到图像合成、空间条件、超分辨率、修补等。为了进一步突出 Latent Diffusion 的两个重要方面,我们将重点关注文本到图像这一广为人知的任务。

首要考虑的关键方面是如何将文本转换为可以传递给交叉注意力层的表示。作者使用BERT-tokenizer并将τθ实现为Transformer以达到这一目标。

第二个重要方面是通过编码器对输入图像进行压缩的程度。作者尝试了不同的下采样因子 f ∈ {1, 2, 4, 8, 16, 32},并得出结论认为 4 和 8 提供了最佳的高质量合成结果条件。上述结果是使用 LDM-8 (KL) 获得的。

结论

Latent Diffusion 及其后续工作的灵感来源于这篇论文,已取得令人惊讶的成果,这些成果曾被认为难以想象。如今,这些模型不再局限于研究实验室,而是被集成到流行产品中,如 Adobe Photoshop。这一发展标志着人工智能领域的一个重要里程碑,并展示了它对我们生活各个方面的潜在影响。

然而,尽管在这一领域取得了显著进展,但仍然存在一些需要解决的挑战。这些挑战包括版权问题,涉及使用图像训练 AI 模型,以及在从互联网爬取大型数据集时出现的固有偏见。然而,尽管存在这些限制,AI 有能力民主化创造力并使个人能够以新颖和引人入胜的方式表达自己,这一潜力不容忽视。

感谢你花时间阅读这篇文章,欢迎留下评论或与我联系,分享你的想法或提出任何问题。要及时了解我最新的文章,你可以在MediumLinkedInTwitter上关注我。

[## 通过我的推荐链接加入 Medium - Mario Namtao Shianti Larcher

阅读 Mario Namtao Shianti Larcher(以及 Medium 上其他成千上万的作者)的每一个故事。你的会员费…

medium.com

论文综述:深度解析 Imagen

原文:towardsdatascience.com/paper-review-a-deep-dive-into-imagen-4e5b4092af13?source=collection_archive---------13-----------------------#2023-02-01

对谷歌令人印象深刻的新型文本到图像生成工具的批判性分析

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jamie McGowan

·

关注 发表在 Towards Data Science · 13 分钟阅读 · 2023 年 2 月 1 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Amanda Dalbjörn 提供,来源于 Unsplash

文本到图像合成是多模态学习领域的一个研究方向,最近在这一领域取得了许多进展[1–4]。这篇综述将重点讨论文章*《具有深度语言理解的逼真文本到图像扩散模型》*[1]。

在这里,作者试图实现最先进的逼真度,并提供对文本到图像合成中更深层次的语言理解的见解。本文的主要输出是一个名为‘Imagen’的模型,它改进了文献中以前的文本到图像合成模型[2–4]。

你可以在这里查看更多关于 Imagen 的信息!

什么是扩散模型?

正如论文标题所示,Imagen 是一个扩散模型。

简而言之,扩散模型是基于生成式 AI的一个例子,其过程是从输入x⁰开始,在每一层t逐渐加入高斯噪声,直到达到纯噪声表示x,其中T是最终层。

这受非平衡热力学的启发,其中状态通过扩散演变为均匀,只要时间足够长。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

显示前向和背景扩散过程的图示。图片来自[16]。

扩散模型通过学习逆转这一过程,试图生成原始的x⁰(在这种情况下,x⁰是一张图像)。请参见上面的图示以获得视觉帮助。

该模型的目标是对描述每一步t反向扩散过程的条件概率进行参数化:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

描述反向扩散过程的方程。图片由作者创建。

其中xᵗ⁻¹(上一个时间步骤)的表示从高斯分布中绘制,特征是均值μ和协方差σ,模型权重为θ。

由于扩散过程在每一步的去噪过程中保持了图像,这导致数据与预测之间的亲密连接,相比其他非扩散基于文本到图像生成器[4–7]。结果通常是,扩散模型[1–3]的输出更逼真

一旦使用基础扩散模型构建一个 64 × 64 像素的图像,Imagen 会利用另外两个超分辨率扩散模型进行上采样 64 × 64 → 256 × 256 → 1024 × 1024。最终结果因此是一个高分辨率的 1024 × 1024 像素图像,如下图所示!

请注意,这张图片实际上来自 DALL-E 2 [2],因为谷歌对 Imagen 有一些限制!概念是一样的,但请确保查看 Imagen 论文以获取实际的图片。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DALL-E 2 的一个示例输出,文本提示为“时代广场上的滑板上的泰迪熊”。图片来自[2]。

本综述将简要概述以前的工作,然后汇总作者提出的主要贡献和结果,并讨论这些贡献,提供我个人对该工作的意见

之前的工作

虽然实现文本生成图像已经有若干年,但早期的工作困难地将多个文本概念现实地结合到一幅图像中 [5–7]。

基于这些不足,OpenAI 在 [4] 中发布了 DALL-E,它能够将多个看似不相关的概念逐行组合成一幅图像——给定文本提示和图像的起始(第一行像素)。

不到 12 个月后,OpenAI 通过 GLIDE [3] 重新制定了他们的文本到图像合成方法。作者展示了 GLIDE 在各种设置中因照片真实感标题相似性而被人类评估者偏好,从而确立了扩散模型在文本到图像生成中的主导地位。

最后,在 [2] 中,DALL-E 2 通过基于从文本提示中找到的图像嵌入生成图像,从而进一步改进了 GLIDE。

注意到在这个时间段内还有其他进展,但我主要集中在三大主要贡献上,这些贡献构成了 Imagen [1] 的基础。

主要贡献

架构

类似于 GLIDE [3] 和 DALL-E 2 [2],Imagen 是一个扩散模型,在其架构上似乎与 GLIDE 非常接近(即,它将文本嵌入作为输入,并从噪声中生成图像)。然而,Imagen 的一个关键区别是,文本嵌入来自大型现成的语言模型(LMs)。

[1]的主要发现之一是,结合在仅文本数据上训练的大型冻结语言模型,在获取文本到图像合成的文本表示时非常有用。

此外,作者探讨了文本编码器的扩展,发现扩展语言模型的规模显著提高了结果,远超扩展扩散模型的规模。[1]中图 4a 的最左侧图表总结了这一结果,显示 T5-XXL LM [8]生成了更高质量的图像(↓ FID 分数)和更好的标题兼容性(↑ CLIP 分数)。

作者还引入了一种新技术,以避免在无分类器引导的图像生成中出现饱和像素。

分类器引导 被引入以通过预训练模型提高生成图像的质量,该模型在测试时推动输出更忠实于文本输入 [9]。

无分类器引导 [10] 通过从输入噪声生成两个样本(输出),一个有文本条件,一个没有,避免了对预训练模型的需求。

通过在特征空间中找到这两个样本之间的差异,可以找出文本在图像生成中的效果。通过调整这种文本效应,图像生成可以引导至更好的图像-文本对齐(通过不同强度的引导权重 w)。

到目前为止,这些都不算新鲜,然而这种指导的一个问题是,当w很大时,像素可能会饱和,图像保真度会因为更好的图像-文本对齐而受到损害。因此,作者引入了动态阈值,通过在每次采样步骤xᵗ时确定的不同量将饱和像素从[-1, 1]推向内部(因此是动态的)。作者声称,在图像生成中,高度指导下逼真度图像-文本对齐有显著提升。

最后,在模型架构方面,作者提出了一种新的 U-Net [11]变体,该变体比以前的版本更简单更高效。根据我的了解,关键修改是去除了超分辨率模型中的自注意力层,这些模型来自于[11–12]的 U-Net 模型。

DrawBench

对未来文本到图像合成研究的另一个重要贡献是DrawBench的发布。

DrawBench 是一个‘挑战性’评估基准文本提示的集合,探测模型处理复杂概念如组合性、基数和空间关系的能力。

本次发布的理念是提供一个包含一些非常奇怪的文本提示的评估基准,以确保图像之前从未存在过。因此,理论上,这应该将模型推向其想象力和能力的极限,以生成复杂图像。

定量结果

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Maxim Hopman拍摄,来源于Unsplash

作者在[1]中展示的定量结果对比了不同模型在 COCO [15]和 DrawBench 文本提示上的表现。

作者发现,与 DALL-E 2 [2]、GLIDE [3]、Latent Diffusion [14]和 CLIP-guided VQ-GAN [13]模型进行成对比较时,DrawBench 上的人工评估结果对 Imagen 表现出强烈的偏好(见[1]中的图 3)。这些结果作为标题对齐保真度的衡量标准。

同时,COCO 验证集上的结果似乎在不同模型之间没有表现出太大差异——这可能是作者没有对此过多停留的原因。

然而,对 COCO 数据集的一个有趣观察是,Imagen 在生成逼真的人物方面有有限的能力——尽管作者没有提供任何关于 Imagen 生成人物效果差的定性例子。

讨论

在引言中,[1]的作者提到以下声明:

[Imagen 交付]了前所未有的逼真度和深层次的语言理解能力,在文本到图像合成中。

在调查这一声明的前半部分时,作者展示了 Imagen 与 DALL-E 2 生成图像之间的几种定性比较。他们还提供了人类评估实验的结果,询问人们从单个文本提示或标题中选择最写实的图像。

即使在考虑任何结果之前,作者立即在他们的分析中引入了人类评估实验中固有的主观性。因此,文献[1]中显示的结果必须谨慎考虑,并持有健康的怀疑态度

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个来自 DALL-E 2 的示例输出,文本提示为“狗在湖边的绿色田野里玩耍的高质量照片”。图像来自[2]。

为了提供这些结果的背景,作者选择了一些示例比较展示给人类评分者,并将其包含在附录中(一定要查看这些示例——作为动机,我在上面添加了一个来自 DALL-E 2 的示例)。

然而,即使有这些例子,我发现很难对哪张图像更具优先性做出明确判断。考虑到上图中显示的复制例子,我个人认为一些 DALL-E 2 生成的图像比 Imagen 的图像更具照片写实性,这展示了在收集这类结果时主观性的问题。

作者选择询问人类评分者*“哪张图像更具照片写实性?”以及每个“标题是否准确描述了图像?”*在评估过程中。然而,评估这些指标的不连续性质让我感到相当担忧。

例如,如果我们有两张卡通图像(这些图像可能不太现实),并且评分者被要求选择其中一张。就照片写实性指标而言,所选择的图像将具有与从另一批次中选择的更真实的图像相同的现实水平(即不是卡通图像)。

显然,图像批次的标题与可以实现的照片写实性之间存在一定的相互作用。因此,探索根据难度加权某些文本提示,将会有趣,以尝试创建一个连续的度量标准,该标准可以更可靠地汇总

同样,在标题对齐的情况下,评分者在三个类别选项中选择标题是否与生成的图像对齐(有些)。这些实验结果试图支持上述引言的后半部分(声称深层次的语言理解)。

确实,对于标题对齐,可以说在文本提示中的关系和概念是否在图像生成中被捕捉,有一个更明确的答案(即比起照片写实性主观性更低)。

然而,我再次认为应该使用更连续的度量标准,例如 1-10 的对齐评分。根据上述讨论,所有字幕的不同难度等级也会表现为较低的字幕对齐度。有趣的是,探讨让评估者在评价过程中评估字幕或文本提示的难度,并帮助标准化数据集和度量标准。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Mitchell LuoUnsplash

随着这一研究领域的发展和生成图像变得越来越令人印象深刻和富有创意,这种评估方法自然会变得不那么可靠(当然,这是一个好问题)。因此,希望看到作者讨论询问评估者更多具体问题的潜力,以评估模型所捕捉的创造力、构图性、数量关系和空间关系。

如果生成的两张图像同样令人印象深刻,向评估者提出更多具体的问题可以帮助区分模型在这一非常高的水平上的表现。

作为一个例子,文本到图像生成的一个应用是帮助生成****插图。因此,评估解释文本提示时的创造力和变化性是有充分理由的。

在之前展示的例子中,DALL-E 2 [2] 在解释*‘眼镜’*时比 Imagen 有更多的方式,因此可以认为 DALL-E 2 是更具创意的模型?

从这个角度来看,论文的一个主要批评是所选择的度量标准过于侧重于 Imagen 的优势。不同应用中表现良好的模型的最佳指示(度量)将依赖于具体应用(即没有免费的午餐!)。

因此,我很想听听作者对如何严格评估这些模型的不仅仅是忠实度和字幕对齐度的看法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Dragos GontariuUnsplash

根据[1],DrawBench 的发布被认为是对文本到图像研究领域的必要贡献,因为它提供了一套全面的挑战性文本提示场景。

虽然我同意大部分观点,但根据围绕这一论点的讨论,我尚未被说服这是一项全面的基准。如果深入探索 DrawBench,会发现仅包括约 200 个文本提示/字幕,分布在 11 个类别中,这在初看时显得相当少。

这种担忧在与 COCO 数据集 [15] 比较时愈发加深,该数据集包括 330K 张图像,每张图像有 5 个标题,涵盖了更广泛的类别。个人认为,作者应该讨论他们声称这是一个全面集的推理

此外,随着文本到图像合成的快速进展,我认为 DrawBench 是该领域的动态 目标。因此,希望能够看到讨论调整或添加这些标题的可能性。

由于 DrawBench 是与 Imagen 一起呈现的,因此有理由对选择 200 个提示以获得对 Imagen 有利的结果是否存在某种担忧

再次比较 Imagen 和基准模型在 COCO [15] 和 DrawBench 上评估的结果,COCO 的结果似乎在模型之间更接近,而 DrawBench 的结果则显示 Imagen 似乎远高于所有基准模型。

这可能是因为 DrawBench 是一个自然难度较大的提示集,而 Imagen 能够处理这些提示是由于其预训练的 LM,或者 DrawBench 可能对 Imagen 的优势存在偏见?确实,作者在构建 DrawBench 时承认存在一些偏见,未包含任何图像生成中的人物。

最后,当模型(或代码)未发布时,批评研究很容易,尤其是当存在巨大的财务获利潜力时(作者并未提及这一点)。

然而,我认为其中的社会和伦理推理是该论文的最佳 贡献之一,并且突显了在发布强大的开源 AI 软件时需要某种形式的治理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Михаил СекацкийUnsplash 上的照片

从更广泛的意义上讲,生成模型自然地反映了社会,这对社会研究小组甚至政府可能是有益的,前提是他们能够访问未经筛选的模型版本。

结论

总结来说,作者在文本到图像合成领域的快速增长中做出了重大贡献。

虽然出于社会和伦理原因目前不对公众开放,但最终的模型‘Imagen’结合了诸如使用现成文本编码器、动态阈值设定和更高效的 U-Net 架构用于基础和超分辨率层等新技术。

我个人很享受阅读这篇论文,我认为所做的贡献是令人兴奋和有趣的领域发展。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Arnold Francisca 提供,来自 Unsplash

然而,尽管结果令人印象深刻,但深入研究后我发现,作者往往过度宣传 Imagen 和 DrawBench。因此,观察(也许在未来的出版物中,或从获得 Imagen 访问权限的研究人员中)更广泛的文本到图像生成模型评估将会很有趣。

参考文献

[1] — Chitwan Saharia 等人,照片级文本到图像生成

深度语言理解的扩散模型,arXiv:2205.11487,(2022)。

[2] — Aditya Ramesh 等人,基于 CLIP 潜变量的层次化文本条件图像生成,arXiv:2204.06125,(2022)。

[3] — Alex Nichol 等人,Glide:朝向照片级图像生成与编辑的文本引导扩散模型,arXiv:2112.10741,(2021)。

[4] — Aditya Ramesh 等人,零样本文本到图像生成,ICML,8821 — 8831,PMLR,(2021)。

[5] — Han Zhang 等人,Stackgan++:利用堆叠生成对抗网络进行逼真图像合成,《IEEE 计算机视觉与模式识别学报》,41(8):1947–1962,(2018)。

[6] — Tero Karras 等人,分析与改善 stylegan 图像质量,发表于《IEEE/CVF 计算机视觉与模式识别会议论文集》,8110 — 8119,(2020)。

[7] — Mark Chen 等人,基于像素的生成预训练,ICML,1691 — 1703,PMLR,(2020)。

[8] — Colin Raffel 等人,探索统一文本到文本转换器的迁移学习极限,arXiv:1910.10683,(2019)。

[9] — Prafulla Dhariwal 和 Alexander Nichol,扩散模型在图像合成中的表现优于 GANs,NeurIPS,34,(2021)。

[10] — Jonathan Ho 和 Tim Salimans,无分类器扩散引导,发表于 NeurIPS 2021 深度生成模型与下游应用研讨会,(2021)。

[11] — Alex Nichol 和 Prafulla Dhariwal,改进的去噪扩散概率模型,ICML,8162–8171,PMLR,(2021)。

[12] — Chitwan Saharia 等人,Palette:图像到图像的扩散模型,arXiv:2111.05826,(2021)。

[13] — Katherine Crowson 等人,VQGAN-CLIP:利用自然语言指导的开放域图像生成与编辑,arXiv:2204.08583,(2022)。

[14] — Robin Rombach 等人,高分辨率图像

基于潜在扩散模型的合成,arXiv:2112.10752,(2021)。

[15] — Tsung-Yi Lin 等人,Microsoft COCO:上下文中的常见对象,发表于《欧洲计算机视觉会议论文集》,740 — 755,Springer,(2014)。

[16] — Calvin Luo,理解扩散模型:一个统一视角,arXiv:2208.11970,(2022)。

在 Pandas 中并行化 Python:使用 Spark 的并发选项

原文:towardsdatascience.com/parallelising-python-on-spark-options-for-concurrency-with-pandas-7ca553b9f265?source=collection_archive---------8-----------------------#2023-11-18

当使用 Pandas 时,利用 Spark 的优势

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传Matt Collins

·

关注 发表在Towards Data Science ·8 分钟阅读·Nov 18, 2023

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Florian SteciukUnsplash上拍摄。

在我之前的工作中,我花了一些时间在一个内部项目上,为我们的托管服务客户预测未来的磁盘存储空间使用情况。每个磁盘都有自己的使用模式,这意味着我们需要为每个磁盘建立单独的机器学习模型,这些模型利用历史数据逐盘预测未来的使用情况。尽管执行预测和选择正确的算法本身就是一个挑战,但在大规模执行时也有其自身的问题。

为了利用更先进的基础设施,我们可以考虑摆脱顺序预测,通过并行化工作负载来加快预测操作。本文旨在比较 Pandas UDF 和 ‘concurrent.futures’ 模块,这两种并发处理的方法,并确定每种方法的适用场景。

挑战

Pandas 是 Python 中用于处理数据集的入口包。在处理 DataFrames 时,我们能够对数据进行分析和评估数据质量,执行探索性数据分析,构建数据的描述性可视化,并预测未来趋势。

尽管这是一个很好的工具,但 Python 的单线程特性意味着在处理更大的数据集时,或者在需要对多个数据子集执行相同分析时,它的扩展性较差。

在大数据领域,我们期望采取更复杂的方法,因为我们还需关注扩展性以保持良好的性能。Spark 以及其他语言使我们能够利用分布式处理来帮助处理更大更复杂的数据结构。

在深入这个具体例子之前,我们可以总结一些数据处理并发需求的用例:

  • 对多个数据文件应用统一转换

  • 为多个数据子集预测未来值

  • 调整机器学习模型的超参数并选择最有效的配置

当我们需要执行如上所述的工作负载时,在 Python 和 Pandas 中最简单的方法是顺序处理这些数据。对于我们的例子,我们将对每个磁盘依次运行上述流程。

数据

在我们的例子中,我们有成千上万的磁盘数据,显示了随时间记录的空闲空间,我们希望预测每个磁盘的未来空闲空间值。

为了更清晰地描述情况,我提供了一个包含 1,000 个磁盘的 csv 文件,每个磁盘有一个月的以 GB 为单位的空闲空间历史数据。这一数据量足以让我们看到不同方法在大规模预测中的影响。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:示例 DataFrame

对于像这样的时间序列问题,我们希望使用历史数据来预测未来的趋势,并且我们希望了解对于每个磁盘哪种机器学习(ML)算法将最合适。像 AutoML 这样的工具在寻找一个数据集的适当模型时非常有效,但对于我们这里的 1,000 个数据集来说则有些过于复杂了。

在这种情况下,我们将限制要比较的算法数量为两个,并查看每个磁盘使用均方根误差(RMSE)作为验证指标时最适合使用的模型。关于 RMSE 的更多信息可以在 这里 找到。这些算法包括:

  • 线性回归

  • Fbprophet(将数据拟合到更复杂的线条)

  • Facebook 的时间序列预测模型。

  • 为了更复杂的季节性超参数预测而建立。

现在,我们已经准备好所有组件,如果我们想预测单个磁盘的未来剩余空间。操作步骤如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:数据生命周期

现在,我们希望将此扩展,对多个磁盘执行此流程,例如我们的示例中的 1,000 个。

作为我们评估的一部分,我们将比较在不同规模下计算不同算法的 RMSE 值的性能。因此,我创建了一个包含前 100 个磁盘子集的子集来模拟这一点。

这应该对不同大小的数据集上的性能提供一些有趣的见解,执行各种复杂度的操作。

引入并发性

Python 以单线程著称,因此不会在某一时间点利用所有可用的计算资源。

因此,我看到了三个选项:

  1. 实现一个 for 循环,顺序计算预测,采用单线程方法。

  2. 使用 Python 的 futures 模块同时运行多个进程。

  3. 使用 Pandas UDFs(用户定义的函数)在 PySpark 中利用分布式计算,同时保持我们的 Pandas 语法和兼容的包。

我希望在不同的环境条件下进行相当深入的比较,因此使用了一个单节点的 Databricks 集群和另一个具有 4 个工作节点的 Databricks 集群,以利用 Spark 进行我们的 Pandas UDF 方法。

我们将采用以下方法来评估线性回归和 fbprophet 模型在每个磁盘上的适用性:

  • 将数据分割为训练集和测试集

  • 使用训练集作为输入,在测试集日期上进行预测

  • 将预测值与测试集中的实际值进行比较,以获取均方根误差(RMSE)分数

我们将在输出中返回两件事情:一个带有预测的修改后的 DataFrame,为我们提供绘制和比较预测与实际值的额外好处,以及一个包含每个磁盘和算法的 RMSE 分数的 DataFrame。

执行此操作的函数如下所示:

我们将比较上述概述的三种方法。我们有几个不同的场景,因此可以填写一个表格,列出我们收集结果的标准:

使用以下组合:

方法

  • 顺序

  • futures

  • Pandas UDFs

算法

  • 线性回归

  • Fbprophet

  • 组合(每个磁盘的两种算法)——最有效的比较方式。

集群模式

  • 单节点集群

  • 标准集群,4 个工作节点

磁盘数量

  • 100

  • 1000

结果以这种格式呈现在本博客的附录中,如果您想进一步查看。

方法

方法 1:顺序

方法 2:concurrent.futures

使用此模块有两个选项:并行内存密集型操作(使用 ThreadPoolExecutor)或 CPU 密集型操作(ProcessPoolExecutor)。关于这一点的描述性解释可以在以下博客中找到。由于我们将处理 CPU 密集型问题,ProcessPoolExecutor 适合我们要实现的目标。

方法 3:Pandas UDFs

现在我们将切换到 Spark 并利用分布式计算来提高效率。由于我们使用的是 Databricks,大部分 Spark 配置已经为我们完成,但我们对数据的通用处理仍有一些调整。

首先,将数据导入到 PySpark DataFrame:

我们将使用 Pandas 分组映射 UDF(PandasUDFType.GROUPED_MAP),因为我们想传入一个 DataFrame 并返回一个 DataFrame。自 Apache Spark 3.0 起,我们不再需要显式声明此装饰器!

由于 PySpark 中的 DataFrame 结构,我们需要将 fbprophet、回归和 RMSE 函数拆分为 Pandas UDFs,但不需要对代码进行大规模重构即可实现。

我们可以使用applyInPandas来生成我们的结果。

注意:上述示例仅演示了使用线性回归的过程以提高可读性。有关完整演示,请参阅完整的笔记本

结果解读

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片:每种建议方法的预测算法执行时间

我们为不同的方法和不同的环境设置创建了图表,然后按算法和磁盘数量对数据进行分组,以便于比较。

请注意,表格结果见本帖的附录。

我已经总结了这些发现的亮点,如下所示:

  • 正如预期的那样,预测 1,000 个磁盘与 100 个磁盘相比是(通常)一个更耗时的过程。

  • 顺序方法通常是最慢的,无法有效利用底层资源。

  • 对于较小、更简单的任务来说,Pandas UDFs 效率很低。数据转换的开销更高 — 并行化有助于弥补这一点。

  • 顺序和concurrent.futures方法都没有意识到 Databricks 提供的聚类 —— 无法利用额外的计算资源。

结语

当然,背景在确定最成功的方法时起着重要作用,但考虑到 Databricks 和 Spark 通常用于大数据问题,我们可以看到在处理这里今天看到的那些更大更复杂的数据集时使用 Pandas UDFs 的好处。

当处理较小数据集时,使用 Spark 环境可以通过concurrent.futures模块在较小(且成本更低!)计算配置上以极高的效率完成,因此在设计解决方案时请记住这一点。

如果您熟悉 Python 和 Pandas,那么从在初学者教程中看到的顺序 for 循环方法转变,对于任何一种方法都不应该是一种费力的学习曲线。

我们在本文中没有进行深入调查,因为我发现当前版本存在差异和不兼容性,但最近的pyspark.pandas模块在未来肯定会更常见,并且是一个值得关注的方法。该 API(以及由 Databricks 开发的 Koalas,现已退役)利用了 Pandas 的熟悉性以及 Spark 的底层优势。

为了演示我们试图实现的效果,我们只查看了每个磁盘生成的 RMSE 值,而没有实际预测未来的时间序列值。我们在这里设置的框架可以以同样的方式应用于此,具有确定算法使用情况的逻辑,以及在可能情况下使用确定的算法预测未来值,其中包括评估指标(以及其他逻辑,例如磁盘的物理限制)是否适合每种情况。

总是可以在我的GitHub找到这个笔记本。

附录

最初发布于 https://blog.coeo.com,为了本次转载做了调整*。*

LLMs 的参数高效微调 (PEFT):全面介绍

原文:towardsdatascience.com/parameter-efficient-fine-tuning-peft-for-llms-a-comprehensive-introduction-e52d03117f95

对 Hugging Face、Google 的 Vertex AI 和最终 OpenAI 使用的 PEFT 方法进行概念性调查

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 肖恩·史密斯

·发表在 Towards Data Science ·19 分钟阅读·2023 年 8 月 22 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 DALL-E 创建。《大碗岛上的一个星期天下午》,但每个人都是类人。

大型语言模型(LLMs)的名字就意味着它们很大。这些模型通常具有从 70 亿到 700 亿个参数。以全精度加载一个 700 亿参数的模型需要 280 GB 的 GPU 内存!要训练这个模型,你需要在数百万或数十亿个文档上更新数十亿个标记。更新这些参数所需的计算量很大。这些模型的自监督训练成本很高,公司最高花费可达 1 亿美元

对于我们来说,适应我们的数据以适配这些模型引起了显著的兴趣。面对有限的数据集(相比之下)和缺乏计算能力,我们如何在成本的一小部分下创建可以超越主要竞争者的模型?

这就是参数高效微调(PEFT)研究领域的作用所在。通过各种技术,我们可以增强这些模型的小部分,使它们更适合我们要完成的任务。

阅读本文后,你将概念性地掌握 Hugging Face 中应用的每种 PEFT 技术,并能够区分它们之间的差异。在这篇文章之前,我发现的一个最有帮助的概述来自 Reddit 评论。此外,lightning.ai(pytorch lightning 的创作者)还有另一篇 杰出文章。另外,还有一篇基于大量此文的全面调查,由 Liali 等人撰写 [2]。在我的文章中,我旨在解决在审阅这些材料时识别出的空白点。在撰写本文时,本文作为 Hugging Face 库中所有 PEFT 方法的概念指南。读者的目标是以对该领域的基本理解来接触其他 PEFT 技术的研究文献。

自我反思的时刻:是时候进行微调了吗?

我写了一篇关于 微调 LLMs 的考虑因素 的文章,以及通过 In-Context Learning 实现类似性能的方法。自那时起,Llama 2 已经发布,开源 LLM 世界有了很大进步。这里是一些我可以分享的扩展思考,超出了那篇文章的范围。

微调本质上对你的组织是危险的。最近的一篇论文显示,LLMs 至少可以记住 1% 的训练数据 [1]。如果你有潜在的数据重复,这个 1% 的比例会更高。如果你的微调 LLMs 将被非内部用户使用,问问自己是否可以将你要训练的数据提供给他们。用户可能对你的模型进行恶意操作,比如 提示注入攻击。我在 LinkedIn 上发布了关于这些安全风险的帖子,作为一个快速概述。如果你不能提供你的数据,使用 ICL 的动态观察选择是你的最佳选择之一(详情请参见我的另一篇 文章)。

你还必须优先创建高质量的数据标签用于你的学习任务。如果组织对高质量数据的承诺不足,特别是对于你项目的微调支持,我建议考虑其他方法。模型依赖于高质量的标注输入。如果你的利益相关者对人工标注者的承诺不足,你可能会让所有相关方感到失望。

谁在使用 PEFT?

PEFT 被大多数提供语言模型微调能力的供应商使用。如果供应商尚未使用这些技术,我保证他们有计划使用这些技术。本文涵盖了在撰写时可用的Hugging Face PEFT中的所有技术。Lialin 等人的调查[2]在 Google 关于在 Vertex AI 上调整基础模型的介绍视频中被引用。虽然 Vertex AI 更像是一个黑箱,但我听说过销售推介中提到的适配器、提示微调,以及最近的 LoRa。虽然不清楚他们具体使用了什么,但我们在这里讨论的技术就是支撑这些功能的核心。

OpenAI 确实提供了微调服务,但至今尚未实现任何 PEFT 方法。这一点基于 OpenAI 几个月前要求删除的一篇博客文章。该文章详细说明了 OpenAI 不使用 Adapters 或 LoRa 来使微调更加计算友好。OpenAI 尚未宣布这些功能的实施,因此可以安全地假设这些功能目前还不可用。它已被纳入 OpenAI 的路线图中,并且由于微调比正常模型使用更具盈利性,我怀疑它将在不久的将来推出。

快速 Transformer 复习

我假设这篇文章的读者对 Transformer 架构已经有所了解。你不需要对自注意力或任何组件的细节有深入的了解,但你应该至少浏览过 Vaswani 等人的论文,并且可能对注释版 Transformer有过一些了解(在我看来,这是学习 Transformer 的最佳资源)。

我将包括 Transformer 块的伪代码。如果你对 Transformer 了解不多,只需知道它的核心功能如下:

def self_attention(x):
    k = x @ W_k
    q = x @ W_q
    v = x @ W_v
    return softmax(q @ k.T) @ v

def transformer_block(x):
    """ Pseudo code by author based on [2] """
    residual = x
    x = self_attention(x)
    x = layer_norm(x + residual)
    residual = x
    x = FFN(x)
    x = layer_norm(x + residual)
    return x

那些伪代码中的所有功能都如 Vaswani 等人所述。FFN 是前馈网络,对于我们的目的来说是 2 层。许多随后的 PEFT 技术对 Transformer 块或自注意力进行修改,因此在我们进行指南时,我会参考并修改这段伪代码。

PEFT 方法概览

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 [2] 的方法和类别概述。

我们将通过查看上图中的更广泛类别来逐一介绍每种技术。我们将涵盖的类别包括加性方法、适配器、软提示、重新参数化,以及一种混合方法,它是重新参数化和选择性(而不是稀疏 LoRa)的组合。

加性方法

加性方法可能是最容易理解的。加性方法的目标是添加一组额外的参数或网络层以增强模型。在微调数据时,您只更新这些新添加参数的权重。这使得训练计算上更容易,并且适应较小的数据集(起始时大约 100-500,最高接近 100,000)。

方法:适配器

适配器既是一种方法也是一种类别。这项技术在 Houlsby 等人 [4] 中被介绍。适配器的目标是在 Transformer 子层之后添加小的全连接网络,并学习这些参数。我遵循 [2] 中的定义,并将适配器严格定义为仅向网络添加全连接层。

Houlsby 等人提出了一种对 transformer 块进行简单更新的方法。他们在两个地方添加了全连接层,如下所示的伪代码。

def transformer_block_adapter(x):
    """Pseudo code from [2] """
    residual = x
    x = self_attention(x)
    x = FFN(x)  # adapter
    x = layer_norm(x + residual)
    residual = x
    x = FFN(x)
    x = FFN(x)  # adapter
    x = layer_norm(x + residual)
    return x

方法:(IA)³

通过抑制和放大内部激活的注入适配器,或 (IA)³ 是一种非常有趣的加性方法(添加参数),它通过一些新参数扩展了 transformer 块。它由 Liu 等人 [5] 于 2022 年提出。尽管名字如此,但这不是适配器方法,因为它并没有严格地在 transformer 块的子层之后添加全连接层。

让我们考虑正常 transformer 中的缩放点积注意力:

来自 Vaswani 等人 [3] 的缩放点积注意力

由于我们正在使用加性方法,我们寻求向此网络添加参数。我们希望维度尽可能小。(IA)³ 提出了以下新向量以添加到注意力机制中:

在 (IA)³ 中的缩放点积注意力来自 [5]。这里我们在普通方程中添加了两个列向量 l_k 和 l_v,它们分别与键和值项进行 Hadamard 乘积。

我们只是添加了列向量 l_k 和 l_v,并进行 Hadamard 乘积,即列向量与矩阵之间的乘法(将列向量与矩阵的所有列相乘)。

我们还引入了另一个可学习的列向量 l_{ff},它被添加到前馈层中,如下所示:

(IA)³ 中的前馈更新改编自 [5]。我们可以看到,我们将列向量 l_{ff} 添加到网络中,并与 transformer 块 FFN 的第一层输出进行 Hadamard 乘积。函数 gamma 是激活函数 GELU

在这个示例中,gamma 是应用于权重和输入之间乘积的激活函数。这是 (IA)³ 的一些伪代码:

def self_attention_ia3(x):
    k = x @ W_k
    q = x @ W_q
    v = x @ W_v

    k = l_k @ k  # ia3
    v = l_v @ v  # ia3

    return softmax(q @ k.T) @ v

def transformer_block_ia3(x):
    """Pseudo code from [2]"""
    residual = x
    x = self_attention_ia3(x)
    x = layer_norm(x + residual)
    residual = x
    x = x @ W_1  # normal transformer
    x = l_ff * gelu(x)  # ia3
    x = x @ W_2
    x = layer_norm(x + residual)
    return x

软提示

要理解软提示,我们首先讨论硬提示,一个大多数读者可能都熟悉的概念,即使不是通过名字。在硬提示中,我们会将代表当前任务的数据集整理在一起。当有人通过提出问题与网络互动时,他们可能会用不同的方式表达问题。使用硬提示,这个过程涉及策划一个数据集,涵盖语言模型可以框定特定任务的各种方式。

软提示(Soft-prompting)是一种试图避免创建数据集的技术。在硬提示中,我们是在离散表示中创建数据(选择单词)。在软提示中,我们寻求对输入到模型中的文本进行连续表示。这意味着你需要为正在训练的示例提供一个静态提示。

根据技术的不同,信息添加到网络中的方法也各不相同。核心思想是基础模型并不优化文本本身,而是优化提示文本的连续表示(即某种类型的可学习张量)。这可以是某种形式的嵌入或对该嵌入应用的某种变换。这些技术将在我们继续深入探讨时详细说明。

方法:Prompt-Tuning

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像来自 Lester 等人 [11] 的 prompt-tuning。这表明,在 prompt-tuning 中,我们将软提示和输入文本的表示(嵌入)连接到预训练模型中。这样做使我们能够通过一个可学习的张量来优化软提示的表示。

Prompt tuning 是 Lester 等人 [11] 提出的技术,属于软提示的范畴。使用软提示时,我们的目标是向基础模型添加更具体于当前任务的信息。通过 prompt tuning,我们通过创建一组提示令牌的参数并将其注入到网络的开头来实现这一点。

为了找到软提示的表示,我们为训练过程中使用的静态提示创建了一组单独的嵌入。我们将输出嵌入与序列嵌入进行连接。我们使用这些新信息传递到语言模型中。创建这种双重信息使我们能够学习软提示的参数化,而无需为同一任务创建多个提示。

def prompt_tuning(seq_tokens, prompt_tokens):
    """ Pseudo code from [2]. """
    x = seq_embedding(seq_tokens)
    soft_prompt = prompt_embedding(prompt_tokens)
    model_input = concat([soft_prompt, x], dim=seq)
    return model(model_input)

通过这种方法进行微调有许多丰富的好处。这组新参数可以非常小,大约是基础模型可调参数的 0.01%。这创造了一个使用相同基础模型的任务特定模型集合的机会,这大大减少了模型的内存需求。有关更多信息,请查看我在LinkedIn上分享的帖子以及[3]中的集成部分。

方法:前缀调整

前缀调整是另一种与提示调整非常相似的软提示技术。在提示调整中,我们创建了一组单独的参数,将输入传递给这些参数,并将输出附加到模型中输入文本的连续表示上。在前缀调整中,我们还从一组单独的提示标记中找到一个连续的表示,这些标记被输入到基础模型中。

前缀调整和提示调整的区别在于,前缀调整的表示会传递到变换器的所有层,而提示调整仅与嵌入层级联。此外,我们还为前缀调整学习了额外的参数,形式为一个全连接网络。训练后,FFN 被丢弃,我们只使用软提示作为输入。

def transformer_block_prefix_tuning(x, soft_prompt):
    """ Pseudo code from [2] """
    soft_prompt = FFN(soft_prompt)
    model_input = concat([soft_prompt, x], dim=seq)
    return model(model_input)

方法:P-Tuning

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

刘等人[6]提供的 P-Tuning 图像。该图展示了提示嵌入的创建过程,整个提示编码与输入嵌入进行级联。提示编码器包括一个嵌入层、LSTM,然后是一些全连接层。

P-Tuning 是刘等人[6]提出的另一种软提示方法,与提示调整(prompt tuning)和前缀调整(prefix tuning)不同。通俗地说,我们可以将 P-Tuning 视为提示调整,但使用 LSTM 对提示进行编码。

P-Tuning 旨在解决作者注意到的两个问题。第一个是传递给模型的词嵌入的离散性。作者认为,如果嵌入是随机初始化的,然后通过随机梯度下降优化,模型可能会陷入局部最小值。第二个是词嵌入的关联性。在提示调整和前缀调整的参数化中,软提示在技术上是相互独立的。作者希望找到一种方法,使提示标记相互依赖。

作者们提出,提示是一个函数,它接受一个上下文 x 和一个目标 y,并将其组织成一个模板 T。作者提供了示例序列“The capital of Britain is [MASK]”。这里的提示是“The capital of … is …”,上下文是“Britain”,目标是[MASK]。我们可以使用这种表述创建两个标记序列:上下文之前的所有内容以及上下文之后和目标之前的所有内容。我们可以学习这些附加信息的表示,并将其缩减为连续输出,输入到语言模型中。

为了以这种方式嵌入提示,我们使用一个由 LSTM 组成的小型网络,输入到一个两层的 FFN 中。我们传递提示标记,包括上下文之前的和目标之前的标记。

def p_tuning(seq_tokens, prompt_tokens):
    """Pseudo code for p-tuning created by Author."""
    h = prompt_embedding(prompt_tokens)
    h = LSMT(h, bidirectional=True)
    h = FFN(h)

    x = seq_embedding(seq_tokens)
    model_input = concat([h, x], dim=seq)

    return model(model_input)

方法:LLaMA-Adapter

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从 Zhang 等人[7]处获得的 LLaMA-Adapter 图像。我们可以看到零初始化的注意力被用于适配提示,并且这些是唯一被微调的内容。

LLaMA adapter 是 Zhang 等人[7]引入的一种软提示技术,它将更高效的前缀学习版本应用于 Llama 模型。

LLaMA-Adapter 与 Prefix Tuning 有一些关键差异。它们引入了适配提示,这些是附加到变换器层输入的软提示。这些适配提示被插入到 N 个变换器层的最上层 L 处。

作者们还引入了零初始化注意力。通过加法方法,我们引入了一组在权重上有一些随机初始化的新参数。由于这种随机噪声的加入,可能会经历不稳定的微调,这可能在早期阶段导致较大的损失值。为了解决这个问题,作者们引入了一个初始化为 0 的门控因子,该因子与自注意力机制相乘。门控因子和自注意力的乘积称为零初始化注意力。门控值在训练步骤中自适应调整,以创建更平滑的网络参数更新。

 def transformer_block_llama_adapter(x, soft_prompt, gating_factor):
    """LLaMA-Adapter pseudo code created by Author"""
    residual = x

    adaption_prompt = concat([soft_prompt, x], dim=seq)
    adaption_prompt = self_attention(adaption_prompt) * gating_factor  # zero-init attention

    x = self_attention(x)
    x = adaption_prompt * x
    x = layer_norm(x + residual)
    residual = x
    x = FFN(x)
    x = layer_norm(x + residual)

    return x

基于重参数化的方法

基于重参数化的方法专注于找到与基础模型中相同的权重矩阵的低维表示。Hu 等人[8]首次展示了微调与低维表示之间的联系。作者在模型的完整参数和较低维度表示之间建立了联系。根据任务,作者能够用大约 0.0002%的可训练参数实现完全微调模型的 90%的结果。

方法:LoRa

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从 Hu & Shen 等人[9]处获取的图像。我们可以看到预训练权重以及额外的矩阵 A 和 B。A 被正常初始化,而 B 被初始化为 0。我们仅训练 A 和 B。

微调中最流行的技术之一是基于重参数化的方法,称为低秩适应(LoRa)[9]。LoRa 通过学习一个单独的矩阵来更新权重矩阵,该矩阵表示来自优化的更新。它们进一步创建两个较小维度的权重矩阵来表示这种差异。通过创建较小维度的权重矩阵,我们需要学习的参数更少。

为了训练 LoRa,我们使用梯度下降的基本思想,在这一过程中,我们对一组参数进行渐进调整,使我们更接近目标(损失函数)。在 LoRa 中,我们选择将所有更新隔离到一个单独的矩阵中。这个矩阵,我们称之为 Delta W,表示我们在微调过程中学到的所有参数更新。

让我们将 W_0 的维度设为 dxk(d 行和 k 列)。我们希望更新它的参数,以使其与我们的新目标对齐。你可以通过ΔW 来表示这个参数的更新,它的维度也是 dxk。我们可以使用下面的方程来建模我们的更新规则。

来自[9]的 W_0 的更新规则。我们将 W_0 的变化隔离到 DeltaW 中。然后,我们将 DeltaW 表示为 A 和 B 两个较小维度矩阵的乘积。这样,我们学习的参数更少,但仍然更新 W,这使得微调的计算更加简单。

现在让我们改变更新规则,使得ΔW 由矩阵乘法 AB 建模。我们将矩阵 A 指定为 dxr 的维度,矩阵 B 指定为 rxk 的维度。如果你对矩阵乘法有所了解,你会发现 AB 与 W_0 的维度相同,因此这些矩阵的加法是有效的。这里是 AB 比 DeltaW 更好的原因:矩阵 A 仅有 dxr 的维度,而矩阵 B 有 rxk 的维度。如果我们将 r 设为一个非常小的数值(r=8 是一个典型值),那么 A 和 B 中的参数数量会比ΔW 小得多。如果我们只学习 A 和 B 的参数,我们将学习d*k-d*r-r*k更少的参数。实际上,这使我们只需学习原始网络参数的 0.1–0.5%。

我刚刚描述的过程就是 LoRa 工作原理的精髓。我们不是通过额外的训练步骤来优化矩阵 W,而是通过两个参数远少的新矩阵 A 和 B 来改变矩阵ΔW。这种结果帮助我们优化了更少的参数,使得训练更为高效。

通常我们将这个更新规则应用于变换器块中的自注意力的关键和数值矩阵。我们还添加了一个缩放因子,设为 1/r,以调整更新所提供的信息量。请参见下面的伪代码。

def lora_linear(x, W):
    scale = 1 / r  # r is rank
    h = x @ W
    h += x @ W_a @ W_b  # W_a,W_b determined based on W
    return scale * h

def self_attention_lora(x):
    """ Pseudo code from Lialin et al. [2]."""

    k = lora_linear(x, W_k)
    q = x @ W_q
    v = lora_linear(x, W_v)
    return softmax(q @ k.T) @ v

选择性方法

使用选择性方法,我们选择一些参数进行更新,而不更新其他参数。这些方法的问题在于我们创建了一个稀疏的参数矩阵。稀疏矩阵操作在现代 GPU 上不太被支持,并带来计算挑战。有关稀疏矩阵为何会产生计算挑战的更多信息,请查看[10]。

在选择性方法中,还有一些技术专注于修剪不成功的向量或操控模型偏差。这些方法在训练模型时也会增加额外的复杂性。一般来说,这些方法的实现更具挑战性,因为它们的计算操作比其他操作更为昂贵。

方法:AdaLoRa

这是一种混合方法,结合了重新参数化和选择性方法的思想。Zhang 等人 [12] 通过研究 LoRa 开发了 AdaLoRa,并提出了这样一个问题:“我们如何根据模块的重要性自适应地分配参数预算,以提高参数高效微调的性能?”这意味着“我们如何优先考虑那些能带来更好性能的参数,而不是平等对待所有参数?”

与我们在 LoRa 中使用的两个矩阵 A 和 B 不同,AdaLoRa 使用奇异值分解 (SVD)的近似来将向量空间的维度降低到三个矩阵:P(左奇异向量)、Lambda(奇异值)和 Q(右奇异向量)。使用这三个矩阵,我们可以重建向量空间 Delta 的近似值,即 P * Lambda * Q。使用 SVD 的好处在于奇异值表示了在这个低维空间中向量的重要性。本文的贡献在于应用一些高效实现,以使用与 SVD 相关的方法来考虑哪些权重应当被优化。

在 LoRa 中,我们看到可以用两个矩阵 A 和 B 来近似 delta W。在这里,我们可以用新的近似值 P * Lambda * Q 来替代 A 和 B。由于 Lambda 仅在对角线(奇异值)上有值,我们将其存储为列向量。我们选择矩阵 P(d x r)、Lambda(r x r)和 Q(r x k)的维度,以匹配权重矩阵 W(d x k)的维度。

另一个新颖的结果是使用一种特殊的重标定技术来确定可以被修剪掉的 SVD 元素。基本上,该技术考虑一组三元组(SVD 的每个条目),并确定它们对低维表示的重要性。他们通过使用一个将奇异值与左/右奇异向量相关联的函数来实现这一点。这些函数随后经过一个灵敏度函数,该函数结合了梯度权重乘积的指数移动平均(伪重要性)以及另一个称为不确定性量化的函数,后者也在当前步骤和前一步骤之间进行指数平均。

在修剪 SVD 元素时,通过删除最不重要的三元组,低维度的秩(矩阵的 r 项)会被迭代地改变。他们通过一个全局预算调度器来实现这一点,该调度器在训练步骤中逐渐降低秩 r。预算初始化为目标预算的 1.5 倍,并在经过 t 次预热步骤后按立方方式减少到目标预算。

从概念上讲,这是一个难以理解的方法。如果您有技术背景,我鼓励您阅读论文以了解该方法的内部工作原理。如果您记住这是一个高效的 SVD 实现应用于 LoRa,并结合了修剪不重要的奇异向量,那么在概念层面上,这可能是安全的。

 def adalora_linear(x, W, curr_sv):
    scale = alpha / r  # r is rank
    h = x @ W

    # p, lamda, and q are related to the W matrix
    # curr_sv marks which singular vectors we are currently optimizing. 
    h += x @ p[curr_sv] @ lamda[curr_sv] @ q[curr_sv]
    return scale * h

def self_attention_lora(x):
    """
    AdaLoRa pseudo code created by author. 
    This only shows the difference in the self_attention block. 
    Does not include code for pruning techniques.
    """
    k = adalora_linear(x, W_k)
    q = x @ W_q
    v = adalora_linear(x, W_v)

    return softmax(q @ k.T) @ v

方法比较

为了在一个地方比较所有的方法,我创建了下面的表格来展示它们的可训练参数数量(这些都是网络的附加参数)、方法类型以及对方法的非正式总结。非正式总结是我如何用一句话向一个从未听说过该方法的大学生描述该方法。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表格改编自 Lialin 等人[2]。作者贡献了 P-Tuning、LLaMA-Adapter 和 AdaLoRa,并在非正式总结栏中进行总结。非正式总结是我如何用一句话向大学生描述这篇论文的方式。

这是您需要的唯一概念性指南吗?

我认为这是您需要的唯一概念性指南,因为阅读后您将理解 PEFT 技术的基础知识。如果您注意到,所有技术都扩展了其他技术的思想。在这次介绍之后,您已经理解了足够的基础知识,您可以自己探索研究论文。然而,如果您最终需要另一个概念性指南来理解这些概念,请在文章的评论中分享,以便其他读者可以找到这些资源!

现在是开始的时候了!

在这次概念性回顾之后,您处于一个很好的起点,可以开始实验这些方法来训练自己的模型。来自Hugging Face有很多很棒的实现指南。如果您希望采取更少的动手方法,您可以使用 Google 的 Vertex AI 模型或使用 OpenAI 的微调服务。

感谢您阅读这篇文章!如果您有其他问题或有不清楚的地方,请留言,我会回复您。如果您想看到更多类似的文章,请在 Medium LinkedIn上关注我。

如果您发现本文中的技术错误,请尽快告知我!我努力确保我发布的信息尽可能准确,但没有人是完美的。

参考文献:

[1] Nicholas Carlini, Daphne Ippolito, Matthew Jagielski, Katherine Lee, Florian Tramer, & Chiyuan Zhang. (2023). 量化神经语言模型的记忆。

[2] Vladislav Lialin, Vijeta Deshpande, & Anna Rumshisky. (2023). 缩小规模以扩大规模:参数高效微调指南。

[3] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, & Illia Polosukhin (2017). 关注即是你所需要的*。CoRR, abs/1706.03762*。

[4] Neil Houlsby, Andrei Giurgiu, Stanislaw Jastrzebski, Bruna Morrone, Quentin de Laroussilhe, Andrea Gesmundo, Mona Attariyan, & Sylvain Gelly (2019). NLP 的参数高效迁移学习*。CoRR, abs/1902.00751*。

[5] Haokun Liu, Derek Tam, Mohammed Muqeeth, Jay Mohta, Tenghao Huang, Mohit Bansal, & Colin Raffel. (2022). 少样本参数高效微调优于上下文学习且更经济。

[6] Xiao Liu, Yanan Zheng, Zhengxiao Du, Ming Ding, Yujie Qian, Zhilin Yang, & Jie Tang (2021). GPT 也能理解*。CoRR, abs/2103.10385*。

[7] Renrui Zhang, Jiaming Han, Chris Liu, Peng Gao, Aojun Zhou, Xiangfei Hu, Shilin Yan, Pan Lu, Hongsheng Li, & Yu Qiao. (2023). LLaMA-Adapter:零初始化注意力的语言模型高效微调。

[8] Armen Aghajanyan, Luke Zettlemoyer, & Sonal Gupta (2020). 内在维度解释了语言模型微调的有效性*。CoRR, abs/2012.13255*。

[9] Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, & Weizhu Chen (2021). LoRA:大语言模型的低秩适应*。CoRR, abs/2106.09685*。

[10] Trevor Gale, Matei Zaharia, Cliff Young, & Erich Elsen. (2020). 深度学习的稀疏 GPU 内核。

[11] Brian Lester, Rami Al-Rfou, & Noah Constant. (2021). 参数高效提示调优的规模效应。

[12] Qingru Zhang, Minshuo Chen, Alexander Bukharin, Pengcheng He, Yu Cheng, Weizhu Chen, & Tuo Zhao. (2023). 参数高效微调的自适应预算分配。

帕累托、幂律与胖尾

原文:towardsdatascience.com/pareto-power-laws-and-fat-tails-0355a187ee6a

他们在统计学中不会教你的东西

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shaw Talebi

·发表于 Towards Data Science ·12 分钟阅读·2023 年 11 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

黑天鹅。图片来自 Canva。

统计学是数据科学和分析的基础。它为我们提供了一个强大的工具箱,以客观地回答复杂的问题。然而,当应用于特定类型的数据——幂律时,许多我们喜爱的统计工具变得无用。

在这篇文章中,我将提供一个对初学者友好的 Power Laws 指南,并描述使用传统统计方法分析这些规律时遇到的 3 个主要问题。

目录

  1. 背景高斯分布、帕累托的 80–20 规则、幂律,以及体重与财富之间的差异。

  2. 统计学 101 的 3 个问题你需要(很多)更多的数据。

  3. 胖尾避免争议并量化高斯分布与帕累托分布之间的差距。

称量你的咖啡师

自然界中的许多量往往围绕典型值聚集。例如,如果你坐在一个(繁忙的)咖啡店里,测量所有进进出出的咖啡师和顾客的体重,你会(最终)观察到如下图所示的模式。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

示例高斯分布。 技术说明:在测量成年人的体重时,每个性别会出现类似高斯分布的模式。图片由作者提供。

这个图是高斯分布的一个例子,你可能在统计学 101 或商业统计中遇到过。高斯分布的美在于,我们可以用一个数字——均值——捕捉到大部分核心信息(例如咖啡师的体重)。

更进一步,我们可以通过标准差方差等测量方式来获取数据的分散程度

这些在入门统计学中每个人都会学到的概念给了我们一个强大的数据分析方法。然而,并非所有我们关心的量都具有围绕典型值聚集的这种定性特征。

帕累托原理(80–20 法则)

你可能听说过所谓的“80–20 法则”,其标语是“80%的销售额来自 20%的客户”。然而,这一概念并非源于销售和市场营销,而是起源于维尔弗雷多·帕累托对意大利土地所有权的研究(约 1890 年)[1]。

帕累托观察到意大利约 80%的土地由 20%的人口拥有。这一简单观察表明,统计特性与我们所熟悉和喜爱的高斯分布有很大不同

也就是说,“80–20 法则”是帕累托分布的结果。下图对此进行了说明。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

帕累托分布,其中 20%的人口占据 80%的体量。图像由作者提供。

高斯分布与帕累托分布之间的关键区别在于帕累托分布没有“典型值”,我们无法用来有效总结分布。

换句话说,虽然知道一个意大利男人的平均体重(约 175 磅)能让你对下一次罗马之行有个大致的预期,但知道一个意大利城市的平均人口(约 7500)却是无用的

功率法则分布

帕累托分布是功率法则这一更广泛分布类别的一部分。我们可以如下定义功率法则[2]。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

功率法则分布类别的定义[3]。图像由作者提供。

其中PDF()表示随机变量X的概率密度函数。xX的特定值。L(x)是一个在[x_min, ∞]范围内缓慢变化的正函数。x_min是功率法则有效的最小值(即PDF(x) = 0 当 x<x_min*)[2]。α是一个数字(通常在 2 到 3 之间)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

各种α值的功率法则分布示例。注意:α = 1.16 大致意味着 80–20 法则。图像由作者提供。

正如上面的图示所示,功率法则与高斯分布在定性上非常不同。这形成了一种高斯型与帕累托型分布之间的二分法。换句话说,高斯和功率法则分布为定性地分类现实世界中的事物提供了概念性基准。

中庸斯坦与极端斯坦

作者Nassim Nicholas Taleb通过他称之为“中庸斯坦”与“极端斯坦”的两个类别描述了高斯型和帕累托型事物之间的二分法。

Mediocristan是类似高斯分布的事物的土地。其公民的一个基本特性是没有单个观察值会显著影响总体统计数据[3]。例如,假设你在罗马旅行时称重每个斗兽场的游客,并计算平均体重。如果你加入地球上最重的意大利人,这个平均值将几乎不变(+0.5%)。

在这种概念景观的另一端是Extremistan,在这里我们看到相反的统计特性。即,在 Extremistan 中,单个观察值可以(并且通常会)驱动总体统计数据。考虑到在斗兽场的相同游客,但不是测量他们的体重,而是询问他们的净资产并计算平均值。与之前不同的是,如果我们将世界上最富有的意大利人乔瓦尼·费雷罗(巧克力+榛子家族)加入样本,这个平均值将会剧烈变化(+2500%)。

为了更好地理解这些类别,请参考下面图像中列出的例子。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 Mediocristan 和 Extremistan 的项目,分别[3]。图像由作者提供。

正如你所见,Extremistan 中的帕累托式居民并不是一个小的或微不足道的集合。事实上,我们关心的许多事物并非像我们在 STAT 101 中研究的高斯曲线。

虽然这看起来可能过于技术性和教学性,但使用我们熟悉的统计技术和直觉来分析来自 Extremistan 的数据存在重大局限性,甚至在某些情况下可能有显著的风险。

STAT 101 思维的 3 个问题

正如我们在罗马斗兽场看到的,来自 Mediocristan(例如体重)的数据具有与 Extremistan(例如财富)相反的特性。

使用 STAT 101 技术分析幂律(即来自 Extremistan 的数据)的最大问题之一是像均值、标准差、方差、相关性等数量几乎没有实际意义

这一切都源于一个核心问题——数据不足

在统计学中,我们学习到大数法则,它表明如果我们取 N 个随机样本,样本均值将趋近于N真实均值。这对任何分布(具有有限均值)都是正确的:高斯分布、幂律分布、均匀分布,等等。

然而,事实证明这种渐近行为在某些分布中比其他分布发生得更缓慢(例如,幂律分布比高斯分布更慢)。在实践中,由于我们(必然地)拥有有限的数据集,这可能会造成问题。在这里,我突出三个这样的问题。

问题 1:均值是无意义的(以及许多其他指标)

每当我们想比较两组值(例如四月与五月的销售额、洛杉矶与纽约的交通事故、对照组与治疗组的患者结果)时,我们通常会计算均值。这为我们提供了一种将多个值压缩成一个代表性数字的直观方法。

对于遵循良好高斯分布的数据,这种方法效果极佳,因为在小样本量(N=~10)中可以准确估计均值。然而,当处理遵循幂律分布的数据时,这种方法会失效

我们可以通过比较高斯分布和幂律分布的样本均值随样本量的增加,来观察这一点,下面的图表展示了 N=100、N=1,000 和 N=10,000 的情况。幂律和高斯样本均值分别用橙色和蓝色表示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

三种不同样本量的样本均值收敛情况。图片由作者提供。

如我们所见,幂律样本均值比高斯分布更为不稳定(且有偏)。即使样本量增加到 N=100,000,幂律的准确性仍然远远低于高斯分布在 N=100 时的表现。这在下面的图表中有所体现。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

尽管均值在 N=1,000,000 时略有稳定,但与高斯分布相比仍然有显著偏差。图片由作者提供。

这种不稳定的行为不仅限于均值。它还适用于许多常用的统计量。下图展示了中位数、标准差、方差、最小值、最大值、1st 和 99th 百分位数、峰度和熵的类似收敛图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

三种样本量的其他指标收敛图。自上而下:中位数、标准差、方差、最小值、最大值、1st 和 99th 百分位数、峰度和熵。图片由作者提供。

如我们所见,某些指标相较于其他指标更稳定。例如,中位数、最小值和百分位数相对稳定。而标准差、方差、最大值、峰度和熵似乎无法稳定在一个数字上。

在这一组中,我想特别指出最大值,因为这一量度可能在小样本中看似收敛, 但随着 N 的增大,它可能会跳跃一个数量级(如 N=10,000 图中所见)。这特别危险,因为它可能导致虚假的可预测性和安全感。

为了将这一点与现实世界联系起来,如果基础数据是例如,疫情死亡人数,那么过去 100 年的最大疫情将比过去 1,000 年的最大疫情小 10 倍。

例如,过去 100 年中最致命的疫情是西班牙流感(约 5000 万人死亡)[4],因此如果疫情死亡人数遵循幂律分布,我们可以预计在接下来的 1,000 年中,会有一个疫情造成 5 亿人丧生(抱歉,这个例子有些黑暗)。

这突出了 Extremistan 数据的一个关键属性,即罕见事件驱动了整体统计数据

然而,这并不仅仅是统计指标的问题。罕见事件的严重性也影响了我们有效进行预测的能力。

问题 2:回归分析无效

回归归结为基于过去数据进行预测。然而,正如我们在问题 1 中看到的,当处理幂律分布时,我们可能没有足够的数据来准确捕捉真实的统计数据。

当使用幂律分布的变量进行回归时,如果α <= 2,这一点会变得更加严重。这是因为α <= 2 意味着该分布具有无限方差,这破坏了流行回归方法(例如最小二乘回归)的一个关键假设。

然而,在实际工作中,计算出的方差永远不会是无限的(数据必然是有限的)。这引发了类似于问题 1 的问题:结果可能看起来稳定,但随着数据的增加不会持续

换句话说,你的 R²在开发模型时可能看起来很棒,但随着样本量的增加,R²很快会恶化并接近实际值 R²=0。

我们可以通过一个(人工)示例来观察这一点。假设我们有两个变量 X 和 Y,它们线性相关(即 Y = mX + b),其中 X 服从正态分布,并且有一个遵循幂律分布的加性噪声项。当我们在小样本量(N=100)中进行回归时,拟合效果看起来非常好。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

涉及具有加性噪声且遵循幂律分布的预测变量的线性回归拟合,样本量较小(N=100)。图片由作者提供。

然而,当我们收集更多数据(N=100,000,000)时,R²会正确地下降到实际值(即 R²=0)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

随着样本量的增加,R²接近实际值(即 R²=0)。图片由作者提供。

问题 3: 回报概率偏离

在这一点上,你可能会想,“Shaw… 有什么大不了的?如果我的模型不能预测一些罕见事件,那又怎么样?它大多数时候是正确的。

我同意你的观点。当处理来自极端国家的数据时,大多数数据不在尾部,因此大多数时间容易是正确的。然而,在预测结果和做出决策时,概率只是部分故事。

故事的另一半是回报。换句话说,不仅仅是关于你有多频繁地正确*(错误),还包括当你正确(错误)时会发生什么*。

例如,如果你被提供一种每日多种维生素,它 99.9%的时间效果很好,但有 0.1%的时间可能致命,你可能会选择另一种品牌(或更健康的食物)。

单纯依赖概率来做决策在处理幂律和“80-20 规则”时尤其有害。考虑以下商业示例。

假设我们有一家软件公司,提供 3 种产品:1)有广告的免费版,2)高级版,3)企业版,其中每种产品的客户和收入分布见下表。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每种产品的客户和收入分布。图片由作者提供。

公司希望推出一个更新,将处理时间提高 50%。作为一家前沿的数据驱动技术公司,他们调查了活跃用户,发现95%的客户更喜欢更新后的软件。掌握了这些数据,公司批准了软件更新。

然而,六周后,公司陷入了混乱,因为收入下降了 50%。

事实证明,更新后,3 位客户取消了服务,因为更新移除了对他们用例至关重要的遗留数据集成。但这些不是普通客户。这些是公司的前 3 大客户(约 1%),占公司收入的约 50%(考虑到他们的定制增销)。

这就是当只关注概率(95%的客户喜欢更新)时可能犯的(致命)错误。故事的寓意是,当处理来自极端领域的稀有事件数据时,错一次可能抵消对 99 次正确的判断(甚至更多)

生成图表的代码 👇

[## YouTube-Blog/power-laws 在主分支 · ShawhinT/YouTube-Blog

代码用于补充 YouTube 视频和 Medium 博客帖子。 - YouTube-Blog/power-laws 在主分支 · ShawhinT/YouTube-Blog

github.com](https://github.com/ShawhinT/YouTube-Blog/tree/main/power-laws?source=post_page-----0355a187ee6a--------------------------------)

极端主义中的争议

幂律,如高斯分布,是一种理想化的数学抽象。然而,现实世界却很混乱,几乎不可能完全符合我们美丽而精确的构造。这引发了一些争议,即某个特定的分布是否真正是幂律。

争论的一个点是财富是否是幂律(如帕累托的工作所示)或仅仅是一个对数正态分布 [5]。

部分争议可能通过观察到对数正态分布在低σ时表现如高斯分布,在高σ时表现如幂律来解释[2]。

然而,为了避免争议,我们可以离开(暂时某些给定数据是否符合幂律的问题,转而关注胖尾

胖尾性 — 测量中等主义和极端主义之间的距离

胖尾是一个比帕累托和幂律分布更一般的概念。我们可以把它看作“胖尾性”是稀有事件推动分布总体统计的程度。从这个角度看,胖尾性在从不胖尾(即高斯分布)到非常胖尾(即帕累托 80–20)的谱系上。

这与之前讨论的 Mediocristan 与 Extremistan 的概念直接对应。下面的图像展示了不同分布在这一概念景观中的可视化 [2]。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Mediocristan 与 Extremistan 的地图。注意:由于肥尾现象存在于一个范围内,因此将分布标记为“肥尾”与否具有一定的主观性。图片作者。

尽管没有准确的肥尾度量,但我们可以在实践中使用许多指标和启发式方法来感知给定分布在 Mediocristan 与 Extremistan 地图上的位置。以下是一些方法。

  • 幂律性:使用幂律尾部指数,即α——α越低,尾部越肥 [2]

  • 非高斯性:峰度(对于α ≤ 4 的幂律失效)

  • 对数正态分布方差

  • 塔勒布的κ指标 [6]

要点

面对肥尾数据的主要挑战是可能没有足够的数据来准确捕捉其潜在的统计属性。这些信息为数据从业者提供了一些启示。

  • 绘制分布例如直方图、PDF 和 CDF

  • 问问自己——这些数据来自 Mediocristan 还是 Extremistan(或介于两者之间的某个地方)

  • 在构建模型时,问问自己——正确预测的价值和错误预测的成本是什么?

  • 如果处理(非常)肥尾数据,不要忽视稀有事件。相反,弄清楚如何利用它们(例如,你能否为前 1%的客户做一个特别促销以推动更多业务?)

👉 更多关于幂律和肥尾的信息:幂律拟合 | 量化肥尾

[## 用 Python 检测现实世界中的幂律]

用示例代码分解最大似然法的方法

towardsdatascience.com](/detecting-power-laws-in-real-world-data-with-python-b464190fade6?source=post_page-----0355a187ee6a--------------------------------)

资源

联系我的网站 | 预约电话 | 问我任何问题

社交媒体YouTube 🎥 | LinkedIn | Twitter

支持请我喝杯咖啡 ☕️

[## 免费获取我撰写的每个新故事的访问权限]

免费获取我撰写的每个新故事的访问权限。附言:我不会与任何人分享你的电子邮件。通过注册,你将创建一个…

shawhin.medium.com

[1] 帕累托原则。 (2023 年 10 月 30 日). 载于 维基百科en.wikipedia.org/wiki/Pareto_principle

[2] arXiv:2001.10488 [stat.OT]

[3] 塔勒布, N.N. (2007). 黑天鹅:极不可能事件的影响. 纽约;兰登书屋。

[4] www.archives.gov/exhibits/influenza-epidemic/

[5] arXiv:0706.1062 [physics.data-an]

[6] Taleb, N. N. (2019). 你需要多少数据?一种操作性的、前渐近的尾部厚度度量。国际预测学杂志35(2),677–686。 doi.org/10.1016/j.ijforecast.2018.10.003

Parquet 最佳实践:在不加载数据的情况下发现你的数据

原文:towardsdatascience.com/parquet-best-practices-discover-your-data-without-loading-them-f854c57a45b6

元数据、行组统计、分区发现和重新分区

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Arli

·发表于 Towards Data Science ·8 分钟阅读·2023 年 1 月 3 日

如果你想亲自体验 Medium,可以考虑通过 注册会员 来支持我和其他成千上万的写作者。这只需每月$5,它对我们写作者的支持巨大,而且你可以访问 Medium 上所有精彩的故事。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Jakarta ParquetUnsplash 提供的照片

这篇文章是关于 Parquet 系列文章中的下一篇。如果你没有 Parquet 知识,应该先查看之前的 Parquet 文章 ,但它也是对更高级用户的很好的提醒。如果你想重现这篇文章的输入数据,代码在文末。

Apache Parquet 是一种用于大数据框架的列式存储格式,如Apache HadoopApache Spark。它旨在通过使用列式存储格式压缩高效的方式存储数据,从而提高大数据处理的性能。

Parquet 的采用持续增加,因为越来越多的组织转向大数据技术来处理和分析大型数据集。随着这种持续发展,学习一些最佳实践以及如何浏览Parquet文件变得至关重要。

在本教程中,我们将展示如何作为Parquet用户在不依赖常见的暴力加载方式的情况下,深入洞察你的Parquet数据。

案例研究

为此,我们提供了一个案例研究,其中一个数据工程师给你提供了贷款申请者的数据,你需要使用这些数据创建预测模型。但首先,你需要*“技术性地”*发现数据。数据量非常大。

确实,准备数据的数据工程师告诉你,Parquet 文件夹的大小为1TB仅用于教育目的,这在我们的示例中并非如此),因此如果你尝试加载所有内容,你的机器将会遇到内存错误。

不用担心,我们会提供最有效的方式来理解大型Parquet数据,甚至不需要将Parquet数据加载到内存中。

这意味着需要回答以下问题:

  • 这个文件夹中的Parquet文件是什么样的?

  • 变量里面有哪些?类型是什么?一些统计数据

  • 数据是如何分区的?

我们还会教你如何重新格式化分区,如果你发现数据分区的方式有问题的话。

阅读第一个 Parquet 文件

你在这个教程中需要的导入:

import pyarrow as pa
import pyarrow.parquet as pq
import os

首先,我们想了解文件夹*‘APPLICATIONS_PARTITIONED’*包含什么,这里存储了数据。

由于你不知道数据是如何分区的,因此不能盲目地加载整个文件夹,因为你将会加载所有的Parquet文件,这不是你想做的(记住 1TB 的大小),而是你需要对数据进行概览。

这里,我给你一个函数get_first_parquet_from_path(),它会返回目录中的第一个Parquet文件。该函数将扫描每个目录和子目录,直到找到一个Parquet文件,并返回该单个文件的完整路径。

def get_first_parquet_from_path(path):
    for (dir_path, _, files) in os.walk(path):
        for f in files:
            if f.endswith(".parquet"):
                first_pq_path = os.path.join(dir_path, f)
                return first_pq_path

看起来是个很酷的函数,让我们把它付诸实践。

path = 'APPLICATIONS_PARTITIONED'
first_pq = get_first_parquet_from_path(path)
first_pq
#Output : APPLICATIONS_PARTITIONED/NAME_INCOME_TYPE=Commercial associate/CODE_GENDER=F/6183f182ab0b47c49cf56a3e09a3a7b1-0.parquet

从路径中我们可以注意到,这里按NAME_INCOME_TYPECODE_GENDER进行分区,知道这一点很重要。

现在要读取这个路径以获取行数和列数,以及宝贵的Schema,你可以这样做:

first_ds = pq.read_table(first_pq)
first_ds.num_rows, first_ds.num_columns, first_ds.schema

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

执行时间不到 1 秒,原因是read_table()函数读取一个Parquet文件并返回一个PyArrow Table对象,该对象代表你的数据,作为一个由Apache Arrow开发的优化数据结构。

现在,我们知道有 637800 行和 17 列(+2 来自路径),并对变量及其类型有了概览。

等等,我之前告诉过你,我们不需要在内存中加载任何东西来发现数据。所以这里有一个方法,可以在不读取任何表的情况下做到这一点。

元数据

我在部分欺骗你,因为我们不会加载任何数据,而是会加载所谓的元数据

Parquet文件格式的上下文中,metadata指的是描述文件中存储的数据的结构特征的数据。这包括每列的数据类型、列的名称、表中的行数和模式等信息。

让我们使用pyarrow.parquet中的read_metadata()read_schema()函数:

ts=pq.read_metadata(first_pq)
ts.num_rows, ts.num_columns, pq.read_schema(first_pq)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这会给你与read_table()相同的输出。

然而,我们注意到执行时间上有很大差异,因为这里接近瞬时。这并不奇怪,因为读取metadata就像是读取Parquet文件中一个非常小的部分,它包含了你需要的所有信息来概述数据。

统计数据

现在假设我想多了解一下列,我该怎么办?

你可以从文件的第一个Row Group读取统计数据。

Parquet文件格式中,Row Group是将行作为一个单位存储在一起的集合,并分成更小的块以便于查询和处理。

parquet_file = pq.ParquetFile(first_pq)
ts=parquet_file.metadata.row_group(0)
for nm in range(ts.num_columns):
    print(ts.column(nm))

上面的代码会给你一个不太美观的输出,这里有一些代码可以将其格式化为一个漂亮的 DataFrame:

beautiful_df = pd.DataFrame()
for nm in range(ts.num_columns):
    path_in_schema = ts.column(nm).path_in_schema
    compressed_size = ts.column(nm).total_compressed_size
    stats = ts.column(nm).statistics
    min_value = stats.min
    max_value = stats.max
    physical_type = stats.physical_type
    beautiful_df[path_in_schema] = pd.DataFrame([physical_type, min_value, max_value, compressed_size])
df = beautiful_df.T
df.columns = ['DTYPE', 'Min', 'Max', 'Compressed_Size_(KO)']

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 DataFrame 中,你可以看到列的类型、最小值、最大值和压缩大小。从这个文件中得到的一些学习点:

  • 字符串列被转换为BYTE_ARRAY

  • 字符串列的最小值和最大值按字母顺序排序。

  • 布尔型的压缩大小不比BYTE_ARRAY好多少。

  • 最年轻的申请者 21 岁,最年长的是 68 岁。

要注意不要将统计数据泛化,这只是来自第一个parquet文件!

很好,现在我们对数据有了很好的理解,包括列的信息、类型、模式,甚至统计数据,但我们是否遗漏了什么?

分区

是的,我们不知道数据的分区!如前所述,我们可以从文件路径中至少猜测到分区列:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据按NAME_INCOME_TYPECODE_GENDER分区。但我们不知道其他分区值。假设我们想查看其他NAME_INCOME_TYPE

但我会提供你一段代码,这样你可以以更系统的方式获取分区,以及所有可能的分区值:

def get_all_partitions(path):
    partitions = {}
    i = 0
    for (_, partitions_layer, _) in os.walk(path):
        if len(partitions_layer)>0:
            key = partitions_layer[0].split('=')[0]
            partitions[key] = sorted([partitions_layer[i].split('=')[1] for i in range(len(partitions_layer))])
        else:
            break
    return partitions

让我们运行这个函数,它返回一个字典,其中键对应于分区列,值是与每个分区列关联的分区值

ps = get_all_partitions(path)
ps.keys(), ps.values()

我们现在知道数据工程师首先按Income_Type分区,然后按Gender分区。所有分区列的值如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

既然我们已经了解了分区列和分区值的知识,我们可以读取另一个感兴趣的分区。

假设我们想读取所有*‘Pensioner’的数据,无论Gender*是什么。

从上一个教程中,我们知道我们可以通过读取 Parquet 文件夹*‘APPLICATIONS_PARTITIONED/NAME_INCOME_TYPE=Pensioner’*来做到这一点

df_pensioner = pd.read_parquet('APPLICATIONS_PARTITIONED/NAME_INCOME_TYPE=Pensioner/')

重新格式化分区

实际上,我们不打算按性别拆分数据,而且数据的大小允许我们在没有过多运行时间的情况下读取两个性别的数据。

不要过度分区数据,因为通常,执行时间会随着文件夹中的分区数量增加而增加。因此,你必须记住,分区即使使数据在功能上更易读,也可能有潜在的缺点。(来自官方文档 512MB — 1GB 是分区的最佳大小)。

在这里,假设在检查数据后,我们认为性别的子文件夹足够小,并且发现性别的功能划分没有用处。我们决定将数据集重新格式化,仅按NAME_INCOME_TYPE进行分区:

pq_table = pq.read_table('APPLICATIONS_PARTITIONED')
pq.write_to_dataset(pq_table, 'APPLICATIONS_REPARTITIONED', partition_cols=['NAME_INCOME_TYPE'])

我们刚刚在PyArrow Table 对象中读取了数据,然后我们写了一个Parquet 文件,仅按NAME_INCOME_TYPE分区,不再按性别分区。如果我们现在运行get_all_partitions()函数,值为:

partitions = get_all_partitions('APPLICATIONS_REPARTITIONED')
partitions.keys(), partitions.values()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们注意到,我们不再按性别进行分区。

总之,你刚刚了解了如何浏览 Parquet 文件,了解有关数据的一切:例如列名、大小、模式、统计信息以及如何获取分区名称和值。你还发现了如何重新格式化分区以使其在技术上和功能上更为正确。

感谢阅读,下次故事见!

生成我们所用输入数据的完整代码:

继续阅读我其他的 Parquet 文章:

## 简单的 Parquet 教程和最佳实践

实用教程,开始你的 Parquet 学习

towardsdatascience.com [## Parquet 最佳实践:筛选的艺术

理解如何筛选 Parquet 文件

pub.towardsai.net](https://pub.towardsai.net/parquet-best-practices-the-art-of-filtering-d729357e441d?source=post_page-----f854c57a45b6--------------------------------)

通过我的推荐链接,你可以无额外费用地订阅 Medium。

[## 使用我的推荐链接加入 Medium — Arli

阅读 Arli 和成千上万其他 Medium 作者的每一个故事。你的会员费用直接支持 Arli 和…

medium.com

使用 Python 解析 HL7

原文:towardsdatascience.com/parsing-hl7-with-python-961e19c4d962

使用 python-hl7 提取健康等级 7 数据的指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Tara (Boyle) Hunter

·发布于 Towards Data Science ·阅读时间 4 分钟·2023 年 3 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Christina Victoria Craft 提供,来源于 Unsplash

解析 HL7 消息对于许多医疗数据专业人员来说是一个重要任务。HL7 在医院和其他医疗机构中常用于在不同系统、应用程序和提供者之间交换病人数据。

根据 维基百科的定义,HL7 或健康等级 7 “指一套用于在不同医疗提供者使用的软件应用程序之间交换电子健康信息的国际标准”。这些标准使医疗提供者能够以一致且易于解释的方式共享临床信息。

在这里,我们将学习如何使用 Python 解析 HL7 消息。

介绍

在开始解析之前,我们需要对 HL7 消息结构有一个基本的了解。虽然 HL7 有不同的版本,这里我们将重点关注最常用的 2.x 系列版本。

HL7 消息通常是在病人事件(如接收、出院或转院)发生时创建和发送的。

HL7 消息由多个部分组成。每一行称为一个段落。段落由三字符标签标识,并由回车符(\r)分隔。段落包含由管道符(|)分隔的字段。字段包含由脱字符(^)分隔的组件。最后,组件可以包含由 & 符号分隔的子组件。

常见的 HL7 段落

  • MSH — 消息头:消息的第一个段落,包含有关发送和接收系统、消息类型和消息控制 ID 的信息。

  • PID — 病人识别:该段包含病人识别信息,包括病人 ID、姓名和出生日期。

  • DG1 — 诊断:该段包含与病人诊断相关的信息。

  • Z 段:这些段不是原始 HL7 标准的一部分。它们用于创建自定义消息并增加 HL7 的灵活性。自定义段的格式和内容由发送和接收应用程序协商。

关于这些和其他段类型的更多信息可以在这里找到。

示例消息

在这里我们将处理这个虚构的入院消息:

## made up sample admit message
MSH|^~\&|HIS|HOSPITAL|LAB|LAB|20230131111929||ADT^A01|1000027|P|2.3||||
EVN|A01|20220131111924
PID|1||0012345678^^^MRN^MRN||Doe^John^R||19700101|M|||123 Main St.^^Anytown^CA⁹¹²³⁴^USA|||||||||||||||||||||
PV1|1||^^¹⁰⁰¹|||||||||||||||||||1||||||||||||||||||||||||||||||||||
DG1|1||123456789^Diagnosis^I9||Confirmed
ZCP|1|Custom Segment Data

在此消息中,MSH 段包含消息头信息,包括发送应用程序、发送机构、接收应用程序和接收机构。

EVN 段包含事件信息,表明这是一个 ADT A01 消息,并在 2023 年 1 月 31 日 11:19:24 发送。

PID 段包含病人识别信息,包括病人标识符、姓名、出生日期和地址。

PV1 段包含病人访问信息,如病人位置和主治医生。

DG1 段包含一个虚构的诊断代码“123456789”和描述“诊断”。诊断状态字段标记为“已确认”。

最终我们看到自定义 ZCP 段,包含段 ID 为“ZCP”和两个数据字段“1”和“自定义段数据”。

解析 HL7 消息

现在我们对 HL7 消息有了基本了解,接下来开始解析吧!

在这里我们将使用[python-hl7](https://python-hl7.readthedocs.io/en/latest/) 。它提供了简单的方法来使用 Python 解析 HL7 消息。

python-hl7可以通过pip轻松安装。

pip install hl7 

安装完成后,我们可以导入库并解析我们的示例消息:

import hl7

## sample message
msg = '''
MSH|^~\&|HIS|HOSPITAL|LAB|LAB|20230131111929||ADT^A01|1000027|P|2.3||||
EVN|A01|20220131111924
PID|1||0012345678^^^MRN^MRN||Doe^John^R||19700101|M|||123 Main St.^^Anytown^CA⁹¹²³⁴^USA|||||||||||||||||||||
PV1|1||^^¹⁰⁰¹|||||||||||||||||||1||||||||||||||||||||||||||||||||||
DG1|1||123456789^Diagnosis^I9||Confirmed
ZCP|1|Custom Segment Data
'''

parsed = hl7.parse(msg)
print(parsed)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

解析的 HL7 消息 — 作者提供的图像

在这里我们可以看到解析后的消息。我们可以检查消息的长度以确保它按预期被解析:

print(len(parsed))
>>> 1

在这里,我们看到解析后的消息长度为 1。由于我们可以在上述消息中数到 6 个不同的段,这似乎不正确。

这是由于我们的换行符问题。python-hl7期望换行符是回车符(\r),而不是换行符(\n)。我们可以通过将换行符替换为回车符来解决这个问题:

msg = msg.replace('\n','\r')
parsed = hl7.parse(msg)
print(len(parsed))
>>> 6

这看起来好多了!现在我们的长度为 6,与我们对消息的视觉检查一致。

现在让我们提取病人的姓名:

## get patient name
print(parsed[2][5])
>>> [[['Doe'], ['John'], ['R']]]

## get only first name
print(parsed[2][5][0][1][0])
>>> 'John'

我们成功提取了病人的全名和名字!

我们还可以通过指定段来访问这些信息:

## we can access the same information by specifying segement
parsed.segments('PID')[0][5]
>>> [[['Doe'], ['John'], ['R']]]

## and only the first name
parsed.segments('PID')[0][5][0][1][0]
>>> 'John'

这种方法的一个好处是可以让我们的代码更容易理解。

结论

在这里,我们了解了 HL7 消息以及如何使用 Python 的python-hl7库来解析它们。

这个库提供了一种方便的方式来使用 Python 读取 HL7 消息,并且支持大多数 HL7 消息类型和段落。

使用合适的工具,开发人员可以有效地和高效地用 Python 解析 HL7 消息!由于 HL7 是医疗保健领域广泛使用的标准,解析这些消息的知识对从事医疗行业的数据专业人士来说是一个有价值的技能。

在人道主义数据集中解析不规则电子表格(借助 GPT-3 的帮助)

原文:towardsdatascience.com/parsing-irregular-spreadsheet-tables-in-humanitarian-datasets-with-some-help-from-gpt-3-57efb3d80d45?source=collection_archive---------2-----------------------#2023-02-24

处理不规则 Excel 表格,无需使用硬编码规则

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 马修·哈里斯

·

关注 发布于 数据科学前沿 · 26 分钟阅读 · 2023 年 2 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由 DALL-E2 根据提示“10 张木桌的画作”创作。上图中有 9 张桌子。

简短说明

作为 之前的研究 的一部分,使用了来自 人道主义数据交换 的数据,我不得不分析成千上万的 Excel 文件,这些文件中的表格常常难以解析成数据库表。文件来自全球数百个组织时,合并单元格、不规则布局、层次化列和注释难以通过基于规则的解析来预见。在这篇文章中,我探讨了使用 GPT-3 的零-shot、单-shot 和推理补全来重新格式化不规则(小型)表格,并微调模型以预测表格属性,从而用于准确解析。

在我的旅行中,有不少次需要查看大量 Excel 文件,以了解它们包含的数据、数据的结构如何,以及将其清理成可以处理的形式所需的工作。大部分情况下,只要数据规则且列标题整齐,这个过程相当简单。然而,现实从未那么简单,这些文件中的表格往往以不完美的格式存在,难以解析成可以上传到关系数据库的数据框。Excel 支持许多功能,如数据透视表和单元格合并,人们使用这些功能创建各种各样的布局,包括空白行、随机文本等等!

这里有一个例子来说明……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Excel 中的不规则表格示例,带有空白顶部行、标签和合并单元格。对人类来说完全可读,但对数据科学来说是解析的挑战。该文件来自人道主义数据交换

如果我们直接将上述文件读入 Pandas 中……

import pandas as pd
df = pd.read_excel(filename)
df = df.fillna("")
display(df)

我们得到这个……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Pandas 数据框在解析 Excel 表格后的示例,其中包含空行和合并单元格,以指示层次列。示例数据来自人道主义数据交换

将其加载到数据库中会导致数据几乎无法使用,因为……

  1. 右上角单元格中有一个表格标题。

  2. 列‘Unnamed: 1’的标题实际上是第一列第 5 行的内容“你拥有的土地的平均面积是多少……”

  3. 列‘Unnamed:2’和‘Unnamed:3’是分为’N‘ 数值和‘%’ 百分比值的汇总总数。

  4. 大多数列是层次化的,合并单元格位于未合并单元格之上。

这也不那么糟糕,对吧?

当然,可以向Pandas read_excel提供参数,将层次列转换为索引,然后可以将其合并为一行。或者,我们可以使用Openpyxl中关于 Excel 自身的合并单元格的信息进行操作。然而,这些方法需要对表格有了解——特别是标题在哪里结束、数据从哪里开始以及层次列的结构——这是我们在处理成千上万的电子表格时可能不总是拥有的奢侈品。对大量文件进行基于规则的解析可能耗时且脆弱,需要随着新布局的出现而持续维护。

其实,我并不是唯一一个遇到这个问题的人!解析不规则表格是一项正在积极研究的挑战。例如,微软的作者展示了利用卷积神经网络开发的一个名为‘TableSense’的算法的出色成果[1]。这种技术将 Excel 表格视作图像来处理,但具有更丰富的特征化,因为每个单元格可能具有多种属性和数据类型,还包括格式化和合并特征。非常酷。我希望像这样的精彩工作能尽快纳入微软的产品中,但在此之前,我想探索一些其他的方法。

值得注意的是,我的使用案例不仅仅是识别表格在工作表中的范围(参见微软论文的训练数据),还包括表格中的元素,以便将不规则的格式转换为可以轻松导入数据库的格式。主要挑战是 Excel 中的层次列,将这些层次列展平成一个单独的行,从而捕捉上层合并单元格中的信息。听起来解决起来很简单,但挑战是:标题在哪里结束,数据从哪里开始?这对我们人类来说显而易见,但令人惊讶的是,当用代码处理工作表时,这样简单的事情在现实世界中可能会变得非常嘈杂。

鉴于最近对生成式 AI 和大型语言模型(LLMs)的关注,我想知道也许OpenAI 的 GPT-3可能会接受这个挑战。这些模型在从互联网提取的大量数据上进行了训练,其中包括表格和 CSV 文件,因此它们可能在处理我们这些疯狂人类拼凑的表格的某些细节方面会很有用。

提示 GPT-3 清理(一个小的)表格

我们将首先尝试将问题作为零样本和少量样本任务解决,然后再转向使用微调技术。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

零样本、单样本和少样本任务,与传统的微调对比。上面的面板展示了用语言模型执行任务的四种方法。来源于 Brown 等人 [2]。

GPT-3 是在从网络上抓取的文本上训练的,所以我们不能用 Excel 提示它(还不行!),因此我们首先必须将我们的表格转换成一种网络上常见的形式,例如 CSV 字符串……

df = pd.read_excel('sample.xlsx', sheet_name='Sheet1')
df = df.fillna("")
print(df.to_csv())
,Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,Unnamed: 4,Unnamed: 5,Unnamed: 6,Unnamed: 7,Unnamed: 8,Unnamed: 9,Unnamed: 10,Unnamed: 11
0,Table 3: Number of acreage under irrigation,,,,,,,,,,,
1,,,OVERALL,,Sub county,,,,,,,
2,,,,,Chepalungu,,,,Bomet Central,,,
3,,,,,Male,,Female,,Male,,Female,
4,,,N,%,N,%,N,%,N,%,N,%
5,What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
6,,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
7,,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
8,,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
9,,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
10,,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%

附注:我还尝试了 Markdown 和 HTML 表格,但发现 CSV 在我的用例中效果最好。

值得注意的是,对于这项分析,我们处理的表格是稀疏的,即列数少于 100。这意味着前 10 行可以很容易地在 GPT-3 提示中表示。这对我在援助数据交换中分析的大多数 Excel 表格来说是合适的,但可能不适用于其他情况。此外,这项分析不考虑同一 Excel 工作表上有多个表格的情况……这将是稍后博客文章的内容。🙂

零样本提示

现在让我们看看 GPT-3 是否可以仅通过一个提示重新格式化我们凌乱的表格,这是一项零样本任务 [2],我们没有提供示例,只是提供了要重新格式化的 CSV 文件……

import openai as ai

# Open AI API key should be put into this file
ai.api_key_path = "./api_key.txt"

csv_as_str = df.to_csv()

prompt = (
    "Reformat this table to be a simpler markdown table with "
    + "no hierarchical columns, no pivoting, values and percentages in different columns, "
    + "and no blank cells\n\n"
    + csv_as_str
)

completions = ai.Completion.create(
    engine="text-davinci-003",
    temperature=0.0,
    prompt=prompt,
    max_tokens=999,
    n=1,
    stop=None,
)

Markdown(completions.choices[0].text)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

它丢弃了不必要的行,将数据转换为一个规范的表格,带有列标题,但仔细观察会发现,它丢失了一些关键信息,如按性别的分类。这是经典的幻觉现象,看起来很可信,但却是错误的。

让我们玩玩温度参数。较低的值使模型更加确定性(对于相同的提示每次都给出相同的结果),而较高的值则更随机。使用更高的温度值,我们得到……

prompt = (
    "Reformat this table to be a simpler markdown table with "
    + "no hierarchical columns, no pivoting, values and percentages in different columns, "
    + "and no blank cells\n\n"
    + csv_as_str
)

completions = ai.Completion.create(
    engine="text-davinci-003",
    temperature=1.0,
    prompt=prompt,
    max_tokens=999,
    n=1,
    stop=None,
)

Markdown(completions.choices[0].text)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

看起来不错!几乎所有的正确列标题都来自我们 CSV 文件中的合并单元格,这实际上相当惊人。然而,抽查几个单元格显示,尽管许多是正确的,但也有一些不正确。此外,上面的“总体”被分成了男性和女性,这是不正确的。

另一个问题是,调用 GPT-3 完全相同的提示会产生不同的结果,因为高温值……

completions = ai.Completion.create(
    engine="text-davinci-003",
    temperature=1.0,
    prompt=prompt,
    max_tokens=999,
    n=1,
    stop=None,
)

Markdown(completions.choices[0].text)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不无道理,尽管值不正确,但布局完全不同。可重复性对我们的任务非常重要,我们应该能够在每次处理运行中以完全相同的方式处理表格数据。

所以高温似乎不是这个用例的好选择。

如果我们在表格中提供更多上下文会怎么样?CSV 并不是很具表现力,例如,层级标题中的合并列告诉人类这些列是分组的,但 CSV 文件并未捕捉到这一点……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

1,,,OVERALL,,Sub county,,,,,,,
2,,,,,Chepalungu,,,,Bomet Central,,,
3,,,,,Male,,Female,,Male,,Female,
4,,,N,%,N,%,N,%,N,%,N,%

在上述示例中,GPT-3 必须推断出合并的行标题右侧的空白列与这些标题对应,并且很多时候它确实能够做到这一点。然而,由于我们知道 Excel 文件中哪些单元格是合并的,我们可以稍微帮助一下。

为了在 CSV 中表示这个,我们可以取消合并合并的单元格,并用其合并值填充 …

def pad_merged_cells(sheet):
    """
    Unmerge merged cells and fill with merged value.

    Input Parameters
    ----------------
    sheet: Obj
        Openpyxl sheet object

    Output Parameters
    -----------------
    df: Dataframe
        Pandas dataframe of the table
    """

    dd = pd.DataFrame(sheet.values)

    # Scan for maxn rows
    maxn = 10

    hasmerged = False
    if len(sheet.merged_cells.ranges) > 0:
        hasmerged = True

    if hasmerged:
        merge_list = []
        for merge in sheet.merged_cells.ranges:
            merge_list.append(merge)

        for cell_group in merge_list:
            min_col, min_row, max_col, max_row = range_boundaries(
                str(cell_group))
            top_left_cell_value = sheet.cell(row=min_row, column=min_col).value
            sheet.unmerge_cells(str(cell_group))
            for row in sheet.iter_rows(
                min_col=min_col, min_row=min_row, max_col=max_col, max_row=max_row
            ):
                for cell in row:
                    cell.value = top_left_cell_value

    # Extract data and save to dataframe
    data = []
    for row in sheet.iter_rows(min_row=1):
        row_data = []
        for cell in row:
            if cell.value is None:
                row_data.append(None)
            else:
                row_data.append(cell.value)
        if any(row_data):
            data.append(row_data)

    df = pd.DataFrame(data)

    # Remove duplicate columns
    df = df.T.drop_duplicates().T

    # Remove duplicate rows
    df = df.drop_duplicates()

    # Fill NaN with blank string for easier viewing
    df = df.fillna("")

    return df, sheet, hasmerged

wb = openpyxl.load_workbook(filename)
sheet = wb['Sheet1']
merged_table, sheet, hasmerged = pad_merged_cells(sheet)

display(merged_table)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表格中合并的单元格被取消合并,并用合并值填充,以在 CSV 文件格式中提供上下文。

,Table 3: Number of acreage under irrigation,,,,,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Chepalungu,Chepalungu,Bomet Central,Bomet Central,Bomet Central,Bomet Central
3,,,OVERALL,OVERALL,Male,Male,Female,Female,Male,Male,Female,Female
4,,,N,%,N,%,N,%,N,%,N,%
5,What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
6,What is the average size of land you own that is currently under irrigation?,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
7,What is the average size of land you own that is currently under irrigation?,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
8,What is the average size of land you own that is currently under irrigation?,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
9,What is the average size of land you own that is currently under irrigation?,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
10,What is the average size of land you own that is currently under irrigation?,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%

CSV 文件现在捕获了叠加的合并列标题。让我们看看这是否能改善情况,首先温度=0.0 …

csv_as_str_merged = merged_table.to_csv()

prompt = (
    "Reformat this table to be a simpler markdown table with "
    + "no hierarchical columns, no pivoting, values and percentages in different columns, "
    + "and no blank cells\n\n"
    + csv_as_str_merged
)

completions = ai.Completion.create(
    engine="text-davinci-003",
    temperature=0.0,
    prompt=prompt,
    max_tokens=999,
    n=1,
    stop=None,
)

Markdown(completions.choices[0].text)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

同样的,但温度=1.0,只是为了好玩 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

稍微好了一些,但总是有些地方不太对。缺失的类别,单元格值偏移,如果我们需要准确表示源数据,两个表格都无法使用。

此时,我尝试了各种组合:

  • 提示

  • 温度

  • 使用 Markdown、HTML 和 CSV 定义输入表格

  • 提示 GPT-3 生成用于解析的 Python 代码,而不是解析表格

有时该过程能够生成列标题和数值完美的表格,但通常这需要高温度值,因此不可重复。大多数情况下,结果看起来合理,但数据不正确。

公平地说,我们真的对 GPT-3 提出了很高的要求,这是一项复杂的零样本任务。我对它的表现感到非常满意,也许通过更好的提示和问题的重新框定 —— 或 GPT-4!—— 结果可能会有所改善,但我没有能够实现所需的结果。

单次提示

现在,让我们在提示中提供一个示例。我从人道数据交换获取了一个类似的 Excel 文件 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们将在单次提示中使用的表格。此文件来源于 人道数据交换

我们希望这被处理成如下所示 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们的示例文件在重新格式化后的样子

显然,这是一个不切实际的“真实世界”示例,因为格式和内容与我们尝试处理的表格非常相似,但这是一个很好的初步测试。

将我们的输入表格转换为 CSV 并取消合并合并的单元格,如上所述,我们得到 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们现在可以构建我们的单次提示(假设温度为零以便可重复) …

from io import StringIO

wb = openpyxl.load_workbook(prompt_sample_table1, data_only=True)
sheet = wb["Sheet1"]
example_before, sheet, hasmerged = pad_merged_cells(sheet)
example_before_csv = example_before.to_csv()
example_after, hasmerged, report = parse_excel_sheet(sheet)
example_after_markdown = example_after.to_markdown()
example_after_csv = example_after.to_csv()

example_before_csv = """
 ,0,1,2,3,4,5,6,7
0,Table 16: % of infants on Minimum Dietary Diversity,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Bomet Central,Bomet Central
3,,,N,%,n,%,n,%
4,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
5,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
6,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%
"""

example_after_markdown = (
    """
 |    |                                      |                                         |   OVERALL - N | OVERALL - %   |   Sub county - Chepalungu | Sub county - Chepalungu - %   |   Sub county - Bomet Central | Sub county - Bomet Central - %   |
|---:|:-------------------------------------|:----------------------------------------|--------------:|:--------------|--------------------------:|:------------------------------|-----------------------------:|:---------------------------------|
|  1 | Infants         on Dietary Diversity | Infants  on  Minimum  Dietary Diversity |            37 | 17.5%         |                        24 | 17.9%                         |                           13 | 16.7%                            |
|  2 | Infants         on Dietary Diversity | Infants not on Dietary Diversity        |           175 | 82.5%         |                       110 | 82.1%                         |                           65 | 83.3%                            |
|  3 | Infants         on Dietary Diversity | Total                                   |           212 | 100.0%        |                       134 | 100.0%                        |                           78 | 100.0%                           |
""".replace(
        ":|", "|"
    )
    .replace("|:", "|")
    .replace("\n", "\n<RETURN>")
)

example_after_csv = """
 , , ,OVERALL - N,OVERALL - %,Sub county - Chepalungu,Sub county - Chepalungu - %,Sub county - Bomet Central,Sub county - Bomet Central - %
1,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
2,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
3,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%
"""

table_to_parse_padded = """
,0,1,2,3,4,5,6,7,8,9,10,11
0,Table 3: Number of acreage under irrigation,,,,,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Chepalungu,Chepalungu,Bomet Central,Bomet Central,Bomet Central,Bomet Central
3,,,OVERALL,OVERALL,Male,Male,Female,Female,Male,Male,Female,Female
4,,,N,%,N,%,N,%,N,%,N,%
5,What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
6,What is the average size of land you own that is currently under irrigation?,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
7,What is the average size of land you own that is currently under irrigation?,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
8,What is the average size of land you own that is currently under irrigation?,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
9,What is the average size of land you own that is currently under irrigation?,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
10,What is the average size of land you own that is currently under irrigation?,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%
"""

prompt = (
    "Reformat this table to only have a single header row: \n\n"
    + example_before_csv
    + "\n\n"
    + "Result: \n\n"
    + example_after_csv
    + "\n\n"
    + "Reformat this table to only have a single header row: \n\n"
    + table_to_parse_padded
    + "\n\n"
    + "Result: \n\n"
)

print("\n\n", prompt, "\n\n")

completions = ai.Completion.create(
    engine="text-davinci-003",
    temperature=0.0,
    prompt=prompt,
    n=1,
    stop=None,
    max_tokens=2068,
    top_p=1,
    frequency_penalty=0,
    presence_penalty=0,
)

print("\n========== Model prediction:\n")

display(pd.read_csv(StringIO(completions.choices[0].text)))

这是生成的提示 …

Reformat this table to only have a single header row: 

 ,0,1,2,3,4,5,6,7
0,Table 16: % of infants on Minimum Dietary Diversity,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Bomet Central,Bomet Central
3,,,N,%,n,%,n,%
4,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
5,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
6,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%

Result: 

 , , ,OVERALL - N,OVERALL - %,Sub county - Chepalungu,Sub county - Chepalungu - %,Sub county - Bomet Central,Sub county - Bomet Central - %
1,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
2,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
3,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%

Reformat this table to only have a single header row: 

,0,1,2,3,4,5,6,7,8,9,10,11
0,Table 3: Number of acreage under irrigation,,,,,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Chepalungu,Chepalungu,Bomet Central,Bomet Central,Bomet Central,Bomet Central
3,,,OVERALL,OVERALL,Male,Male,Female,Female,Male,Male,Female,Female
4,,,N,%,N,%,N,%,N,%,N,%
5,What is the average size of land you own that is currently under irrigation?,0 - 2 acres,22,2.8%,4,2.2%,10,3.8%,3,1.7%,5,2.9%
6,What is the average size of land you own that is currently under irrigation?,2 - 5 acres,6,.8%,2,1.1%,2,.8%,0,0.0%,2,1.2%
7,What is the average size of land you own that is currently under irrigation?,5 - 10 acres,1,.1%,0,0.0%,0,0.0%,0,0.0%,1,.6%
8,What is the average size of land you own that is currently under irrigation?,More than 10 acres,0,0.0%,0,0.0%,0,0.0%,0,0.0%,0,0.0%
9,What is the average size of land you own that is currently under irrigation?,None,760,96.3%,176,96.7%,251,95.4%,170,98.3%,163,95.3%
10,What is the average size of land you own that is currently under irrigation?,Total,789,100.0%,182,100.0%,263,100.0%,173,100.0%,171,100.0%

Result: 

这是 GPT-3 的完成结果,转换为数据框以便更容易显示 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从单次提示生成的表格,重新格式化具有层次结构标题的表格(完成结果是 CSV,这里为便于展示转换为 pandas 数据框)

很好!当提供一个示例时,GPT-3 能够完美地重新格式化我们的新表格。然而,这不是一个很好的测试,因为示例表格和测试表格在结构和内容上非常相似,但有趣的是,即使示例中没有男性/女性的层级,GPT-3 仍能正确地折叠这个额外的层级。

让我们使用相同的示例表格来重新格式化一个具有不同布局和内容数据的表格 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用相同的代码处理得到的是 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这很接近,标题完全正确,但农场列向左移动了。我们的单次提示在重新格式化非常相似的表格时表现不错,但稍微的变化导致了较差的结果。

单次提示,带有推理

关于提示工程已经有相当多的研究。一个非常好的资源可以在 OpenAI Cookbook 的 提升可靠性的技术 [3] 中找到。提高结果的最有效方法之一是包含推理在示例提示中 [4]。以我们之前的表格为例,调整提示以包括推理 …

prompt = (
    "We need to reformat this table to only have a single header row: \n\n"
    + example_before_csv
    + "\n"
    + "Let's think step by step \n"
    + "Row 1 is just an index row, it has no text or data \n"
    + "Row 2 contains just label text \n"
    + "Rows 3 to 5 contain column headers \n"
    + "Rows 6 onwards contain data \n"
    + "Columns are separated by commas, there should be 7 commas on each row \n"
    + "If we combine each colummn of rows 3 to 5 by concatenating vertically, we get \n"
    + example_after_csv
    + "\n\n"
    + "We need to reformat this table to only have a single header row: \n\n"
    + table_to_parse_padded
    + "\n\n"
    + "Let's think step by step \n\n"
)

完整的提示如下 …

We need to reformat this table to only have a single header row: 

 ,0,1,2,3,4,5,6,7
0,Table 16: % of infants on Minimum Dietary Diversity,,,,,,,
1,,,OVERALL,OVERALL,Sub county,Sub county,Sub county,Sub county
2,,,OVERALL,OVERALL,Chepalungu,Chepalungu,Bomet Central,Bomet Central
3,,,N,%,n,%,n,%
4,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
5,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
6,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%

Let's think step by step 
Row 1 is just an index row, it has no text or data 
Row 2 contains just label text 
Rows 3 to 5 contain column headers 
Rows 6 onwards contain data 
Columns are separated by commas, there should be 7 commas on each row 
If we combine each colummn of rows 3 to 5 by concatenating vertically, we get 

 , , ,OVERALL - N,OVERALL - %,Sub county - Chepalungu - N,Sub county - Chepalungu - %,Sub county - Bomet Central - N,Sub county - Bomet Central - %
1,Infants         on Dietary Diversity,Infants  on  Minimum  Dietary Diversity,37,17.5%,24,17.9%,13,16.7%
2,Infants         on Dietary Diversity,Infants not on Dietary Diversity,175,82.5%,110,82.1%,65,83.3%
3,Infants         on Dietary Diversity,Total,212,100.0%,134,100.0%,78,100.0%

We need to reformat this table to only have a single header row: 

,0,1,2,3,4,5
0,Random text here,,,,,
1,,Cows,Cows,Cowboy Hats,Cowboy Hats,Cowboy Hats
2,,TOT,Farm,Small ,Medium,Large
3,,97,6666,0.65,,0.13
4,,93,Bar M,,0.2,
5,,140,Daisy,0.3,0.89,
6,,,Plew,0.2,0.5,0.1
7,,119,Birdie,0.2,0.1,
8,,29,Kit Kat,,0.55,

Let's think step by step 

这导致 GPT-3 对我们输入表格的完成结果是 …

Row 1 is just an index row, it has no text or data 
Row 2 contains just label text 
Rows 3 to 8 contain column headers 
Rows 9 onwards contain data 
Columns are separated by commas, there should be 6 commas on each row 
If we combine each colummn of rows 3 to 8 by concatenating vertically, we get 

 , ,Cows - TOT,Cows - Farm,Cowboy Hats - Small ,Cowboy Hats - Medium,Cowboy Hats - Large
1,,97,6666,0.65,,0.13
2,,93,Bar M,,0.2,
3,,140,Daisy,0.3,0.89,
4,,,Plew,0.2,0.5,0.1
5,,119,Birdie,0.2,0.1,
6,,29,Kit Kat,,0.55,

这很正确!重新格式化后的表格正是我们想要的 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们在单次提示中提供推理,结果会有所改善

也就是说,我们提供的任务并不是很好,因为尽管内容与提供的示例不同,但标题布局仍然相似。事实上,如果我们稍微调整一下要重新格式化的表格并添加一个额外的“有机”列 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

向输入中添加一个额外的列

预测现在不正确 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

只是标题行中多了一个额外的逗号,这导致所有内容向右移动。

我们可能会继续通过更多推理来优化提示,或应用 更高级的技术 自动构建我们的提示工作流,但真正的问题是一个示例并不足以捕捉我们可能遇到的所有表格格式变体。尽管 GPT-3 在仅有一个示例的情况下表现得非常好,但对于这个任务来说,它还不够好(至少就目前的框架而言)。

少量示例…. 或者说不是

下一个方法可能是提供多个示例。然而,表格片段需要大量的令牌(稍后会详细说明),所以如果我们必须在提示中提供多个示例,再加上结果中的令牌,就会触及 Open API 的令牌限制。对于 davinci 模型,目前的限制为4,000个令牌。此外,由于我们按令牌收费,对于像DataKind这样的小型非营利组织,发送和接收大量令牌可能会变得昂贵。更长的提示还有性能影响,因此对于这个任务没有探索少样本提示。

所以我决定暂时跳过少样本学习。

微调

探索零样本和单样本提示很有趣,如果这些方法在这个用例中有效,将会取得惊人的结果。未来,随着模型的改进,这可能会成为一个可行的选项,但目前,重新定义任务可能更有意义。

另一种方法是通过微调提供大量示例。正如 OpenAI 所述:

微调可以通过提供以下内容来让你更好地利用 API 提供的模型:

  1. 比提示设计产生更高质量的结果

  2. 能够在比提示中能容纳的更多示例上进行训练

  3. 由于提示较短而节省令牌

  4. 更低延迟的请求

起初,我考虑通过提供 GPT-3(i)原始表格的提示(合并单元格未合并)和(ii)作为重新格式化表格的完成项来进行微调。然而,这种方法的挑战在于,它仍然使用了大量的令牌,尤其是我们现在需要使用数百个示例。

与其传递原始表格片段,不如尝试使用该表格的属性,并让 GPT-3 预测我们可以用来解析的关键进一步属性……

重新定义任务 — 使用表格属性作为提示

作为一个人(好吧,大部分是人),当我扫描 Excel 中的表格时,我可以通过查看值来识别结构,并决定数据的位置。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

确定表格中的数据部分是将其解析成规则表格结构的关键

一旦我知道数据开始的行,就很容易从上面的行推断出标题层次,并将它们合并成一个单一的标题行,以创建一个整齐、规则的表格来使用……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

处理后的表格具有平面标题,容易导入关系数据库

确定数据的开始位置乍一看似乎很简单,只需在 openpyxlpandas.read_excel 中稍作处理即可。然而,如果需要处理成千上万的具有不同标题布局、空白行等的电子表格,开发一套用于准确识别每个工作表中数据开始位置的规则将是一项挑战。

这很复杂,因为:

  • 列标题可能有很高的变化性,看起来像数据。

  • 空白单元格和注释容易混淆解析规则。

  • 数据不总是数字的,它可以是分类的,看起来很像列标题。

  • 一些列标题是数字,可能看起来像数据,例如年份。

那么,我们应该使用哪些表格属性/特征来预测数据首次出现的行号呢?

我列出了一个我认为可能有用的表格属性的简短清单……

import openpyxl

def get_sheet_attributes(sheet, maxn):
    """
    Returns a set of table attributes for a given sheet

    Input Parameters:
        sheet: Obj
            Openpyxl sheet object
        maxn: int
            Number of rows to scan at start of sheet

    Returns:
        null_cells_in_rows: list of ints
            Count of NULL records in forst maxn rows
        float_cells_in_rows: list of ints
            Count of numeric records in first maxn rows
        unique_vals_in_rows: list of ints
            Count of unique values in first maxn rows
        year_vals_in_rows: list of ints
            Count of year values in first maxn rows
        hxl_row: int
            Row number of HXL header row
        first_float_row: int
            Row number of row with most numeric records
        first_not_null_row: int
            Row number of row with most non-null records

    """
    dd = pd.DataFrame(sheet.values)

    null_cells_in_rows = list(
        dd[0:maxn].apply(lambda x: x.isnull().sum(), axis="columns")
    )
    float_cells_in_rows = []
    unique_vals_in_rows = []
    year_vals_in_rows = []
    report_json = {}
    hxl_row = None
    for index, row in dd[0:maxn].iterrows():
        unique_vals = list(row.unique())
        unique_vals = [i for i in unique_vals if i is not None and str(i) != "nan"]
        unique_vals_in_rows.append(len(unique_vals))
        float_count = 0
        year_count = 0
        if check_hdx_header(list(row)):
            hxl_row = index
        for col in dd.columns:
            val = row[col]
            # Handle numbers that come through as strings
            if isinstance(val, str):
                val = val.replace(",", "").replace(" ", "")
                if val.isnumeric():
                    val = int(val)
            # Check for year values
            if (
                ((isinstance(val, int) or isinstance(val, float)) and val % 1 == 0)
                and val > 1900
                and val < 2100
            ):
                year_count += 1
                continue
            # Check for HXL tags
            if isinstance(val, float) or isinstance(val, int) or "^=" in str(row[col]):
                float_count += 1
        float_cells_in_rows.append(float_count)
        year_vals_in_rows.append(year_count)

    max_floats = max(float_cells_in_rows)
    min_nulls = min(null_cells_in_rows)
    first_float_row = 0
    if sum(float_cells_in_rows) > 0:
        for i in range(1, len(float_cells_in_rows)):
            # Use a ratio or special case where we go from zero to some
            if float_cells_in_rows[i] / max_floats > 0.5 or (
                float_cells_in_rows[i] > 0 and float_cells_in_rows[i - 1] == 0
            ):
                first_float_row = i
                break
    first_not_null_row = np.argmin(null_cells_in_rows)

    report = f"Nulls in first {maxn} rows: {str(null_cells_in_rows)}\n"
    report += f"Numeric first {maxn} rows: {str(float_cells_in_rows)}\n"
    report += f"Unique values in first {maxn} rows: {str(unique_vals_in_rows)}\n"
    report += f"Year values in first {maxn} rows: {str(year_vals_in_rows)}\n"
    report += f"HXL row: {str(hxl_row)}\n"

    report += f"\nFirst reduced nulls row: {str(first_not_null_row)}\n"
    report += f"First increased numeric row (excluding years): {str(first_float_row)}\n"

    report_json = {
        "null_cells_in_rows": null_cells_in_rows,
        "float_cells_in_rows": float_cells_in_rows,
        "unique_vals_in_rows": unique_vals_in_rows,
        "year_vals_in_rows": year_vals_in_rows,
        "hxl_row": hxl_row,
        "first_float_row": first_float_row,
        "first_not_null_row": first_not_null_row,
    }

    return report, report_json

wb = openpyxl.load_workbook(filename, data_only=True)
for s in wb.sheetnames:
    sheet = wb[s]
    report, report_json = get_sheet_attributes(sheet, maxn)
    print(report) 

这会产生这样的输出……

Nulls in first 10 rows: [12, 11, 10, 10, 8, 2, 0, 1, 1, 1]
Numeric first 10 rows: [0, 0, 0, 0, 0, 0, 5, 5, 5, 5]
Unique values in first 10 rows: [0, 1, 2, 2, 2, 2, 12, 8, 6, 3]
Year values in first 10 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
HXL row: None

First reduced nulls row: 6
First increased numeric row (excluding years): 6

这些将是我们用于微调模型的提示。

为了创建微调文件的补全,我使用了肯尼亚的人道主义数据交换数据集(有关如何提取 Excel 文件的更多细节,请参见这里)。解析文件并循环遍历每个工作表,我生成了提示。

我使用了以下逻辑来估算数据开始的行号,使用了上述表格参数……

# Make a guess at which row is the data row
datarow = max_not_null_row
# Sometimes we have header rows where none are null, in this case we want to use the row with the most floats
if max_float_row > datarow:
    datarow = max_float_row
# HXL row is always the row before the data row
if hxl_row is not None:
    datarow = hxl_row
# If we a row with a lot of year values below datarow, use that
if year_vals_in_rows[datarow] > 3:
    datarow = datarow + 1

这种基于规则的方法实际上表现得相当不错,但它并不完美,因此需要 GPT-3。尽管如此,它在创建一个大多数补全都准确的测试集时很有用,我只需调整几个逻辑不成立的部分即可。

对于我的训练集,我使用了来自 10 个人道主义提供组织的多个标记为“Kenya”的 Excel 表格中的每个组织的一个表格,其中使用上述基于规则的方法进行了首次数据行的预测。我随后审查了这份清单,并与实际的工作表进行了比较,以纠正电子表格表格开始于不同的行的情况。我排除了本研究中存在多个表格的情况,此后我得到了 232 个这样的微调提示……

{"prompt": "Nulls in first 15 rows: [9, 8, 7, 7, 3, 1, 2, 2, 2, 2, 2]\nNumeric first 15 rows: [0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3]\nUnique values in first 15 rows: [0, 1, 2, 2, 3, 8, 7, 7, 6, 6, 5]\nYear values in first 15 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\nHXL row: None\nFirst reduced nulls row: 5\nFirst increased numeric row (excluding years): 5", "completion": "Data starts at row: 5\n", "meta_data": "./data/Kenya/kenya-hand-washing-statistics-in-bomet-county_118ea93f-83ce-4b86-b1c4-ca54ea9acc8a/Hand_washing_practices_xlsx_efc74f32_ac23_463a_924b_d53c3656b406/Hand washing practices.xlsx"}
{"prompt": "Nulls in first 15 rows: [2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1]\nNumeric first 15 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\nUnique values in first 15 rows: [0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1]\nYear values in first 15 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\nHXL row: None\nFirst reduced nulls row: 1\nFirst increased numeric row (excluding years): 0", "completion": "Data starts at row: 1\n", "meta_data": "./data/Kenya/shcchealthcare-dataset_02995168-3644-4b78-92be-cdf67275b39d/2018_SHCC_Overview_Data_xlsx_d053b42a_7d31_41b5_a6d9_c8b0a424241c/2018 SHCC Overview Data.xlsx"}
{"prompt": "Nulls in first 15 rows: [6, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 3, 7, 6, 3]\nNumeric first 15 rows: [0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 0, 0, 0]\nUnique values in first 15 rows: [1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 0, 1, 4]\nYear values in first 15 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\nHXL row: None\nFirst reduced nulls row: 1\nFirst increased numeric row (excluding years): 2", "completion": "Data starts at row: 2\n", "meta_data": "./data/Kenya/eastern-southern-africa-refugees-and-idps-situation-and-response-dec2019_e1f9f55e-08db-4166-a787-c7ea9969dc4d/UNICEF_ESARO_Regional_refugee_and_idp_db_2019_November_27_2019_xlsx_0696b7f3_6368_403e_bcb7_eccdc617961f/UNICEF ESARO Regional refugee and idp db 2019 November 27.2019.xlsx"}

附注:在上面的内容中,你可能注意到我为每个提示添加了一个“meta_data”元素。这不是 JSONL 提示记录的必要部分,但我这样做是为了能够轻松将每个提示与文件关联以便于调试。包含这些额外数据的提示文件似乎仍然被 OpenAI 接受,我认为只要有“prompt”和“completion”元素,它就会接受!

然后我微调了一个 DaVinci 模型……

 ai.api_key_path="./api_key.txt"

train_file = './prompts.json'

print("Uploading training file ...")
training_id = cli.FineTune._get_or_upload(train_file, True)

print("Fine-tuning model ...")
create_args = {
    "training_file": training_id,
    "model": "davinci"
}
resp = ai.FineTune.create(**create_args)
job_id = resp["id"]
status = resp["status"]

print(f'Fine-tunning model with jobID: {job_id}.')

我手动检查了微调状态,如下所示……

ai.api_key_path="./api_key.txt"
result = ai.FineTune.retrieve(id=job_id)

print(result['status'])

然后完成后,检索了模型……

model = result["fine_tuned_model"]

对于测试集,我使用了来自训练集之外组织(标记为‘肯尼亚’)的每个 Excel 文件中的一个表格,首先运行上述基于规则的预测生成提示和完成项,然后纠正返回的错误值。再次排除指定了多个表格的 Excel 表格。这给我提供了 72 个提示的测试集。

 def make_gpt3_prediction(prompt, model, temperature=0.99, max_tokens=13):
    """
    Wrapper to call GPT-3 to make a prediction (completion) on a single prompt.
    Also calls post_process() to clean up the prediction.

    Parameters
    ----------
    prompt : str
        Prompt to use for prediction
    model : str
        GPT-3 model to use
    temperature : float
        Temperature to use for sampling
    max_tokens : int
        Maximum number of tokens to use for sampling

    Returns
    -------
    result : dict
        Dictionary with prompt, predicted, predicted_post_processed
    """
    result = {}
    result["prompt"] = prompt
    model_result = ai.Completion.create(
        engine=model,
        prompt=prompt,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        stop=["\n"],
        logprobs=1,
    )
    result["predicted"] = model_result["choices"][0]["text"].replace(" ", "")
    result["logprobs"] = model_result["choices"][0]["logprobs"]["top_logprobs"]
    return result

def output_prediction_metrics(results, prediction_field="predicted_post_processed"):
    """
    Prints out model performance report if provided results in the format:

    [
        {
            'prompt': ' \'ISO3\' | "[\'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\']"',
            'predicted': ' #country+code+iso3+v_iso3+',
            'expected': '#country+code'
        },
        ... etc ...
    ]

    Parameters
    ----------
    results : list
        See above for format
    prediction_field : str
        Field name of element with prediction. Handy for comparing raw and post-processed predictions.
    """
    y_test = []
    y_pred = []
    for r in results:
        if "expected" not in r:
            print("Provided results do not contain expected values.")
            sys.exit()
        y_pred.append(r[prediction_field])
        y_test.append(r["expected"])

    print(f"There were {len(y_test)} predictions made.")
    print(f"\nPrediction using field {prediction_field} ...\n")
    print(f"Accuracy: {round(accuracy_score(y_test, y_pred),2)}")
    print(
        f"Precision: {round(precision_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"Recall: {round(recall_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"F1: {round(f1_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )

# File generated by downloading and processing HDX files. See this blog post
# for more details: https://medium.com/towards-data-science/predicting-metadata-for-humanitarian-datasets-using-gpt-3-b104be17716d
country='Kenya'
resources = pd.read_pickle(hdx_resources_pkl_file)

df = resources[(resources["resource_format"]=='XLSX')][["resource_format","file","sheet","dataset_name","dataset_org_title"]]
df.drop_duplicates(inplace=True)
orgs = df["dataset_org_title"].unique()

# Number of rows to use when calculating table row parameters
maxn = 15

# Determine test/train split, 0:10 used for training, 11:len(orgs) for test
dataset_orgs_cutoff = 10

for dataset_org in orgs[dataset_orgs_cutoff: len(orgs)]:
    rows = df.loc[df['dataset_org_title']== dataset_org]
    row = rows.iloc[0]  # Take one sheet from each org to get more variation
    filename = row["file"]
    sheetname = row["sheet"]

    wb = openpyxl.load_workbook(filename, data_only=True)
    for s in wb.sheetnames:
        sheet = wb[s]

        # Extract table attributes 
        report = get_sheet_attributes(sheet, maxn)

        report_elements = report.split('\n\n')
        prompt = report_elements[0] + report_elements[1]
        completion = report_elements[2]

        # Make our GPT-3 prediction
        res = make_gpt3_prediction(prompt, model, temperature=0.0)

        predicted = res["predicted"].split(':')[1].strip()
        actual = completion.split(':')[1].strip()

        results.append({
            "prompt": prompt,
            "predicted": predicted,
            "expected": actual
        })

output_prediction_metrics(results, prediction_field="predicted")

附注:我之前的博客文章中预测 HXL 标签时,我必须通过对数概率过滤完成项,但在这项研究中没有必要。

GPT-3 在我们的测试集中预测了第一行数据的结果如下……

Prediction using field predicted ...

Accuracy: 0.97
Precision: 1.0
Recall: 0.97
F1: 0.99

所以 GPT-3 在预测第一行数据的位置上表现不错。

综合起来

步骤 1 — 读取我们的数据

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

示例电子表格,具有不同的层级标题和单元格中的备注

步骤 2 — 取消合并的列并填充合并值

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Pandas 数据框在通过‘pad_merged_cells’函数处理后,用于取消合并并填充合并值

步骤 3 — 计算表格参数以生成 GPT-3 提示

Nulls in first 10 rows: [20, 20, 20, 21, 10, 8, 19, 9, 21, 0]
Numeric first 10 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 14]
Unique values in first 10 rows: [1, 1, 1, 0, 11, 13, 2, 4, 0, 21]
Year values in first 10 rows: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
HXL row: None

First reduced nulls row: 9
First increased numeric row (excluding years): 9

步骤 4 — 调用 GPT-3 预测数据行的起始位置

GPT-3 prediction: 9

步骤 5 — 现在我们知道了数据行的开始位置,将上方的列标题连接成一行

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

解析后的表格,具有折叠的层级列,没有随机标签。现在可以导入到数据库中。

这是一个我们可以上传到关系数据库的好表格。有关完整代码,请参见下面的参考部分。

诚然,手动解析这个表格并指定一些与我们发现的表格参数相关的规则是很容易的,但上述过程的重点是它可以应用于人道主义数据交换数据集中成千上万的 Excel 表格的广泛表格布局。

结论与未来工作

尽管零次和一次提示具有很大的潜力,但在用 CSV 表格进行提示时,这种方法尚未对这个特定任务奏效。随着大型语言模型的进步,这种情况可能会改变——我很期待 GPT-4 的表现——但目前看来,微调是更好的选择,它可以预测关键的表格属性,用于重新格式化。当然,这种方法需要一些预处理,以确定提示的表格参数。值得注意的是,使用表格‘特征’时,它更像是分类任务而不是文本完成,可能会更适合这样框架。不过,无论如何,这种技术在使用人道主义数据交换 Excel 文件时表现良好。

我认为将这项工作扩展到处理 Excel 工作表上有多个表格的情况将非常有趣。这需要比我在这项研究中使用的更多的表格特征,比如单元格格式和列(而不是行)属性。

更多有趣内容敬请期待!

参考文献

[1] Haoyu Dong 等人,TableSense: 使用卷积神经网络进行电子表格表格检测 (2021)

[2] Brown 等人,语言模型是少样本学习者 (2020)。

[3] OpenAI Cookbook: 提高可靠性的技术

[4] Kojima 等人,大型语言模型是零样本推理者

这个分析的代码可以在这个笔记本中找到。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值