TowardsDataScience 2023 博客中文翻译(一百零二)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

Delta Lake:保持快速和清洁

原文:towardsdatascience.com/delta-lake-keeping-it-fast-and-clean-3c9d4f9e2f5e

曾经想过如何提高 Delta 表的性能吗?亲身体验如何保持 Delta 表的快速和清洁。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vitor Teixeira

·发布于数据科学前沿·阅读时间 11 分钟·2023 年 2 月 15 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如何保持 Delta 表快速和清洁的简化流程图(作者提供的图片)

保持 Delta 表的快速和清洁对维护数据管道的效率非常重要。Delta 表可能会随着时间的推移变得非常庞大,导致查询性能下降和存储成本增加。然而,有几种操作和权衡可以积极影响表的速度。

在这篇博客文章中,我们将使用people10m 公共数据集,该数据集在 Databricks Community Edition 上可用,展示如何利用 Delta 操作保持表的快速和清洁,同时解释幕后发生的情况。

分析 delta 日志

我们将从检查数据集的内容开始。默认情况下,它在 Databricks 上可用,你可以在这里访问它。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集的小样本

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自原始 Delta 表的文件

我们有 16 个 parquet 条目和一个*_delta_log文件夹,其中包含所有交易日志,这些日志堆积在一起形成我们的 delta 表。

如果我们检查日志的内容,可以看到一个 JSON 文件,描述了 Databricks 创建这个 Delta 表时写入的第一次交易。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从分析中,我们可以看到这个交易包括几个操作:

提交信息

{
    "commitInfo": {
        "timestamp": 1602173340340,
        "userId": "360903564160648",
        "userName": "stephanie.bodoff@databricks.com",
        "operation": "WRITE",
        "operationParameters": {
            "mode": "ErrorIfExists",
            "partitionBy": "[]"
        },
        "notebook": {
            "notebookId": "1607762315395537"
        },
        "clusterId": "1008-160338-oil232",
        "isolationLevel": "WriteSerializable",
        "isBlindAppend": true,
        "operationMetrics": {
            "numFiles": "8",
            "numOutputBytes": "221245652",
            "numOutputRows": "10000000"
        }
    }
}

commitInfo 包含有关提交的所有信息:执行了什么操作、由谁执行、在哪里执行以及在什么时间。operationMetrics 字段显示写入了 8 个文件,总共 1000000 条记录。

Protocol

{
    "protocol": {
        "minReaderVersion": 1,
        "minWriterVersion": 2
    }
}

protocol 操作用于增加读取或写入给定表所需的 Delta 协议版本。这允许排除那些使用旧协议的读者/写者,因为旧协议可能缺少正确解释事务日志所需的功能。

Metadata

{
    "metaData": {
        "id": "ee2db204-0e38-4962-92b0-83e5570d7cd5",
        "format": {
            "provider": "parquet",
            "options": {}
        },
        "schemaString": "{\"type\":\"struct\",\"fields\":[{\"name\":\"id\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"firstName\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"middleName\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"lastName\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"gender\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"birthDate\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"ssn\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"salary\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}}]}",
        "partitionColumns": [],
        "configuration": {},
        "createdTime": 1602173313568
    }
}

metadata 操作包含所有表的元数据。它在表的第一个操作中是必需的,因为它包含了表的定义。对表元数据的后续修改将产生新的操作。

Add

{
    "add": {
        "path": "part-00000-373539c8-e620-43e4-82e0-5eba22bb3b77-c000.snappy.parquet",
        "partitionValues": {},
        "size": 27825521,
        "modificationTime": 1602173334000,
        "dataChange": true,
        "stats": "{\"numRecords\":1249744,\"minValues\":{\"id\":3766824,\"firstName\":\"Aaron\",\"middleName\":\"Aaron\",\"lastName\":\"A'Barrow\",\"gender\":\"F\",\"birthDate\":\"1951-12-31T05:00:00.000Z\",\"ssn\":\"666-10-1008\",\"salary\":-20858},\"maxValues\":{\"id\":5016567,\"firstName\":\"Zulma\",\"middleName\":\"Zulma\",\"lastName\":\"Zywicki\",\"gender\":\"M\",\"birthDate\":\"2000-01-30T05:00:00.000Z\",\"ssn\":\"999-98-9985\",\"salary\":180841},\"nullCount\":{\"id\":0,\"firstName\":0,\"middleName\":0,\"lastName\":0,\"gender\":0,\"birthDate\":0,\"ssn\":0,\"salary\":0}}"
    }
}
{
    "add": {
        "path": "part-00001-943ebb93-8446-4a6c-99f7-1ca12ec2511b-c000.snappy.parquet",
        "partitionValues": {},
        "size": 27781558,
        "modificationTime": 1602173334000,
        "dataChange": true,
        "stats": "{\"numRecords\":1249537,\"minValues\":{\"id\":1267751,\"firstName\":\"Abbey\",\"middleName\":\"Abbey\",\"lastName\":\"A'Barrow\",\"gender\":\"F\",\"birthDate\":\"1951-12-31T05:00:00.000Z\",\"ssn\":\"666-10-1005\",\"salary\":-20925},\"maxValues\":{\"id\":2517287,\"firstName\":\"Zulma\",\"middleName\":\"Zulma\",\"lastName\":\"Zywicki\",\"gender\":\"F\",\"birthDate\":\"2000-01-30T05:00:00.000Z\",\"ssn\":\"999-98-9981\",\"salary\":165757},\"nullCount\":{\"id\":0,\"firstName\":0,\"middleName\":0,\"lastName\":0,\"gender\":0,\"birthDate\":0,\"ssn\":0,\"salary\":0}}"
    }
}
...

add 操作,顾名思义,是通过添加单个 逻辑文件 来修改表中的数据。它包含相应文件的元数据以及一些可以用于优化的数据统计信息,我们将在文章的进一步部分讨论这些优化。

日志包含 8 个 add 条目,从 part-00000 到 part-00007,为简化起见已被截断。

如果你希望了解更多关于 Delta 协议的信息,请参阅:github.com/delta-io/delta/blob/master/PROTOCOL.md

Setup

现在我们已经分析了事务日志和数据集,我们将其复制到自己的目录中,以便我们可以修改表。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

保持清洁

Vacuum

第一个明显的答案是 VACUUM 命令。它的作用是删除不再影响我们的 Delta 表的文件,前提是配置了 delta.deletedFileRetentionDuration,默认为 7 天。

在分析数据集和 Delta 日志后,我们发现有 16 个文件,所有文件都比默认保留时间间隔要旧,但日志中仅引用了 8 个文件。这意味着理论上,如果我们运行命令,另外 8 个文件将被清理。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让我们检查一下底层文件系统中的结果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

运行 VACUUM 后的文件,使用默认设置

令人惊讶的是,文件并没有被清理。发生了什么事?

我发现令人惊讶的是,VACUUM 内部使用的时间戳不是 add 操作中的事务日志文件中引用的时间戳,而是文件的 modificationTime。这样做的原因是为了避免读取大量 JSON 文件以查找应选择删除的文件。也就是说,在复制/迁移 Delta 表时,请确保保持 modificationTime 不变。

鉴于我们刚刚复制了整个数据集,modificationTime 现在的时间,因此不会被选择删除,至少在 7 天内不会。如果我们尝试删除,将会收到以下警告:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

出于测试目的,我们将delta.retentionDurationCheck.enable=false,以便我们可以演示命令的实际效果,但这应该谨慎使用,因为如果其他活动的读取器或写入器依赖于被删除的数据,可能会导致表损坏。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

VACUUM 后的文件

看,这样一来,一切看起来都更整洁了。那么事务日志呢?现在有 4 个新的 JSON 文件,每个文件代表一个新的事务。

每次请求VACUUM时,事务日志中会生成两个新的提交,分别包含VACUUM STARTVACUUM END操作。

{
    "commitInfo": {
        "timestamp": 1676202617353,
        "userId": "8019820830300763",
        "userName": "vitor",
        "operation": "VACUUM START",
        "operationParameters": {
            "retentionCheckEnabled": true,
            "defaultRetentionMillis": 604800000
        },
        "notebook": {
            "notebookId": "1087108890280137"
        },
        "clusterId": "0102-173902-b3a5lq4t",
        "readVersion": 0,
        "isolationLevel": "SnapshotIsolation",
        "isBlindAppend": true,
        "operationMetrics": {
            "numFilesToDelete": "0"
        },
        "engineInfo": "Databricks-Runtime/11.3.x-scala2.12",
        "txnId": "6b875d5e-4c0e-4724-a87b-a0a6bbfd8419"
    }
}

第一个没有影响任何文件,因此numFilesToDelete为 0。

{
    "commitInfo": {
        "timestamp": 1676206833338,
        "userId": "8019820830300763",
        "userName": "vitor",
        "operation": "VACUUM START",
        "operationParameters": {
            "retentionCheckEnabled": false,
            "specifiedRetentionMillis": 0,
            "defaultRetentionMillis": 604800000
        },
        "notebook": {
            "notebookId": "1087108890280137"
        },
        "clusterId": "0102-173902-b3a5lq4t",
        "readVersion": 2,
        "isolationLevel": "SnapshotIsolation",
        "isBlindAppend": true,
        "operationMetrics": {
            "numFilesToDelete": "8"
        },
        "engineInfo": "Databricks-Runtime/11.3.x-scala2.12",
        "txnId": "42f93d56-8739-46d5-a8f9-c2c1daffe0ec"
    }
}

第二个标记了 8 个文件进行删除,因此numFilesToDelete为 8。

总之,VACUUM作业对于减少存储成本是必不可少的。然而,我们需要确保定期安排这些作业(它们不会影响任何正在运行的作业),因为它们默认不会被安排。此外,我们还需要确保调整保留值,以便我们希望进行时间旅行时考虑modificationTime,并在迁移 Delta 表时加以考虑。

优化

我们需要注意的下一个命令是OPTIMIZE。这个命令的作用是将小文件压缩成较大的文件,同时保持所有数据完整,并重新计算增量统计信息。它可以显著提高查询性能,特别是当数据是通过流式作业写入时,根据触发间隔,可能会生成很多小文件。

目标文件大小可以通过调整delta.targetFileSize来改变。请记住,设置此值并不能保证所有文件都达到指定大小。该操作将尽最大努力接近目标大小,但这在很大程度上取决于我们处理的数据量以及并行性。

在这个例子中,我们将其设置为 80MB,因为数据集远小于默认大小 1GB。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让我们在运行命令后分析一下事务日志的提交情况:

{
    "commitInfo": {
        "timestamp": 1676215176645,
        "userId": "8019820830300763",
        "userName": "vitor",
        "operation": "OPTIMIZE",
        "operationParameters": {
            "predicate": "[]",
            "zOrderBy": "[]",
            "batchId": "0",
            "auto": false
        },
        "notebook": {
            "notebookId": "1087108890280137"
        },
        "clusterId": "0102-173902-b3a5lq4t",
        "readVersion": 2,
        "isolationLevel": "SnapshotIsolation",
        "isBlindAppend": false,
        "operationMetrics": {
            "numRemovedFiles": "8",
            "numRemovedBytes": "221245652",
            "p25FileSize": "59403028",
            "minFileSize": "59403028",
            "numAddedFiles": "3",
            "maxFileSize": "88873012",
            "p75FileSize": "88873012",
            "p50FileSize": "87441438",
            "numAddedBytes": "235717478"
        },
        "engineInfo": "Databricks-Runtime/11.3.x-scala2.12",
        "txnId": "55389d3e-4dd5-43a9-b5e1-de67cde8bb72"
    }
}

总共删除了 8 个文件,添加了 3 个。我们的新目标文件大小为 80MB,因此所有文件都被压缩成三个新文件。正如提交信息所示,日志中还包含了 8 个remove操作和 3 个add操作,为了简化起见,这些操作被省略了。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

你可能会想知道OPTIMIZE命令在这个特定数据集上是否真的做了有用的事情,所以让我们尝试运行一个简单的查询。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用OPTIMIZE后,我们改善了扫描时间,因为我们读取了更少的文件。然而,我们仍然在尝试查找薪资大于 80000 的记录时读取了整个数据集。我们将在文章的下一部分解决这个问题。

总结来说,应该定期安排OPTIMIZE任务,因为查询读取可以从减少读取文件数量中获益。Databricks 建议每天运行,但实际上这取决于更新的频率。请注意,OPTIMIZE可能需要一些时间,并会增加处理成本。

Z-Order 优化

Z-Ordering是一种用于将相关信息放置在同一组文件中的技术。

当文件被写入 Delta 表时,最小值、最大值和计数统计数据会自动添加到 add action 中的stats字段,如前所述。这些统计数据用于在查询表时进行数据跳过。数据跳过是一种优化,旨在优化包含WHERE子句的查询。默认情况下,数据集的前 32 列会收集统计数据。可以通过调整delta.dataSkippingNumIndexedCols到所需的数字来更改这一点。请注意,这可能会影响写入性能,特别是对于长字符串,建议将其移动到模式的末尾,并将属性设置为低于其索引的数字。

OPTIMIZE示例中,我们看到即使收集了这些统计数据,我们也不能真正利用它们,仍然会读取所有文件。这是因为我们没有任何明确的排序,薪资在所有文件之间基本是随机的。

通过在OPTIMIZE中添加ZORDER-BY列,我们可以轻松解决这个问题:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

让我们分析事务日志:

{
    "commitInfo": {
        "timestamp": 1676217320722,
        "userId": "8019820830300763",
        "userName": "vitor",
        "operation": "OPTIMIZE",
        "operationParameters": {
            "predicate": "[]",
            "zOrderBy": "[\"salary\"]",
            "batchId": "0",
            "auto": false
        },
        "notebook": {
            "notebookId": "1087108890280137"
        },
        "clusterId": "0102-173902-b3a5lq4t",
        "readVersion": 2,
        "isolationLevel": "SnapshotIsolation",
        "isBlindAppend": false,
        "operationMetrics": {
            "numRemovedFiles": "8",
            "numRemovedBytes": "221245652",
            "p25FileSize": "113573613",
            "minFileSize": "113573613",
            "numAddedFiles": "2",
            "maxFileSize": "123467314",
            "p75FileSize": "123467314",
            "p50FileSize": "123467314",
            "numAddedBytes": "237040927"
        },
        "engineInfo": "Databricks-Runtime/11.3.x-scala2.12",
        "txnId": "0e9b6467-9385-42fa-bc1a-df5486fc997f"
    }
}

两个OPTIMIZE命令之间存在一些差异。我们首先注意到的是,如预期的那样,现在在operationParameters中有一个zOrderBy列。此外,尽管我们指定了相同的目标文件大小,但由于列的统计数据,OPTIMIZE结果为 2 个文件而不是 3 个文件。

以下是第一个文件的add操作。统计数据显示该文件包含所有薪资在-26884 和 73676 之间的记录。因此,我们的查询应该完全跳过这个文件,因为薪资值超出了我们WHERE子句的范围。

{
    "add": {
        "path": "part-00000-edb01f4d-18f1-4c82-ac18-66444343df9b-c000.snappy.parquet",
        "partitionValues": {},
        "size": 123467314,
        "modificationTime": 1676217320000,
        "dataChange": false,
        "stats": "{\"numRecords\":5206176,\"minValues\":{\"id\":1,\"firstName\":\"Aaron\",\"middleName\":\"Aaron\",\"lastName\":\"A'Barrow\",\"gender\":\"F\",\"birthDate\":\"1951-12-31T05:00:00.000Z\",\"ssn\":\"666-10-1010\",\"salary\":-26884},\"maxValues\":{\"id\":9999999,\"firstName\":\"Zulma\",\"middleName\":\"Zulma\",\"lastName\":\"Zywicki\",\"gender\":\"M\",\"birthDate\":\"2000-01-30T05:00:00.000Z\",\"ssn\":\"999-98-9989\",\"salary\":73676},\"nullCount\":{\"id\":0,\"firstName\":0,\"middleName\":0,\"lastName\":0,\"gender\":0,\"birthDate\":0,\"ssn\":0,\"salary\":0}}",
        "tags": {
            "INSERTION_TIME": "1602173334000000",
            "ZCUBE_ZORDER_CURVE": "hilbert",
            "ZCUBE_ZORDER_BY": "[\"salary\"]",
            "ZCUBE_ID": "493cfedf-fdaf-4d34-a911-b4663adefec7",
            "OPTIMIZE_TARGET_SIZE": "83886080"
        }
    }
}

通过在 Z-Ordering 文件后再次运行查询,我们可以看到只读取了一个文件,另一个文件被修剪掉了。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

尽管 Z-Ordering 在数据跳过方面看起来是一个改变游戏规则的技术,但必须正确使用才能提高效率。下面我们将列出使用 Z-Ordering 时必须考虑的一些关键点:

  1. Z-Ordering 仅适用于高基数的列,如果列的基数较低,我们无法从数据跳过中受益。

  2. 我们可以在 Z-Order 上指定多个列,但每增加一列,其数据跳过的效果会降低。

  3. 确保仅在有统计数据的列上进行 Z-Order 操作。记住列的索引,只有前 32 列会被分析。

分区

另一种可以使用的技术是物理分区。虽然 Z-ordering 将具有相似值的数据分组到同一文件中,但分区将数据文件分组到同一文件夹下。

与 Z-Ordering 相反,分区在低基数列上效果最佳。如果我们选择其他列,可能会导致无限的分区,最终生成大量小文件,从而引发性能问题。

我们将使用性别作为分区列,因为它是数据集中唯一具有低基数的列。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通过这样做,我们最终得到了两个文件夹,每个文件夹对应一个性别。这种类型的分隔对于具有低基数并且在大表中经常用于WHERE子句的列非常有用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

假设我们现在希望能够根据性别和薪资提取见解。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

OPTIMIZE 可以与分区列配对使用,如果我们只想优化数据的一个子集。下面我们将分析 Z-Ordered 表中有无分区的数据跳过,以展示如何同时利用这两种方法。我们已经减少了目标文件大小,以展示我们的数据在不同文件下按性别拆分后的差异。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如上所示,如果没有分区,我们必须读取两个文件才能获得结果。通过根据薪资进行 Z-Order,我们能够跳过 3 个文件,但必须完全读取这些文件以提取请求的性别。使用分区后,我们能够跳过整个分区,基本上“免费”过滤性别,并且由于 Z-Ordering 跳过了 3 个文件。

正如我们所见,同时使用这两种方法有其好处,但需要经过仔细考虑,因为它可能只对非常大的表格产生显著差异。

结论

总之,保持 Delta 表的清洁对于维持数据管道的性能和效率至关重要。清理和优化 Delta 表有助于回收存储空间并提高查询执行时间。深入了解每个操作的细节对于正确的微调非常重要,否则可能会导致不必要的存储和处理成本。

参考文献

docs.databricks.com/delta/vacuum.html

docs.databricks.com/delta/optimize.html

docs.databricks.com/delta/data-skipping.html

docs.databricks.com/tables/partitions.html

docs.databricks.com/delta/best-practices.html

www.databricks.com/blog/2018/07/31/processing-petabytes-of-data-in-seconds-with-databricks-delta.html

Delta Lake — 分区、Z-Order 和 Liquid Clustering

原文:towardsdatascience.com/delta-lake-partitioning-z-order-and-liquid-clustering-944030ff1828

Delta 中的不同分区/聚类方法是如何实现的?它们在实际中是如何工作的?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Vitor Teixeira

·发表于 Towards Data Science ·10 分钟阅读·2023 年 11 月 8 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 frame harirak 提供,发布在 Unsplash

使大数据变得困难的问题之一一直在于它的名字,它就是“大”。分区,特别是当做得好时,一直是一种通过将需要读取的数据减少到一个子集来提高对大量数据的查询执行时间的方法。然而,分区数据是复杂的,需要仔细考虑和一些前期规划,因为今天适用的要求可能不适合未来的要求。例如,在 Hive 风格的分区中,列可能需要更改或增加其基数,从而使数据过度分区(小文件问题),这会导致数据的完全重组,这并不理想。

Z-Order 聚类是另一种用于数据跳过的技术,同样避免了全面的数据扫描。然而,这种技术有一些局限性。其中之一是新摄取的数据默认情况下未排序,用户需要重新聚类,这意味着已经聚类的数据将被重新聚类和重写,增加了操作所花费的时间。Z-Order 用户还需要每次运行命令时定义聚类列,因为它们不是表属性的一部分。

这就是 Liquid Clustering 进入游戏的地方。前提是它可以无缝地融入当前的数据布局,并且能够适应未来的需求,而无需重写任何已经聚类的数据。

在这篇文章中,我们将解释 Delta 中不同数据修剪策略的细节及其应用方式。

分区修剪 — Hive 风格的分区

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Hive 样式分区 — 作者提供的图片

Hive 样式分区是一种将表组织成小块的方式。这些数据块被组织成几个子文件夹,包含分区值的数据。

dbfs://people10m/gender=M/data_0.json
dbfs://people10m/gender=M/data_1.json
dbfs://people10m/gender=F/data_0.json
dbfs://people10m/gender=F/data_1.json

这种方法不是 Delta 的原生方法,即不属于 Delta 协议的一部分。然而,由于 Delta 建立在 Apache Spark 之上,旧的 Hive 样式分区在某些场景下也可以很好地工作。

几种机制处理这种类型的分区,使得它对最终用户完全不可见。在 Apache Spark 中,当用户读取数据集时,gender 列会自动添加到模式中,并带有相应的值,可以像常规列一样进行查询。这种技术称为 分区发现,由 DataSource’s resolveRelation 处理,它从给定的基本路径推断分区列。另一方面,当用户使用 partitionBy 保存 DataFrame 时,会执行 InsertIntoHadoopFsRelationCommand 作为执行计划的一部分,这会调用 FileFormatWriter,为每个底层 RDD 的分区生成一个写入作业(从最终模式中排除分区列并为其创建桶)。

在上述示例中,由于查询仅选择性别为 F 的数据,它将只需要实际扫描该文件夹,从而有效跳过数据,因为它只读取数据集中的一半文件。这称为 分区剪枝

这种方法有一些缺点,特别是当选择具有非常高基数的分区列或多个分区级别时,这会导致许多小文件,从而导致更差的读取性能。此外,一旦定义了这种分区策略,就不能在不重写所有数据的情况下更改,因为它是在物理层面上定义的。

I/O 剪枝 — Z-Order

另一种有效跳过数据的技术是对文件级统计数据进行过滤。在这种技术中,每个文件都有可用的统计数据,可以作为是否值得读取文件的指标。默认情况下,Delta 存储前 32 列的最小值、最大值和空值计数的统计数据。

people10m公共数据集中的单列id为例。如果我们使用repartitionByRange在该列上将数据排序为 5 个不同的文件,最小/最大统计分布可能类似于以下内容:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按列 ID 范围分区后的文件 — 图片作者

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

选择公司的前 20,000 名员工 — 图片作者

运行上述查询将产生一个良好的计划,因为我们的查询仅筛选该列且所有文件包含不重叠的 ID 集合。这样,数据库引擎更容易选择正确的文件进行扫描,而不会产生假阳性。

如果我们想在查询中添加另一列怎么办?

假设我们还想按员工的薪资进行筛选。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

选择薪资大于 40,000 的公司前 20,000 名员工 — 图片作者

在我们按两列范围分区文件之后,我们最终得到如下结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

按列 ID 和薪资范围分区后的文件 — 图片作者

薪资与 ID 没有直接关系,使用之前的线性方法将文件组织成有效的数据跳过方式将导致数据仅按第一列排序。通过简单地筛选薪资大于 40,000,我们最终会读取所有五个文件,而不仅仅是一个。

我们如何解决这个问题?是否有办法在保持位置性的同时将多个统计信息分组到单一维度,从而使我们的范围分区正常工作?

如果你猜测了 Z-Ordering,你猜对了。如果你猜测了填充空间曲线,你就更对了!

什么是空间填充曲线,我需要关注它吗?空间填充曲线是一种遍历嵌入空间中所有点的曲线。一些曲线能够将这些高维点映射到一个维度,同时保持在原始空间中的邻近性。听起来复杂?其实并不复杂。下面我们将详细介绍这些曲线的工作原理。

Z-Order 曲线

Z-Order 曲线是 Delta 中空间填充曲线聚类的第一次实现,因此得名。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

级别 1 Z-Order 曲线 — 图片作者

Z-Order 值,即形成 Z 形状的曲线的点,是使用称为位交错的技术计算得出的。位交错是一种使用位表示 N 维坐标的方法。例如,如果我们使用 4 位表示(0000 到 1111),我们能够通过逐位分配给每个轴来编码 4x4 网格坐标。接下来,我们将通过一个更直观的示例来展示这种技术。

在 Delta 中,Z-Ordering 用于以使数据跳过操作有效的方式对数据进行分组。所有 Z-Order 列都被“标记”为使用RangePartitionId表达式进行范围分区。该表达式只是一个占位符,将由一个优化器处理,该优化器将对 RDD 进行采样,以找到列的范围边界。(如果你曾经尝试对一个相当大的数据集进行多次 Z-Order,你可能会注意到其文件统计数据是不确定的。这是因为 Delta 使用水库抽样来避免在计算范围 ID 时读取整个数据集)。然后,所有计算出的范围被转换为字节并交错,这样就得出了行的 Z-Order 值。

以下我们将以简化的方式说明 Z-Order 在 Delta 中的工作原理,以一个 6 条记录和 3 个分区的组为例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

针对 6 条记录到 3 个不同范围 ID 的 Z-Order 优化—作者提供的图像

希尔伯特曲线

曲线在保持局部性的能力越强,我们由于误报需要读取的文件就越少。这就是为什么希尔伯特曲线在保持局部性至关重要的场景中更常使用的原因。

在撰写时,希尔伯特曲线尚未在 Delta 的开源版本中实现。然而,它们是 Databricks Z-Order 实现中使用的默认曲线,因为它们相比 Z-Order 曲线在处理高维数据时提供了更好的数据局部性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

希尔伯特曲线—作者提供的图像

希尔伯特曲线可以以四种不同的方式出现,每种方式都是在上述基础上旋转 90º得到的。

但为什么希尔伯特曲线在保持局部性方面比默认的 Z-Order 曲线更好?

希尔伯特曲线的相邻点之间的距离始终为 1。与 Z-Order 不同,这意味着这些跳跃可能会生成具有较大最小/最大差异的 Z-Order 文件,从而使其无用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Z-Order 曲线上的相邻点之间的距离—作者提供的图像

该算法有几个实现,但在这篇文章中,我将介绍 John Skilling 在“编程希尔伯特曲线”中的一个整洁的迭代方法。这个算法可能会让人困惑,因为它包含了一些位操作。如果你不需要了解细节,可以直接跳到下一节。

请注意,由于 Databricks 代码是专有的,以下示例可能不代表当前实现。

J. Skilling 编码方法将位交错并使用格雷码进行编码。这样,每次只改变一个位,因此遍历网格时只会在垂直或水平方向进行。然后,它遍历编码后的位,并应用一系列位交换和反转,最终返回坐标的位表示,可以通过解交错来恢复。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将笛卡尔点转换为希尔伯特索引的 Skilling 变换 — 来源于 编程希尔伯特曲线

类似于 Z-Order,我们需要一种将任意维度的坐标组编码为单一点的方法。为了实现这一点,我们将运行之前的算法,但要反向运行,以便可以检索希尔伯特曲线中的点。然后有两个循环,一个循环将遍历编码位,从最重要到最不重要,直到 p-2,其中 p 是每个轴上的位数,另一个内循环将从最不重要的位迭代到 n-1,其中 n 是维度数。根据当前的位,我们将交换位或反转它们。最后,我们需要对位进行格雷解码,就能得到我们的点。

接下来,我们将介绍如何对坐标 (2, 0) 进行编码,它表示希尔伯特曲线中的点号 14。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将笛卡尔坐标转换为希尔伯特曲线点的算法 — 作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4x4 希尔伯特曲线 — 作者提供的图像

从这里开始,我们假设过程与 Z-Order 实现相同,其中数据被划分范围,并且相近的记录被写入同一文件。

液体聚类

那么,液体聚类究竟是什么?它不过是希尔伯特曲线加上一个名为 ZCube 的新特性,使得增量聚类成为可能!

OPTIMIZE ZORDER BY 命令要求完全重写数据,这对大型表非常昂贵。此外,当 OPTIMIZE ZORDER 命令中出现问题时,一切需要从头开始,这有时会非常麻烦。

什么是 ZCubes?

ZCubes 是由相同 OPTIMIZE 任务生成的一组文件。这样,一个大型表的 OPTIMIZE 任务可以分成几个不同的任务,这些任务会生成新的 ZCube,并在增量日志中生成新的条目,以实现增量聚类。每个新优化的文件将包含 AddFile 元数据中的 ZCUBE_ID 属性,这将使其可能区分优化和未优化的文件(即没有关联 ZCube 的文件)。

有两个新的可配置 ZCube 属性:

  • MIN_ZCUBE_SIZE 设置 ZCUBE 的最小尺寸。低于此尺寸的 ZCUBE 将被视为 OPTIMIZE 任务的一部分,新文件可以被合并,直到尺寸达到此阈值(默认为 100GB)。这些立方体被称为 部分 ZCubes

  • TARGET_CUBE_SIZE 设置完成的立方体的目标尺寸,包含超过目标尺寸的文件。这些立方体被称为 稳定 ZCubes

如果 Delete 命令使大量文件无效,从而使其小于 MIN_ZCUBE_SIZE,稳定 ZCubes 可能会重新变为部分 ZCubes。

它如何无缝适应新的分区列?

当用户更改聚类列时,只有包含相同聚类列的 ZCubes 会被考虑进行优化。其他立方体保持不变,新立方体会被创建。

这在实践中是如何工作的?

当发出 OPTIMIZE table 命令时,Delta 会选择有效的文件用于 ZCube 生成,这些文件是部分 ZCube 的一部分(可以进一步优化),以及新文件。然后,进行规划步骤,将文件打包到多个 ZCubes 下,这些 ZCubes 是相互独立运行的 OPTIMIZE 任务。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

启用液态聚类的 OPTIMIZE 流程 — 作者图示

如何启用/禁用液态聚类?

--New tables
CREATE TABLE <table>
USING delta
CLUSTER BY (<col1>, <col2>,)

--Existing tables
ALTER TABLE <table>
CLUSTER BY (<col1>, <col2>,)

--Remove liquid clustering
ALTER TABLE <table>
CLUSTER BY NONE

由于聚类列是在表级别定义的,OPTIMIZE 命令不需要定义任何参数。

注意:这仍在 提议 中,可能会有所变化。

结论

在这篇博客文章中,我们详细讨论了 Delta Lake 中可用的不同分区和聚类选项。我们讨论了 Hive 风格分区、Z-Order 及其当前问题,展示了液态聚类如何解决这些问题。

液态聚类非常有前途,因为它使用起来更简单,具有增量和更好的聚类性能,并且支持在没有任何开销的情况下更改分区列。如果你对性能感兴趣,这里有几个性能比较,你也可以尝试使用 Databricks Runtime 13.3+。Databricks 推荐将所有当前的分区列和 ZOrder 列更改为聚类列,以获得更好的性能。

如果你在使用开源 Delta,尽管液态聚类功能不可用,请确保查看我之前的帖子,了解如何保持你的表格快速而干净:

## Delta Lake— Keeping it fast and clean

是否曾想过如何提升 Delta 表的性能?手把手教你如何保持 Delta 表快速而干净。

[towardsdatascience.com

参考文献

docs.databricks.com/en/delta/clustering.html

docs.google.com/document/d/e/2PACX-1vREkVPDxqlKrwnaQ7Et1EnaiCF-VhFXCwit7bGSomWKtGEfkxbuGhX4GP3cJ20LgllYfjzsjr2lyY5y/pub#kix.301alpimymwh

pubs.aip.org/aip/acp/article-abstract/707/1/381/719611/Programming-the-Hilbert-curve

en.wikipedia.org/wiki/Z-order_curve

en.wikipedia.org/wiki/Hilbert_curve

民主化 AI:MosaicML 对开源 LLM 运动的影响

原文:towardsdatascience.com/democratizing-ai-mosaicmls-impact-on-the-open-source-llm-movement-7972ff12dd92

高质量基础模型如何为整个行业开启新可能性……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Cameron R. Wolfe, Ph.D.

·发布于 Towards Data Science ·13 min 阅读·2023 年 10 月 15 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(照片由Raimond Klavins 提供,来源于Unsplash

最近,我们回顾了许多关于创建开源大型语言模型(LLM)的当前研究。在所有这些工作中,模型是通过一个包含几个简单组件的共同框架创建的;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

创建和优化大型语言模型(来自[12, 13])的多步骤过程

尽管这个框架有几个步骤,但第一步可以说是最重要的。通过广泛的高质量预训练创建一个更强大的基础模型,可以在通过监督微调(SFT)和从人类反馈中进行强化学习(RLHF)时实现更好的结果。然后,下游应用由于使用了改进的模型而表现得更好。预训练(基础)模型是任何 LLM 应用程序的共同起点。

直到最近,开源基础模型要么与其专有对手相比表现不佳,要么只能用于研究。然而,这种情况随着 MosaicML 发布的 MPT-7B 和 MPT-30B [1, 2]的出现发生了变化。这些开源基础模型达到了令人印象深刻的性能水平,商业使用免费,并且配备了用于训练、微调和评估 LLM 的完整高效软件套件。这些开源工具使得可以以显著降低的成本探索多种专业应用场景,从而成为 AI 从业者的强大资源。

更快的 LLM 和更长的上下文长度

MPT-7B/30B 模型基于典型的 仅解码器变换器 架构。然而,进行了一些关键修改,包括:

在本节中,我们将深入了解这些组件,每个组件的工作原理,以及它们对 LLM 的影响。要全面理解本节的细节,可能需要回顾以下概念:

  • 自注意力 [link]

  • 因果自注意力(由仅解码器 LLM 使用) [link]

ALiBi 实现了上下文长度的外推

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 LLM 中嵌入一个令牌序列(由作者创建)

在一个普通的变换器架构中,我们通过首先 令牌化 原始文本并查找每个令牌的嵌入(令牌化器词汇表中的每个令牌都有一个唯一的嵌入)来创建一个输入令牌序列。然后,我们将位置嵌入添加到每个令牌嵌入中,从而将位置信息注入到序列中每个令牌的嵌入中;见上文。这是必要的,因为自注意力操作对序列中每个令牌的位置是无感知的。尽管位置嵌入工作良好,但有一个大问题:它们难以推广到比训练期间见过的更长的序列

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [6])

解决方案。 带有线性偏差的注意力**(**ALiBi)[6]通过完全去除位置嵌入来解决这个问题。相反,位置信息通过在自注意力操作中对键-查询注意力分数添加一个加性惩罚注入到变换器中;见上文。我们应当回顾,自注意力计算序列中每对令牌之间的注意力分数。ALiBi 通过为这个分数添加一个与令牌对之间的距离成比例的静态、非学习的偏差(或惩罚)来操作;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算特定令牌对的键-查询注意力分数(由作者创建)

这种方法之所以有影响,是因为它依赖于令牌之间的成对距离,而不是序列中令牌的绝对位置。这一量度不那么依赖于基础序列的长度,并允许 ALiBi 对比训练期间看到的序列更长的序列进行更好的泛化;见下文。正如我们将看到的,使用 ALiBi 的 MPT 模型可以训练以支持比大多数开源替代方案更大的上下文长度甚至可以推断到长度为 84K 标记的序列

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [6])

更快的推理

由于使用低精度层归一化和 FlashAttention [7],MPT 模型具有非常快的训练和推理速度(即,比使用标准HuggingFace 推理管道的同等规模的 LLaMA 模型快1.5-2X)。更进一步,这些模型的权重可以迁移到像FasterTransformerONNX这样的优化模块,以实现更快的推理。

低精度层归一化。简单来说,低精度层归一化以 16 位精度执行LayerNorm模块的操作。尽管这种方法在某些情况下可能导致损失峰值,但它改善了硬件利用率,从而加速了训练和推理。使用低精度层归一化对模型的最终性能也几乎没有影响。

Flash attention。在其经典形式中,自注意力是一个O(N²)操作,其中N是输入序列的长度。为了提高该操作的效率,已经提出了许多近似注意力变体,例如:

大多数这些技术的目标是推导出一种“线性”注意力变体——一种具有O(N)复杂度的类似/近似操作。尽管这些变体在理论上减少了FLOPs许多在实际场景中并没有实现任何墙钟速度的提升!Flash attention 通过以 IO 感知的方式重新构建注意力操作来解决这个问题;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [7])

FlashAttention 的硬件实现细节超出了本文的范围。然而,结果高效的注意力实现带来了各种积极的好处。例如,FlashAttention 可以:

  • 将 BERT-large [10]的训练时间提高 15%

  • 将 GPT-2 的训练速度提高3X [11]

  • 为 LLMs 启用更长的上下文长度(由于更好的内存效率)

关于 FlashAttention 的更多细节,请查看 这里

MPT-7B:一个商业可用的 LLaMA-7B

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练 MPT-7B 模型及其各种衍生模型的总计算成本(来自 [1])

在 [1] 中提出的 MPT-7B 是一个开源、商业可用的语言基础模型,性能广泛匹配类似规模的开源基础模型,如 LLaMA-7B [3](它是不可商业使用的!)。根据 Chinchilla [4] 的经验教训,MPT-7B 在一个大的语料库上进行预训练——总计一万亿个标记——这些文本是多样的、公开可用的。用于训练、微调和评估 MPT-7B 的代码完全开源,使得这个模型成为实践者们调整自己专门化大语言模型以解决各种不同下游应用的一个很好的资源或起点

创建基础模型

由于其修改后的架构,MPT-7B 具有几个理想的属性,例如能够泛化到更长的上下文长度和更快的推理速度。此外,我们在 [1] 中看到,这种修改后的架构消除了 MPT-7B 预训练过程中的损失峰值,使得模型可以在没有任何人工干预的情况下进行预训练(假设任何硬件故障都在大语言模型的训练代码中自动处理)!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MPT-7B 在其预训练过程中仅经历硬件故障,这些故障可以自动解决(来自 [1])

训练过程。 尽管大多数大语言模型使用 AdamW 优化器 进行训练,MPT 采用了 Lion 优化器 [8],这提高了训练过程的稳定性。整个训练框架基于 PyTorch 的 完全分片数据并行 (FSDP) 包,不使用管道或张量并行。简单来说,MPT-7B 的训练框架是 完全开源的,使用了流行/常见的组件,但进行了几个有用的修改,以提高训练的稳定性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [1])

数据。 用于训练 MPT-7B 的文本语料库是由公开数据集(主要是英语数据)定制混合而成;见上文。在 [1] 中,我们看到用于训练 MPT-7B 的数据量非常大 — 总计 1T 个标记。作为对比,开源模型如 PythiaStableLM 分别在 300B 和 800B 个标记上进行预训练。有趣的是,我们看到 [1] 的作者采用了一种非常特定的分词器 — GPT-NeoX-20B BPE 分词器 — 进行模型训练。这个分词器是受欢迎的,因为它在一个大规模、多样化的数据集上进行训练,并且比其他流行的分词器更一致地处理空格。

“这个分词器具有许多令人满意的特性,其中大多数与代码分词相关:在包括代码的数据混合上训练,应用一致的空格分隔(不像 GPT2 分词器根据前缀空格的存在不一致地进行分词),并且包含了对重复空格字符的处理。” — 来源于 [1]

作为从业者,我们应该始终关注模型所使用的分词器。这个选择 — 尽管通常被忽视或忽略 — 会对我们的结果产生极大影响。例如,基于代码的语言模型需要一个以特定方式处理空格的分词器,而多语言模型则有各种独特的分词考虑因素。

它的表现如何?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来源于 [1])

MPT-7B 在标准基准测试中与各种开源模型进行比较(例如,LLaMA,StableLMPythiaGPT-NeoXOPTGPT-J)。如上所示,LLaMA-7B 相较于开源替代方案取得了显著的改进,而 MPT-7B 的表现与 LLaMA 相匹配或超过其表现。近期的开源大型语言模型比其前身要好得多!LLaMA-7B 和 MPT-7B 相比其他开源模型都是极其高效的基础模型。然而,MPT-7B 可以用于商业用途,而 LLaMA 仅能用于研究。

MPT-7B 的衍生模型

除了发布 MPT-7B 基础模型外,[1] 的作者还利用 MPT 的开源训练代码来微调多个不同的基础模型衍生版本(见下文)。与从头开始预训练一个大型语言模型相比,微调的成本非常低(即,时间和成本减少 10–100 倍,甚至更多)。因此,开发 MPT-7B 的大部分时间和精力都投入到了创建基础模型上,该模型作为微调下述模型的起点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [1])

MPT-StoryWriter-65K(商业版) 是 MPT-7B 的一个版本,经过了在非常长上下文长度的数据上进行微调。特别是,文献中的作者 [1] 利用了包含虚构书籍摘录的 books3 dataset,以创建一个用于微调的数据集(即仅使用下一个令牌预测目标),上下文长度为 65K 令牌。由于使用了 ALiBi [6] 和 FlashAttention [7],MPT-StoryWriter-65K 可以有效地在如此大的输入上进行训练,能够处理《了不起的盖茨比》的全部内容(68K 令牌)以编写后记(见上文),甚至可以推广到处理长度达 84K 令牌的序列。

“我们期望 LLMs 将输入视为需要遵循的指令。指令微调是训练 LLMs 以这种方式执行指令跟随的过程。通过减少对巧妙提示工程的依赖,指令微调使 LLMs 更加易于访问、直观和立即可用。” ——来自 [1]

MPT-7B-Instruct(商业版)MPT-7B-Chat(非商业版) 是 MPT-7B 的指令调整版本。指令版是在 Dolly-15KHelpful and Harmless dataset 数据上进行微调的,而聊天模型则使用了来自 ShareGPTHC3AlpacaEvol-Instruct 等来源的数据进行训练。如上文所述,指令调整是指在预训练语言模型的基础上,对其风格或行为进行修改,使其更加直观和易于访问,通常着重于指令跟随或问题解决。

MPT-30B:一个开源的 GPT-3 替代品

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MPT-30B 在所有性能类别上都优于 MPT-7B(来自 [2])

在其提出后不久,MPT-7B 模型在 AI 研究界获得了显著认可 —— 它甚至在 HuggingFace 上累计下载量超过了 300 万次!MPT-7B 的成功并不令人意外,因为它为极受欢迎的 LLaMA-7B 模型提供了一个商业上可用的替代品。借此势头,MosaicML 的研究人员推出了一个稍大的模型,称为 MPT-30B [2],该模型被发现能与 GPT-3 [9] 的表现相匹敌或超过。因此,MPT-30B 的提出延续了为任何人提供强大基础 LLM 的商业可用版本的趋势。

深入了解 MPT-30B

MPT-30B 共享与 MPT-7B 相同的修改过的解码器架构,使用了 FlashAttention 和低精度层归一化,以提高效率。总体而言,这些模型非常相似,除了 MPT-30B 更大一些。值得注意的是,MPT-30B 的大小选择得非常具体。这个大小的模型可以在单个 GPU 上使用 8 位或 16 位精度进行部署,而像 Falcon-40B 这样的替代品稍微大了一些,无法以这种方式部署。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [2])

有什么不同? MPT-30B 与 MPT-7B 主要有两方面的不同:

  • 预训练数据混合

  • 上下文长度

MPT-30B 的预训练数据集类似于 MPT-7B,但数据的混合略有不同;见上文。此外,MPT-30B 部分使用 8K 上下文长度进行训练,而大多数其他开源模型(例如,LLaMA、Falcon 和 MPT-7B)使用较短的 2K 令牌上下文长度进行训练。更具体地说,我们看到 MPT-30B 使用一种训练课程,模型首先使用 2K 上下文长度进行训练,然后在训练后期切换到 8K 上下文长度。在第二阶段,数据集中代码的比例增加了 2.5X,使得最终模型在编程能力上比其他开源 LLM 更强。

模型变体。 除了 MPT-30B 基础模型外,[2] 中的作者还发布了 MPT-30B 的聊天和指令变体。这些模型遵循类似于 MPT-7B-Instruct 和 MPT-7B-Chat 的训练策略。然而,这些模型用于指令调优的数据显著增加。有趣的是,发现 MPT-30B-Chat 在编程技能上表现出色!

它表现得好吗?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [2])

除了在各种类别上优于 MPT-7B 外,MPT-30B 的表现与顶级开源替代品如 LLaMA-30B 和 Falcon-40B 相当;见上文。总体而言,我们发现 MPT-30B 在解决基于文本的任务时落后于 Falcon 和 LLaMA,但在编程相关问题上通常优于这些模型(可能是因为预训练数据集中代码的比例较高!)。值得注意的是,我们发现 MPT-30B 在各种上下文学习任务上优于 GPT-3;见下文。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [2])

结合这些结果来看,像 MPT-30B 这样的模型可能为开源 LLM 应用奠定了基础,能够与专有系统的质量相媲美。我们所需要的只是足够的精细调整和微调!

最终备注

“你可以训练、微调和部署你自己的私人 MPT 模型,可以从我们的检查点之一开始,或从头开始训练” — 来自 [2]

MosaicML 提供的基础模型是开源 LLM 社区的一大进步,因为它们提供了与 LLaMA 和 GPT-3 等流行基础模型相当的商用 LLM。然而,这一开源产品不仅仅限于 MPT 模型本身——它还包括一个 用于训练 LLM 的开源代码库,各种 在线演示 等。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(来自 [2])

MPT-7B 和 30B 模型配备了一个完整的开源工具生态系统,可用于创建专业化/个性化的 LLMs。鉴于创建基础模型是任何基于 LLM 系统中最昂贵的部分(见上文),这些工具显著降低了使用 LLM 的门槛,并提供了一个解决各种下游应用的起点。记住,当我们有一个特定的任务需要解决时,微调是极其有效的(即,仅通过提示一个更通用的 LLM 难以超越)!

与我联系!

非常感谢您阅读这篇文章。我是 Cameron R. WolfeRebuy 的 AI 总监。我研究深度学习的经验和理论基础。如果您喜欢这个概述,请订阅我的 Deep (Learning) Focus 新闻通讯,在其中我通过从基础上概述相关主题帮助读者理解 AI 研究。您还可以在 XLinkedIn 上关注我,或者查看我在 medium 上的 其他文章

参考文献

[1] “介绍 MPT-7B: 开源商用 LLM 的新标准。” MosaicML,2023 年 5 月 5 日,www.mosaicml.com/blog/mpt-7b.

[2] “MPT-30B: 提升开源基础模型的标准。” MosaicML,2023 年 6 月 22 日,www.mosaicml.com/blog/mpt-30b.

[3] Touvron, Hugo, 等。“Llama: 开放且高效的基础语言模型。” arXiv 预印本 arXiv:2302.13971 (2023)。

[4] Hoffmann, Jordan, 等。“训练计算最优的大型语言模型。” arXiv 预印本 arXiv:2203.15556 (2022)。

[5] Zhang, Susan, 等。“OPT: 开放的预训练变换器语言模型。” arXiv 预印本 arXiv:2205.01068 (2022)。

[6] Press, Ofir, Noah A. Smith, 和 Mike Lewis。“训练短期,测试长期:具有线性偏置的注意力机制实现输入长度外推。” arXiv 预印本 arXiv:2108.12409 (2021)。

[7] Dao, Tri, 等。“Flashattention: 快速且内存高效的准确注意力机制,具有 IO 感知能力。” 神经信息处理系统进展 35 (2022): 16344–16359。

[8] 陈向宁等人。“优化算法的符号发现。” arXiv 预印本 arXiv:2302.06675 (2023)。

[9] 布朗汤姆等人。“语言模型是少样本学习者。” 神经信息处理系统进展 33 (2020): 1877–1901。

[10] 德夫林雅各布等人。“Bert: 深度双向变换器的预训练以实现语言理解。” arXiv 预印本 arXiv:1810.04805 (2018)。

[11] 拉德福德亚历克等人。“语言模型是无监督的多任务学习者。”

[12] 欧阳龙等人。“通过人类反馈训练语言模型以遵循指令。” 神经信息处理系统进展 35 (2022): 27730–27744。

[13] 格莱斯艾米莉亚等人。“通过有针对性的人工评判改进对话代理的对齐。” arXiv 预印本 arXiv:2209.14375 (2022)。

使用 AWS SageMaker AutoML 实现机器学习的民主化

原文:towardsdatascience.com/democratizing-machine-learning-with-aws-sagemaker-automl-150299c70396

概述

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Patrick Brus

·发表于 Towards Data Science ·16 分钟阅读·2023 年 4 月 25 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Joshua Sortino 提供,来源于 Unsplash

介绍

AI 目前仍然是最热门的话题之一,尤其是随着 ChatGPT 的崛起。许多公司现在正尝试利用 AI 从数据中提取有用的见解,以优化他们的流程或开发更好的产品。

然而,构建有效的 AI 模型需要在不同领域具备大量专业知识,如数据预处理、模型选择、超参数调优等。所有这些领域都可能耗时且需要专业知识。

这就是 AutoML 发挥作用的地方。AutoML 自动化了构建 AI 模型所需的许多领域。

AutoML 正迅速成为企业和数据科学家们喜爱的解决方案。它使组织能够利用 ML 和 AI 做出明智的决策,而无需成为数据科学专家。随着企业对 ML 需求的增加,AutoML 提供了一种简单高效的方式来创建准确的模型,无论个人的专业知识水平如何。

在本文中,我们将考察市场上一个非常流行的 AutoML 工具——AWS SageMaker AutoML,并展示如何利用它解决复杂的 ML 使用案例。

我将用传统的手动方法训练一个模型,并将其结果与 AWS SageMaker AutoML 产生的结果进行比较。

我将使用 Kaggle 上的信用卡欺诈检测数据集进行对比 [1]。你可以在这里找到数据集。

在本文结束时,你将清楚了解 AutoML 如何帮助利用 ML 驱动有意义的见解并做出明智的决策。

AWS SageMaker AutoML

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:AWS SageMaker AutoML 概述(作者提供的图像,基于 [2])。

图 1 展示了 AWS SageMaker AutoML 解决的不同步骤概述。

其包含以下步骤:

  1. 数据准备: 你可以轻松地将数据上传到 Amazon S3。一旦数据上传完毕,SageMaker AutoML 会自动分析数据,以检测任何缺失值、异常值或需要转换的数据类型。

  2. 自动模型创建: AWS SageMaker AutoML 自动训练多个机器学习模型,使用不同的超参数和算法,以确定最适合你数据的模型。它还提供了自动模型调优,调整所选模型的超参数以进一步优化其性能。它还为你创建运行模型选择的笔记本,以便你可以全面了解在此过程中执行了什么。

  3. 模型部署: 一旦选择了最佳模型,AWS SageMaker 提供将模型部署到 SageMaker 端点或批处理转换作业的选项,在那里它可以用于对新数据进行预测。此外,AWS SageMaker Model Monitor 可以用于在出现问题(如数据漂移、概念漂移等)时发出警报。它还提供了重新训练模型的新数据、更新模型的超参数或算法以提升性能的工具。

AWS SageMaker AutoML 提供了一个 Python SDK,可以用于启动 AutoML 作业,并且有一个 GitHub 仓库,其中包含了各种不同的笔记本示例,展示了如何利用 AutoML SDK 处理具体的机器学习用例。

市场上还提供了其他强大且知名的 AutoML 工具,如 Google Cloud AutoML 和 H2O.ai,它们也有各自独特的优势和劣势。

Google Cloud AutoML 以其易用性和直观的界面而闻名,这使其非常适合机器学习新手,并且对编码要求不高。Google Cloud AutoML 支持图像数据、视频数据、文本数据和表格数据。你可以在 这里 了解更多信息。

H2O.ai 以其速度和可扩展性而闻名,使其成为处理大数据集和复杂模型的良好选择。H2O.ai 提供 R、Python 或网页 GUI 的接口。你可以在 这里 了解更多关于其功能的信息。

手动训练方法

在使用 AWS SageMaker AutoML 来生成信用卡数据集的分类器之前,我首先以传统方式训练一个模型:从头开始做所有事情。

这将有助于我建立一个基准,并将我的方法与 AWS 的 AutoML 方法进行比较,期望 AWS SageMaker AutoML 能超越我的手动半优化方法。

对于手动方法,我使用了 Scikit-learn,并将通过下一章中强调的步骤进行操作。

你也可以在我的 GitHub 仓库中找到完整的笔记本 这里

数据准备

我从 CSV 文件中加载数据集,首先检查数据集的分布。这显示数据集高度不平衡,只有**0.17%**的样本是正例。

数据集本身不包含任何缺失值。

然后我将数据集按 80/20 的比例拆分为训练集和测试集,并将数据缩放到 0–1 的范围内,同时标准化器仅在训练集上进行训练,以避免一些过于乐观的结果。

下面可以找到这些步骤的代码。

import sys
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# step 1: Load the dataset from the csv file. 
# You can download the dataset from Kaggle
filepath = os.path.join("data", "creditcard.csv")
df = pd.read_csv(filepath)

# step 2: check data imbalance on target
count_neg_class = np.sum(df["Class"] == 0)
count_pos_class = np.sum(df["Class"] == 1)

print(f"There are {count_neg_class} negative samples ({np.round(100 * count_neg_class / num_samples, 2)} % of total data).")
print(f"There are {count_pos_class} positive samples ({np.round(100 * count_pos_class / num_samples, 2)} % of total data).")

# step 3: split data into train and test set
X = df.drop(columns="Class").to_numpy()
y = df["Class"].to_numpy()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# step 4: scale the data
scaler = StandardScaler()

X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

通常,广泛的探索性数据分析(EDA)也会是数据准备步骤的一部分。但是为了这个实验,我没有进行广泛的 EDA,因为数据集已经为 ML 做好了充分的准备。

但请记住,这部分对于你的 ML 训练的成功也是至关重要的,并且通常需要一些时间。

模型选择

下一步是找出哪种 ML 算法最适合数据。为此,我首先使用逻辑回归训练一个非常简单的基准模型。这是为了有一个简单的模型,我可以将其与更复杂的算法进行比较。

目标应始终是:保持简单!不要从神经网络开始,如果一个简单的算法,比如逻辑回归,能够完成任务,神经网络可能更难以解释。

逻辑回归模型的 F1-Score 为70.6%。我使用 F1-Score 来评估这个数据集,因为它高度不平衡,准确率不会提供有意义的模型评估,因为仅预测所有类别为负例就已经会导致超过**99%**的准确率!

下面可以找到训练基准模型的代码。

from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

from sklearn.linear_model import LogisticRegression

log_model = LogisticRegression()
log_model.fit(X_train, y_train)

preds = log_model.predict(X_test)
print(f"Test Acc: {accuracy_score(y_test, preds)}")
print(f"Test F1-Score: {f1_score(y_test, preds)}")
print(f"Test Precision: {precision_score(y_test, preds)}")
print(f"Test Recall: {recall_score(y_test, preds)}")

好的,我们现在有了基准。接下来尝试不同的分类算法及其默认超参数,看看哪个算法在数据上表现最好。

我使用了 5 折交叉验证来训练以下每个模型:

  • 决策树

  • 支持向量机

  • k-近邻

  • 随机森林

  • 自适应提升

from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.model_selection import cross_val_predict

dict_models = {
    "Decision Tree": DecisionTreeClassifier(),
    "SVM": SVC(),
    "Nearest Neighbor": KNeighborsClassifier(),
    "Random Forest": RandomForestClassifier(),
    "Ada Boost": AdaBoostClassifier()
}

# train all models by using the models dictionary
results_dict = {}
for model_name, model in dict_models.items():
    print(f"Start training {model_name}...")
    preds = cross_val_predict(model, X_train, y_train, cv=5)

    f1 = f1_score(y_train, preds)
    precision = precision_score(y_train, preds)
    recall = recall_score(y_train, preds)

    print(f"F1-Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print("\n\n")
    results_dict[model_name] = (f1, precision, recall)

# create a pandas dataframe with the results on sort on f1-score
df_results = (pd.DataFrame.from_dict(results_dict, orient="index", columns=["F1-Score", "Precision", "Recall"])
             .sort_values(by="F1-Score", ascending=False))
df_results

随机森林算法以86.9%的 F1-Score 取得了最佳结果,其次是最近邻算法,F1-Score 为84.8%。不错!

下一步是对获胜者(随机森林)进行微调。

为此,我选择了一些要尝试的超参数值,并使用随机化交叉验证搜索来找到最佳超参数组合,从而获得最佳模型。

这次评估的代码:

from sklearn.model_selection import RandomizedSearchCV

params = {
    "n_estimators": [10, 20, 30, 60, 80, 100],
    "criterion" : ["gini", "entropy"],
    "max_depth" : [4, 5, 10, None],
    "min_samples_split": [2, 4, 6],
    "class_weight": [None, "balanced", "balanced_subsample"]
}

clf_rf = RandomizedSearchCV(RandomForestClassifier(), params, n_iter=50, scoring="f1", cv=5, verbose=1, n_jobs=-1)
clf_rf.fit(X_train, y_train)

# let's print the best score and save the best model
print(f"Best f1-score: {clf_rf.best_score_}")
print(f"Best parameters: {clf_rf.best_params_}")
best_random_forest_model = clf_rf.best_estimator_

最佳模型的得分多多少少与我在没有调整超参数的情况下获得的模型相同。真是浪费时间 😉

模型评估

最后但同样重要的是,我在留出的测试集上评估我的最终模型,以查看它在真实世界数据上的表现。

final_preds = best_random_forest_model.predict(X_test)

f1 = f1_score(y_test, final_preds)
precision = precision_score(y_test, final_preds)
recall = recall_score(y_test, final_preds)

print(f"F1-Score: {f1}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")

这给了我一个**82%**的 F1 分数,接近我们的验证结果,但略低一些。

我知道通过更多地调整模型可以获得更多收益。但本文的目标只是做一些基本的机器学习,并将结果与 AutoML 进行比较,以查看 AutoML 的表现有多好,以及与仅自己做基础训练相比,运行 AutoML 作业的工作量有多大。

使用 AWS SageMaker AutoML 进行训练

好了,现在我有了基准,可以尝试使用 AWS SageMaker AutoML 以更少的努力获得更好的模型。

我将再次指导你完成不同的步骤,并提供所有这些步骤的代码片段。你还可以在我的 GitHub 仓库这里找到完整的笔记本。

数据上传

为了使 SageMaker AutoML 正常工作,数据需要存储在 s3 中。因此,我首先创建一个桶,然后将 CSV 文件上传到这个桶中(Gif 1)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Gif 1:创建 s3 桶并将 CSV 文件上传到此桶中(Gif 由作者提供)。

设置 SageMaker 笔记本

下一步是设置我可以在其中运行 AutoML 作业的环境。

为了实现这一点,我首先在 SageMaker 中创建一个笔记本,然后在其中创建和运行代码。

在 AWS 中,其他服务对服务的访问由 IAM 角色处理。SageMaker 笔记本附带了一个带有默认访问权限的 IAM 角色。但为了访问我之前创建的 s3 桶,我首先必须明确调整附加到该角色的策略。

Gif 2 展示了创建笔记本的完整过程以及我如何调整笔记本角色的策略。质量不太好,但我仍然认为看到所采取的操作顺序是有价值的。此外,我还添加了一些创建笔记本时的具体设置截图(图 2),并在我的 GitHub 仓库这里添加了我附加到笔记本角色的完整策略。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Gif 2:创建笔记本的过程以及如何调整附加 IAM 角色的策略(Gif 由作者提供)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:AWS SageMaker 笔记本创建设置(图片由作者提供)。

运行 AWS SageMaker AutoML 作业

现在我终于可以开始运行一些代码了。

我运行的代码大多是从这个AWS 教程笔记本中复制并调整过来的。

首先,我将数据从 s3 加载到 Pandas 数据框中,并设置一些 SageMaker AutoML 之后所需的一般变量。

import numpy as np 
import pandas as pd
import boto3
import sagemaker
import os, sys

# get some variables required for AutoML later
sess   = sagemaker.Session()
bucket = sess.default_bucket()                     
region = boto3.Session().region_name
prefix = 'sagemaker/fraud-detection-auto-ml'
# Role when working on a notebook instance
role = sagemaker.get_execution_role()

# get some sagemaker clients
sm = boto3.Session().client(service_name='sagemaker',region_name=region)
sm_rt = boto3.Session().client('runtime.sagemaker', region_name=region)

# load data from s3
bucket_data = 'patrick-fraud-detection-ml-kaggle'
filename = 'creditcard.csv'
s3 = boto3.client('s3') 
obj = s3.get_object(Bucket=bucket_data, Key=filename) 
df = pd.read_csv(obj['Body']) # 'Body' is a key word

然后,我将数据集分成训练集和保留测试集。我将使用后者来比较 AutoML 模型和我自己训练的模型。然后,我将数据上传到 SageMaker 创建的 s3 桶中,以便 AutoML 任务可以直接从 s3 访问它。

from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(df, test_size=0.2)

# Save to CSV files and upload to S3
train_file = "automl-train.csv"
train_data.to_csv(train_file, index=False, header=True, sep=',') # Need to keep column names
train_data_s3_path = sess.upload_data(path=train_file, key_prefix=prefix + "/train")
print("Train data uploaded to: " + train_data_s3_path)

# save test file only to a CSV file 
# -> will be send as POST request to inference endpoint later
test_file = "automl-test.csv"
test_data.to_csv(test_file, index=False, header=False, sep=',')

现在我可以设置 AutoML 任务并启动它。你可以在 SageMaker SDK 文档这里找到有关所需输入参数和设置的更多信息。

from time import gmtime, strftime, sleep
# setup config for input data
input_data_config = [{
      'DataSource': {
        'S3DataSource': {
          'S3DataType': 'S3Prefix',
          'S3Uri': 's3://{}/{}/input'.format(bucket,prefix)
        }
      },
      'TargetAttributeName': 'Class'  # the column we want to predict
    }
]

# setup config for output data
output_data_config = { 'S3OutputPath': 's3://{}/{}/output'.format(bucket,prefix) }

# Optional parameters
problem_type = 'BinaryClassification'
job_objective = { 'MetricName': 'F1' } # using F1 because of highly imbalanced dataset

# launch the AutoML job 
# but: limit to max. 20 candidates to limit overall execution time
timestamp_suffix = strftime('%d-%H-%M-%S', gmtime())

auto_ml_job_name = 'fraud-detection-' + timestamp_suffix

sm.create_auto_ml_job(AutoMLJobName=auto_ml_job_name,
                      InputDataConfig=input_data_config,
                      OutputDataConfig=output_data_config,
                      AutoMLJobConfig={"CompletionCriteria": {"MaxCandidates": 20}},
                      AutoMLJobObjective=job_objective,
                      ProblemType=problem_type,
                      RoleArn=role)

任务现在在后台运行,为你创建所需的 AWS 资源。

然后,你可以运行以下代码来跟踪 AutoML 任务的进度:

job_run_status = sm.describe_auto_ml_job(AutoMLJobName=auto_ml_job_name)['AutoMLJobStatus']

print(job_run_status)

while job_run_status not in ('Failed', 'Completed', 'Stopped'):
    describe_response = sm.describe_auto_ml_job(AutoMLJobName=auto_ml_job_name)
    job_run_status = describe_response['AutoMLJobStatus']

    print (describe_response['AutoMLJobStatus'] + " - " + describe_response['AutoMLJobSecondaryStatus'])
    sleep(60)

任务正在通过以下阶段:

  1. 数据分析

  2. 特征工程

  3. 模型调优

  4. 合并 AutoML 任务报告

AWS SageMaker 为你生成了两个笔记本。一个用于探索数据,另一个用于定义在数据集上进行评估的不同候选模型。如果你对 AWS SageMaker 在这些阶段运行的代码感兴趣,可以查看这些笔记本。

可以列出 AWS SageMaker 执行的所有实验,并且你也可以列出所有探讨过的候选模型。我在这篇文章中没有添加代码,但如果你对其工作原理感兴趣,可以在我的笔记本中找到代码。

在测试集上评估最佳候选模型

现在是时候在保留测试集上测试最佳候选模型了。对我来说,这是最有趣的部分,因为它显示了 AutoML 是否能比我的手动方法得分更高。

首先,我从 AWS SageMaker AutoML 任务中检索最佳候选模型。

best_candidate = sm.describe_auto_ml_job(AutoMLJobName=auto_ml_job_name)['BestCandidate']
best_candidate_name = best_candidate['CandidateName']

接下来,我将在 AWS 中将此模型作为端点进行托管,以便我可以将数据发送给它进行推断。

timestamp_suffix = strftime("%d-%H-%M-%S", gmtime())
model_name = best_candidate_name + timestamp_suffix + "-model"

# create a model in SageMaker that can be hosted as endpoint
model_arn = sm.create_model(
    Containers=best_candidate["InferenceContainers"], ModelName=model_name, ExecutionRoleArn=role
)

# setup config for endpoint (including instance type)
epc_name = best_candidate_name + timestamp_suffix + "-epc"
ep_config = sm.create_endpoint_config(
    EndpointConfigName=epc_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.m5.2xlarge",
            "InitialInstanceCount": 1,
            "ModelName": model_name,
            "VariantName": "main",
        }
    ],
)

# deploy endpoint
ep_name = best_candidate_name + timestamp_suffix + "-ep"
create_endpoint_response = sm.create_endpoint(EndpointName=ep_name, EndpointConfigName=epc_name)

# wait until endpoint is ready for inference
sm.get_waiter("endpoint_in_service").wait(EndpointName=ep_name)

最后但同样重要的是,读取包含测试数据的 CSV 文件,并将数据发送到最终模型进行推断。然后,我将预测结果与真实值进行比较,并使用相当手动的方法来计算真正例、假负例、假正例和真正例。

老实说,我只是懒得自己实现一些东西,而是再次从 AWS SageMaker 教程笔记本中适配了代码,你可以在这里找到。

tp = tn = fp = fn = count = 0

with open('automl-test.csv') as f:
    lines = f.readlines()
    for l in lines[1:]:   # Skip header
        l = l.split(',')  # Split CSV line into features
        label = l[-1]     # Store 0/1 label
        l = l[:-1]        # Remove label
        l = ','.join(l)   # Rebuild CSV line without label

        response = sm_rt.invoke_endpoint(EndpointName=ep_name, ContentType='text/csv', Accept='text/csv', Body=l)

        response = response['Body'].read().decode("utf-8")
        #print ("label %s response %s" %(label,response))

        if '1' in label:
            # Sample is positive
            if '1' in response:
                # True positive
                tp=tp+1
            else:
                # False negative
                fn=fn+1
        else:
            # Sample is negative
            if '0' in response:
                # True negative
                tn=tn+1
            else:
                # False positive
                fp=fp+1
        count = count+1
        if (count % 100 == 0):   
            sys.stdout.write(str(count)+' ')

# get final scores
 # Confusion matrix

accuracy  = (tp+tn)/(tp+tn+fp+fn)
precision = tp/(tp+fp)
recall    = tp/(tp+fn)
f1        = (2*precision*recall)/(precision+recall)

print ("Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f" % (accuracy, precision, recall, f1))

最终模型在保留测试集上达到了**96%**的 F1-Score!

这太棒了!相比之下,我用 Scikit-Learn 训练的模型仅达到了**82%**的 F1-Score。

AutoML 的不足之处

AutoML 确实强大,可以帮助加快 ML 开发周期。但也存在一些使用 AutoML 的不足之处,需要认识到。

AutoML 的主要限制之一是它可能是一种黑箱方法,因为它自动化了构建机器学习模型的大部分过程。这可能使数据科学家很难完全理解模型的工作原理,并可能限制他们微调模型或调试出现的问题的能力。

AWS 通过提供 Jupyter 笔记本来应对这一问题,展示了数据探索或探索不同 ML 候选模型等阶段的代码。这已经有助于获取一些见解,但如果代码或 AutoML 任务的结果出现任何问题,数据科学家无法对背后的代码进行更改,因为一切都是自动生成的。

使用 AutoML 的另一个潜在缺点是,它可能比传统的机器学习模型构建方法灵活性差。AutoML 优化了效率和易用性,但这可能以定制选项或处理专用数据集或模型的能力为代价。

在这篇文章中,我使用了一个非常简单的数据集,AWS SageMaker AutoML 能够训练出一个不错的候选模型。但如果遇到更具挑战性的数据集,AutoML 在这些数据集上的表现如何还不清楚。

个人发现

在这一章中,我想强调我在使用 AWS SageMaker AutoML 时的个人发现。

我首先意识到使用和设置 AWS SageMaker AutoML 相当复杂。我个人在使用 AWS 和编码方面有一定经验,但对于没有太多先前知识的人来说,使用 AWS SageMaker AutoML 的学习曲线可能过于陡峭。

我还认为文档质量不够好。我一开始在文档中很难找到关于端到端用例的内容。后来我找到了视频培训和 GitHub 上的示例笔记本,但我个人更喜欢书面的文档而非视频。

另外,请注意成本。我最初在 SageMaker AutoML 中进行了一些尝试,频繁运行 AutoML 任务,因为我想在其中尝试不同的东西。但这并没有像我预期的那样便宜,结果我在这个实验上的花费比计划的要多。

结论

在这篇文章中,我将 AWS SageMaker AutoML 与使用 Scikit-Learn 手动训练模型进行了比较。

作为数据集,我决定使用欺诈检测数据集,因为它由于高度不平衡而存在一些困难。

然后我手动尝试找出一个好的分类器,并在 AWS SageMaker AutoML 中做了同样的尝试。

我最终将两种方法的结果在一个保留测试集上进行了比较,其中 AWS SageMaker AutoML 的得分高于我的手动方法,达到了96%的 F1 分数,而我的手动方法为82%

这表明,使用 AWS SageMaker AutoML 在你的数据集上训练 ML 模型以快速生成一个可用的分类器是完全合理的。

甚至不需要在机器学习领域有专门的知识就可以使用 AWS SageMaker AutoML。

当然,我的手动方法也需要谨慎对待。

我没有花很多时间来优化我的最终分类器,我相当确定,如果再多投入一些时间,我会在保留测试集上获得更好的结果。

至少我希望如此 😉

但这篇文章的目的就是展示如何轻松地利用 AutoML 库为你的机器学习数据集创建一个分类器。

最终,这不一定是你投入到产品中的最终版本,但你可以至少做一个快速的初步概念验证,以查看你是否能从数据中获得有用的信息。

展望

迄今为止,我只深入了解了 AWS SageMaker AutoML,但肯定也很有兴趣查看其他服务,例如 Google Cloud AutoML。

我计划在不久的将来对 Google Cloud 的 Vertex AI 进行全面评估,然后会在 Medium 上写出我的发现。

然后我还可以具体讨论 Sagemaker 与 Google Cloud 的表现相比如何。

所以如果你不想错过这些内容,请关注我!

感谢你读到我的文章的最后!我希望你喜欢这篇文章。如果你想将来阅读更多类似的文章,关注我以保持更新。

如果你想了解更多关于机器学习和云计算的信息,请加入我的邮件列表。

联系方式

LinkedIn | GitHub

参考文献

[1]: 机器学习组 — ULB,“信用卡欺诈检测”,Kaggle,2018 年,数据库内容许可证 (DbCL) v1.0

[2]: AWS,Amazon SageMaker Autopilot(访问于 2023 年 4 月 1 日)

解密数据回填

原文:towardsdatascience.com/demystify-data-backfilling-cf1713d7f7a3

让我们谈谈数据工程师的噩梦

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Xiaoxu Gao

·发表于 Towards Data Science ·10 分钟阅读·2023 年 11 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者创建

作为数据工程师,我们每天都会遇到独特的挑战。但如果有一项令人畏惧的任务,那一定是回填。回填不当意味着处理时间过长、数据污染以及高额的云账单。对了,这也意味着你需要另一个回填任务来修复它。

完成第一次成功的数据回填是数据工程师的一项重要经历。— Dagster

回填任务需要一系列数据工程技能才能有效完成,例如验证结果的领域知识、运行回填任务的工具专长以及优化流程的数据库扎实理解。当这些元素交织在一个任务中时,可能会出现问题。

在本文中,我们将深入探讨数据回填的概念、其必要性以及高效实施的方法。无论你是回填新手还是经常对这类任务感到恐慌的人,这篇文章都会让你平静下来,帮助你重拾信心。

什么是回填?

回填是将过去缺失的数据填充到之前不存在的新表中,或用新记录替换旧数据的过程。它通常不是一个定期任务,仅在数据管道增量更新表时才需要。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

常规任务与回填任务的区别(作者创建)

例如,一个表按date列分区。一个常规的每日任务只更新最新的 2 个分区。相比之下,一个回填任务可以更新表中最初的所有分区。如果常规任务每次都更新整个表,那么回填任务就不必要了,因为历史数据会通过常规任务自然更新。

那么,我们什么时候需要回填呢?

通常,有一些常见的场景。让我们看看你是否觉得这些场景熟悉。

  • 创建新表并希望填补缺失的历史数据

你刚刚开发了一个新的表来分析每月的电子商务销售业绩。数据管道只选择发生在给定月份的交易。在部署数据管道时,它仅为当前月份生成报告。要生成历史月度报告,需要一个回填作业。回填作业中更新的分区数量取决于业务需求和源表中的数据可用性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用回填作业开发新表(由作者创建)

  • 修复数据管道中的错误并希望更新整个历史数据

哎呀,你发现了联接逻辑中的错误。它应该是左联接而不是内联接。你迅速修复了这个问题,以确保数据质量,但过去的数据呢?它仍在使用左联接。这里的回填作业是纠正历史数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用回填作业修复错误(由作者创建)

  • 数据管道停机并希望追赶

数据管道可能会经历几天的停机,从而导致数据缺口。一旦管道恢复,它需要赶上其计划的运行。幸运的是,大多数现代数据编排工具提供自动追赶功能,因此相比其他情况,需要的人工干预较少。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用回填作业追赶缺失的数据(由作者创建)

你有其他场景吗?欢迎在评论中与我们分享。

从图表中可以看出,回填是一个耗时的工作,因为它涉及许多分区。为了防止不必要的复杂性,最佳做法是首先询问团队和利益相关者关于回填数据的预期用途以及是否确实需要回填。回填的时间范围是什么?利益相关者是否能从回填中获得长期收益?随着公司的增长,表格也会增长。最终,我们可能会达到回填整个表格变得不可行的地步,需要决定在哪里截断。

另一个需要考虑的重要问题是我们是否有权限修改历史数据。在某些情况下,修改像财务数据这样的历史数据对公司可能意义重大,特别是当这些数据经过审计时。更新历史记录前理解业务影响至关重要,因为没有人愿意涉及法律问题。

表分区

现在,让我们看看回填的技术方面。如果不提及表分区,就无法讨论回填。分区是增量表更新的一种方法。分区列将表划分为一组分区。回填作业会一个接一个地更新这些分区。分区也是一种并行单元,允许多个回填作业同时执行。

每个分区的大小与分区数量之间存在权衡。更多的分区会导致数据更为细粒度地划分,回填作业会更加有针对性和具体。然而,通常建议不要创建过小的分区(例如,小于 1G),因为这样可能无法有效利用资源。想象一下比较打开一个 1GB 的文件和打开 10 个分别为 100MB 的文件——前者通常更高效。另一方面,过大的分区可能导致回填作业时间过长,因为这可能涉及到超出范围的数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

大分区和小分区之间的权衡(由作者创建)

每个数据仓库根据元数据的访问方式和资源的最佳利用情况有不同的推荐分区大小。此外,还需要考虑哪种粒度对表最有意义。例如,一种常见的分区策略是在date列上进行分区,其中我们期望数据在各天之间均匀分布。然而,如果每一天的数据量很小,那么另一种方法可以是按月份或年份进行分区。

回填策略

如前所述,回填作业需要数据工程技能的综合运用。成功的回填作业的大部分关键因素实际上是在作业本身开始之前就已经确定了。

将数据回填纳入初始表设计讨论

早期涉及回填讨论是一种良好的实践,因为这可能会影响表的设计,例如前面提到的分区策略。特别是当常规作业并不每次都更新整个表时,必须有一个回填计划,以便在需要时更新历史数据。

对于某些表,例如小型维度表,每次进行全表刷新可能是一个更为理想的选择,以避免需要回填。另一方面,对于具有一致线性增长的事实表,增量表更新更为可取,因为我们不希望云账单或服务器成本随着数据增长而线性增长。

使数据管道具有幂等性

幂等性指的是多次运行相同操作而不改变结果。这是每个数据工程师应该了解的基本数据管道设计原则。在代码更改之前,使用相同输入重新运行相同的 Airflow 任务应始终产生相同的输出。你不希望看到任何重复或不同的输出。因此,对于增量表更新,使用 replace 而非 append 模式以避免重复。此外,在转换逻辑中使用 Airflow 变量如 data_interval_end,而不是像 current_date() 这样的时间敏感函数,因为 current_date() 的输出会根据作业的执行时间而不同。

幂等性是成功回填的关键前提,它确保数据只根据预期的变化进行更改,而不受其他因素影响。在这个例子中,数据框中的 date 列始终表示预期的计划时间,而不是与任务的实际执行时间相关联。

def transform_data(data_interval_end):
    # do not use date.today()
    return pd.DataFrame(data={'date': [data_interval_end], 'quantity': [10]})

transform_task = PythonOperator(
    task_id='transform_data',
    python_callable=transform_data,
    op_args=[{{ data_interval_end}}],  
    provide_context=False,
    dag=dag,
)

对回填范围做出明智决策并执行

如果范围选择不当,回填作业可能会非常繁重。一个分区所花费的时间和金钱在回填作业中会被放大。你可以利用历史运行来提前估算成本和时间。如果估算超出预算,那么首先了解用例。这是为了进行一次性的分析吗?那么考虑在现有表上创建一个临时视图。范围对于用例来说是否太大?那么将分区大小减少到更可管理的水平。如果仍然太多,那么可能需要考虑使用更高效或更具成本效益的技术。

另一个非常重要的点是评估下游影响。在回填源表时,可能需要将回填扩展到下游表。我知道揭示所有隐藏连接可能很具挑战性。但如果这是你团队面临的重大挑战,考虑利用数据血缘工具系统性地识别所有下游依赖关系。

一旦范围定义好,就该采取行动了。幸运的是,许多数据工具原生支持回填。在 Airflow 中,你可以通过 Airflow UI 重新运行任务,也可以使用命令 airflow dags backfill。在 dbt 中,你可以使用命令 dbt run --full-refresh 或传递自定义变量,例如 dbt run -s my_model --vars '{"start":"2023-11-01"}'。像 DagsterMage 等其他工具也有自己运行回填作业的方法。

对于模式更改要小心。对于兼容的更改,如添加新列,许多数据工具会在回填作业中为第一个分区之前的记录填充空值。对于不兼容的更改,如删除列或更改数据类型,你需要重新创建整个表。

使用 DDL 或 DML 回填表格

好消息是,有替代方法可以回填表格,从而无需执行许多耗时的 Airflow 运行。实际上,我们往往只希望回填特定的列,而不是所有列。因此,针对无关列进行计算是资源的低效使用。

一个捷径是使用 DDL 或 DML 更新表格。例如,在quantity的变换从quantity = amount * price改为quantity = amount * price * exchange_rate的情况下。我们可以简单地使用UPDATE语句回填表格:

UPDATE my_table
SET quantity = amount * price * exchange_rate
WHERE date >= '2023-11-01'

在大多数情况下,这比在 Airflow 中运行回填作业更高效。对于不兼容的模式更改,如果重新创建整个表非常昂贵,可以考虑使用DDL删除列或更改数据类型。

并行回填作业

另一个优化技巧是将回填作业并行化。如果 10 个 Airflow 回填作业更新 10 个分区,它们可以在这些配置到位的情况下并行运行:

depend_on_past = False
max_active_runs = X # The maximum number of active DAG runs allowed for the DAG.
max_active_tasks = X # The total number of tasks that can run at the same time for a given DAG run.
concurrency = X # The maximum number of task instances allowed to run concurrently across all active DAG runs for a given DAG.
max_active_tis_per_dag = X # The maximum number of times that the same task can run concurrently across all DAG runs.

这种方法允许同时更新多个分区,消除了顺序等待的需要。然而,我们需要确保数据仓库支持并发写入并检查其并发级别。此外,分区之间不应有任何依赖关系,例如今天的分区不能基于昨天的分区进行计算。

常规运行中的回填

有时我们也希望在正常运行中自动“回填”表格。这是什么意思?例如,当常规批处理包含一些迟到的记录,需要对历史数据进行回溯更新时,就会发生这种情况。由于这种情况非常频繁,因此应将其纳入常规运行中,而不是手动触发它。

一个例子是统计电子商务中的累计总购买订单。现在,假设一种情况,客户在 11 月 1 日下了订单,但由于系统延迟,订单信息直到 11 月 3 日才被处理。当 11 月 3 日收到订单信息时,应该更新 11 月 1 日和 11 月 2 日的数据。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

延迟订单的交易数据(由作者创建)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

回填前后的汇总表(由作者创建)

在这种情况下,“内部回填”是由输入数据的更新触发的,而不是由转换逻辑触发的。根据记录的延迟情况,作业可能会更新多个分区。延迟越大,需要调整的分区就越多。因此,监控性能至关重要,可能需要实施另一种流程以防止常规作业过载。

# pseudo code
earliest_order_date = find_earliest_date_in_batch(new_batch)
partitions_to_be_updated = f"select * from summary where date >={earliest_order_date}"
# can be heavy
updated_partitions = update_historical_data(partitions_to_be_updated, new_batch)
update_table(updated_partitions)

回填后

哇,到目前为止有很多阅读内容。我很高兴你能看到这里。触发回填作业并不是过程的结束。我们必须积极监控性能,因为问题可能在过程中任何时候出现。逐个分区回填表的一个关键好处是,如果过程中出现问题,你可以灵活地从失败的分区恢复,而不是从头开始。

沟通是任何数据变更的关键。确保利益相关者参与过程。考虑在作业完成后创建脚本,自动发送通知并请求对回填表的所有用户进行验证。

结论

就是这样了!希望你喜欢,并以某种方式获得启发。回填作业是具有挑战性的,但它不应该是一个黑箱或让你感到害怕的东西。下次,不必在按下按钮前深呼吸并祈祷:))

对于已经熟悉回填的人。我希望你仍然从这篇文章中获得了一些见解。如果你有额外的技巧或窍门,请随时分享——我们很想听到你的声音!干杯!

参考

[## 数据与机器学习中的回填:入门 | Dagster 博客]

从糟糕的回填中恢复对任何数据工程师来说都是一个痛苦的经历。

dagster.io

揭示贝叶斯模型的奥秘:通过 SHAP 值揭示可解释性

原文:towardsdatascience.com/demystifying-bayesian-models-unveiling-explanability-through-shap-values-8405f618f4e0?source=collection_archive---------14-----------------------#2023-05-12

通过一个引人入胜的玩具示例探索 PyMC 的见解与 SHAP 框架

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shuyang Xiang

·

关注 发表在 Towards Data Science · 6 分钟阅读 · 2023 年 5 月 12 日

贝叶斯模型与可解释性之间的差距

SHAP 值(SHapley Additive exPlanations)是一种基于博弈论的方法,用于提高机器学习模型的透明度和可解释性。然而,这种方法以及其他机器学习可解释性框架,鲜有应用于贝叶斯模型,而贝叶斯模型提供了捕捉参数估计不确定性的后验分布,而不是经典机器学习模型使用的点估计。

虽然贝叶斯模型提供了一个灵活的框架来整合先验知识、调整数据限制和进行预测,但遗憾的是,使用 SHAP 对其进行解释是困难的。SHAP 将模型视为一个游戏,将每个特征视为该游戏中的一个玩家,但贝叶斯模型不是一个游戏。它更像是一个包含来自后验分布的参数的游戏集合。当模型不仅仅是一个游戏时,我们该如何解释它?

本文尝试通过玩具示例使用 SHAP 框架解释贝叶斯模型。该模型建立在 PyMC 上,PyMC 是一个用于 Python 的概率编程库,允许用户通过简单的 Python API 构建贝叶斯模型,并使用马尔可夫链蒙特卡罗方法对其进行拟合。

主要思想是将 SHAP 应用于从贝叶斯网络生成的确定性模型的集合。对于每个特征,我们将从生成的确定性模型中获得一个 SHAP 值样本。可解释性将由所有获得的 SHAP 值样本提供。我们将通过一个简单的示例来说明这种方法。

所有实现都可以在这个笔记本中找到。

使用 PyMC 进行贝叶斯建模

数据集

考虑以下由作者创建的数据集,其中包含 250 个点:变量 y 依赖于 x1 和 x2,两个变量都在 0 到 5 之间变化。下图说明了数据集:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片作者提供:数据集

让我们使用配对图快速探索数据。从中我们可以观察到以下几点:

  1. 变量 x1 和 x2 不相关。

  2. 两个变量在某种程度上都对输出 y 有贡献。也就是说,单一变量不足以获得 y。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片作者提供:数据的配对图

使用 PyMC 进行建模

让我们使用 PyMC 构建一个贝叶斯模型。在不深入讨论任何统计学书籍中可以找到的细节的情况下,我们只需回顾一下贝叶斯机器学习模型的训练过程涉及根据观察到的数据和先验知识使用贝叶斯规则来更新模型的参数。

我们将模型的结构定义如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片作者提供:模型结构

在定义了先验和似然之后,我们将使用 PyMC 标准采样算法 NUTS,该算法旨在自动调整其参数,例如步长和 leapfrog 步数,以实现对目标分布的有效探索。它通过树探索重复模拟点在参数空间中的轨迹,并确定是否接受或拒绝样本。此类迭代在达到最大迭代次数或达到收敛水平时停止。

你可以在下面的代码中看到,我们设置了先验,定义了似然,然后使用 PyMC 运行了采样算法。

让我们使用 PyMC 构建一个贝叶斯模型。贝叶斯机器学习模型训练涉及基于观察数据和先验知识更新模型参数,使用贝叶斯规则。我们不会在这里详细介绍,因为你可以在任何统计学书籍中找到。

我们可以定义如下的模型结构:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:模型结构

对于上述定义的先验和似然,我们将使用 PyMC 标准采样算法 NUTS。该算法旨在自动调整其参数,如步长和跳跃步数,以实现对目标分布的高效探索。它重复进行树形探索,以模拟点在参数空间中的轨迹,并决定是否接受或拒绝样本。迭代在达到最大迭代次数或实现收敛水平时停止。

在下面的代码中,我们设置先验,定义似然,然后使用 PyMC 运行采样算法。

with pm.Model() as model:

    # Set priors.
    intercept=pm.Uniform(name="intercept",lower=-10, upper=10)
    x1_slope=pm.Uniform(name="x1_slope",lower=-5, upper=5)
    x2_slope=pm.Uniform(name="x2_slope",lower=-5, upper=5)
    interaction_slope=pm.Uniform(name="interaction_slope",lower=-5, upper=5)
    sigma=pm.Uniform(name="sigma", lower=1, upper=5)

    # Set likelhood.
    likelihood = pm.Normal(name="y", mu=intercept + x1_slope*x1+x2_slope*x2+interaction_slope*x1*x2, \
                           sigma=sigma, observed=y)
    # Configure sampler.
    trace = pm.sample(5000, chains=5, tune=1000, target_accept=0.87, random_seed=SEED)

下面的踪迹图展示了模型中参数的后验分布。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:模型的后验

使用 SHAP 解释模型

我们现在希望在上述模型上实现 SHAP。注意,对于给定的输入(x1, x2),模型的输出 y 是条件概率。因此,通过从获得的后验中绘制一个样本,我们可以获得一个确定性的模型及其所有特征的 SHAP 值。或者,如果我们绘制一个参数样本的集合,我们将得到一个确定性模型的集合,因此,所有特征的 SHAP 值样本。

可以使用以下代码获得后验分布,我们每条链绘制 200 个样本:

with model: 
    idata = pm.sample_prior_predictive(samples=200, random_seed=SEED)
    idata.extend(pm.sample(200, tune=2000, random_seed=SEED)here

以下是后验数据变量的表格:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:后验样本

接下来,我们为每个绘制的模型参数样本计算一对 SHAP 值。下面的代码对参数进行循环,为每个参数样本定义一个模型,并计算感兴趣的 x_test=(2,3)的 SHAP 值。

background=np.hstack((x1.reshape((250,1)),x2.reshape((250,1))))
shap_values_list=[]
x_test=np.array([2,3]).reshape((-1,2))
for i in range(len(pos_intercept)): 
  model=SimpleModel(intercept=pos_intercept[i],
                    x1_slope=pos_x1_slope[i], 
                    x2_slope=pos_x2_slope[i], 
                    interaction_slope=pos_interaction_slope[i],
                    sigma=pos_sigma[i])
  explainer = shap.Explainer(model.predict, background)
  shap_values = explainer(x_test)
  shap_values_list.append(shap_values.values)

输入的二维 SHAP 值的结果集如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:SHAP 值样本

从上面的图表中,我们可以推断出以下内容:

  1. 两个维度的 SHAP 值大致形成一个正态分布。

  2. 第一个维度对模型有正贡献(中位数为-1.75),而第二个维度有负贡献(中位数为 3.45)。不过,第二个维度的贡献绝对值更大。

结论

本文探讨了 SHAP 值的使用,这是一种基于博弈论的方法,用于提高机器学习模型的透明度和可解释性,应用于贝叶斯模型。通过一个玩具示例演示了 SHAP 如何应用于贝叶斯网络。

请注意,SHAP 是模型无关的。因此,随着其实现方式的变化,未来可能可以直接将 SHAP 应用于贝叶斯模型本身。

揭示依赖关系及其在因果推断和因果验证中的重要性

原文:towardsdatascience.com/demystifying-dependence-and-why-it-is-important-in-causal-inference-and-causal-validation-4263b18d5f04

一步步了解依赖关系的概念及如何使用 Python 应用于验证有向无环图

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Graham Harrison

·发表于 Towards Data Science ·阅读时间 16 分钟·2023 年 11 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Ana Municio 提供,来源于 Unsplash

介绍

因果推断是数据科学的一个新兴分支,关注事件和结果之间的因果关系,它具有显著提升机器学习为组织创造的价值的潜力。

例如,传统的机器学习算法可以预测哪些贷款客户可能违约,从而实现对客户的主动干预。然而,尽管这个算法有助于减少贷款违约,但它并不了解违约发生的原因,而了解违约原因能够解决根本问题。在这种情况下,主动干预可能不再必要,因为导致违约的因素已经被彻底解决。

这就是因果推断的承诺,它具有为能够利用这一潜力的组织带来显著影响和成果的潜力。

有多种不同的方法,但最常见的方法通常是通过增加“有向无环图”来开始,这种图能够封装和可视化数据中的因果关系,然后使用因果推断技术提出“如果如何”的问题。

问题

封装数据中因果关系的有向无环图(DAG)通常由数据科学家和领域专家一起手动(或半手动)构建。因此,DAG 可能是错误的,这将使任何因果计算无效,导致错误的结论和可能的错误决策。

机会

存在一系列用于“因果验证”的技术(验证 DAG 是否与数据一致的过程),如果这些技术有效,它们可以最小化或消除 DAG 中的错误,从而确保计算和结论是无误的。

前进的道路

随机变量之间的统计学依赖概念可以用来确定 DAG 中存在的关系是否也存在于数据中;如果存在,则 DAG 更有可能是正确的,如果不存在,则更可能是错误的。

入门

我们需要一个示例 DAG 来解决问题,这个 DAG 具有足够的节点和链接,以便深入探索因果验证……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

本文中将使用的示例 DAG — 作者图片

DAG 中的每个节点要么对其他节点产生因果影响,要么其他节点对其产生因果影响,箭头的方向表示因果影响的方向。例如,“B”的一个原因是“C”,而“C”的一个原因是“F”。

示例 DAG 是虚构的,因此节点的字母/名称并不重要,不过“X”意图代表“处理”,而“Y”代表“效果”,所有其他节点对 X 对 Y 的真实效果在实际例子中会产生一些因果影响,从而掩盖 X 对 Y 的真实效果。

请注意,浅蓝色节点没有输入(在因果术语中称为外生),而深蓝色节点有一个或多个输入(在术语中称为内生)。

为了开始,我们还需要一些与 DAG 匹配的数据。下面的数据集完全是合成的,由作者生成。它完全封装并匹配 DAG 所建议的结构,并且没有错误或故障关系……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与 DAG 相关的合成虚构数据集 — 作者图片

我们开始之前还需要一种扩展 pandas DataFrameSeries 类以添加自定义方法的方式,以便我们编写的代码既简洁又易于理解。

这里有一个我之前文章的链接,提供了一个关于如何扩展数据框以及为什么这样做很有用的端到端教程……

## 如何通过自定义方法扩展 Pandas DataFrames 以增强代码功能性和可读性

一个逐步指南,介绍如何通过自定义方法扩展 pandas DataFrames,包括实施的完整示例 …

[towardsdatascience.com

理解依赖性

依赖性的一个定义如下 …

两个随机变量之间的依赖性意味着一个变量的发生或值会影响另一个变量的发生或值。如果一个变量的发生或值提供了关于另一个变量的发生或值的信息,则这两个变量被认为是相关的。

为了深入了解这一点,让我们再看一下我们的示例 DAG,并考虑影响节点 Y 的因果因素 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

突出显示影响 Y 的因果因素的 DAG — 作者提供的图片

在这个可视化中,我们可以看到节点 Y 是由 5 个不同因素造成的(因此也依赖于这些因素) — C、E、F、G 和 X。

现在让我们再看一下 DAG 所表示的数据 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

df_causal DataFrame 中前 5 行数据的回顾 — 作者提供的图片

这个合成数据集是作者为方便本文创建的,所以我知道节点 Y 与这些依赖因素之间的关系如下 …

Y = 3C + 3E + 2F + 2.5G + 1.5X + ε

(注:ε 代表误差项)

… 这一点可以通过选择一行(在这种情况下,我选择了第 3 行)并将该公式应用于数据来测试和验证 …

Y = -422.1827393983049, error term = 48.75941612372628

现在我们可以看到 Y 如何以及为何依赖于 C、E、F、G 和 X。如果这些依赖变量中的一个值发生变化,Y 的值也会变化。我们还可以从 DAG 中看到 Y 不应依赖(例如)节点 D,因为 D 和 Y 之间没有连接。

“Y 依赖于 C、E、F、G 和 X” 的表述可以用数学公式表示如下 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

… 以及“Y 与 D 独立” 的表述如下 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

⫫ 符号被称为“双向交叉符号”,但 ⫫̸ 符号没有一个普遍接受的名称,所以我个人习惯称之为“斜杠双向交叉符号”。

一些文章和文本使用单向交叉符号(⊥ 和 ⊥̸)代替双向交叉符号,但双向交叉符号更为常见,因此这是我在本文及相关 Python 代码中采用的标准。

回顾一下,两个随机变量之间的统计依赖意味着“一个变量的发生或数值影响另一个变量的发生或数值”,我们现在知道这在 DAG 中是如何可视化的,如何用数学公式表示(例如Y = 3C + 3E + 2F + 2.5G + 1.5X + ε),以及如何用斜杠双向箭头符号表示(例如 Y ⫫̸ C、E、F、G、X)。

从依赖关系到因果验证

因果推断通常从一组数据开始,然后用 DAG 扩充这些数据。虽然有些新兴技术可以从数据中反向工程生成 DAG,但它们并不准确或一致,因此开发 DAG 的最常见方法是询问领域专家他们认为的因果关系,然后验证或测试该 DAG 是否符合数据,并在验证失败时进行必要的修正。

DAG 已经提出 Y 依赖于 C、E、F、G 和 X,如果这种依赖在数据中存在,那么可以确信指向节点 Y 的因果链接是有效和正确的,并且可以用如下数学符号表示……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

这个看起来吓人的公式其实非常容易理解。第一个斜杠双向箭头符号的“G”下标表示“在图中”(即 DAG),而第二个“D”下标表示“在数据中”(注意,我见过一些文献中使用“P”下标,但“D”对我来说更有意义,因此我采用了“D”)。

具备这些知识后,整个公式可以被解读为“如果图中的 Y 依赖于 C、E、F、G 和 X,那么 Y 在数据中也应该依赖于 C、E、F、G 和 X”。

因此,我们只需要一个 Python 机制来检测数据中的依赖关系。然后可以使用该机制检查 DAG 中具有传入连接的每个节点,如果在数据中检测到的依赖关系与 DAG 中的匹配,我们可以合理地确信没有虚假的连接(因果链接),并且 DAG 在这方面是数据的有效表示。

观察数据中的依赖关系

让我们开始可视化数据中 C、E、F、G 和 X 与我们关注的节点 Y 之间的关系……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

右侧的图表将 Y 绘制在 x 轴上,将 C、E、F、G 和 X 分别绘制在 y 轴上。如果 Y 依赖于这些其他变量,那么改变其中一个变量的值应该会改变 Y 的值。这意味着应该存在一个正或负的系数,并且这些线应该表现出明显的斜率(向上或向下)。

由于存在明确的斜率,我们可以看到𝑌⫫̸ 𝐶,𝐸,𝐹,𝐺,𝑋是正确的,即 Y 在数据中依赖于 C、E、F、G 和 X**。

但是,如果没有依赖关系,那么改变变量的值对 Y 的影响很小或没有影响,系数应该接近零,且直线应该没有斜率,即应为平坦的。

通过将 Y 和 D 之间的关系添加到图表中可以证明这一点,记住在 DAG 中从 D 到 Y 没有因果联系,因此数据中 Y 和 D 之间也不应该有关系……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

这正是我们期望的结果。C、E、F、G 和 X 都有明显的斜率,并且具有负或正的系数,清楚地表明如果这些变量的值发生变化,Y 的值也会发生变化,因此 Y 依赖于这些变量。

然而 D 的斜率平坦,系数非常小(仅为-0.029),因此改变 D 的值对 Y 的值几乎没有影响,因此因果关系𝑌⫫𝐷(Y 与 D 无关)在数据中存在。

在 Python 中实现数据中的依赖关系

检测数据中依赖关系的提议方法使用了来自 statsmodels.formula.api 库的 ols 类,以执行普通最小二乘(OLS)回归。

可以将 ols 类拟合到数据集中,然后提取和解释数据中存在的系数或斜率。以下是操作方法……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

总结中的关键数据是中间的表格,该表格对变量 C、E、F、G 和 X 与 Y 之间的关系进行了一些分析。例如,ols 分析提出了以下内容——

𝑌=2.03𝐶+3.02𝐸+1.84𝐹+6.33𝐺+1.54𝑋−25.2

这与我用来创建数据集的公式相距不远,公式是……

𝑌=3𝐶+3𝐸+2𝐹+2.5𝐺+1.5𝑋+ε

最大的差异在于节点 G,但出于验证目的,系数的大小并不重要,只要系数存在且斜率不是平坦的即可。

除了coef列外,另一个值得关注的项目是P>|t|或 p 值列,其工作方式如下……

  1. 零假设是变量(例如 E)与因变量(例如 Y)之间没有关系。

  2. 如果 p 值大于 alpha(通常设定为 0.05),则拒绝零假设,即存在关系,即存在依赖性。

例如,E、G 和 X 的 p 值都低于 0.05,因此可以拒绝零假设并假定存在依赖关系。

那么 C 和 F 呢?C 的 p 值为 0.076,略高于 alpha,而 F 的值为 0.275,明显高于我们选择的 alpha(0.05)。

我们可以简单地增加 alpha,直到我们得出所有变量都是依赖的结论,但这种方法从长远来看效果不好,因为它会开始得出不存在的依赖结论。

当我进行最初开发时,我几乎在这一点上放弃了,认为 ols 不能作为检测我的 DAG 和数据中依赖关系的可靠方法,但随后我重新审视了 ols 分析。

对所有 5 个变量可以观察到系数,但 p 值仅在 5 个中的 3 个上具有决定性。我随后转而使用coef,但在后续的过程中,我发现 p 值有效而coef无效的情况。

在经历了许多令人沮丧的小时和大量的反复试验后,我建立了一种方法,它结合了两个值,并且展现出高度的准确性,经严格测试对比了大量不同的数据和 DAG。

这是我用来检测依赖关系的方法…

VALIDATION SUCCESS: Y is dependent on C in the data
VALIDATION SUCCESS: Y is dependent on E in the data
VALIDATION SUCCESS: Y is dependent on F in the data
VALIDATION SUCCESS: Y is dependent on G in the data
VALIDATION SUCCESS: Y is dependent on X in the data

我通过反复试验采用的测试方法如下…

如果 p 值大于 0.05 且系数小于或等于 1.0,则假定没有依赖关系,否则假定存在依赖关系。

这种方法并不遵循统计方法,仅仅考虑 p 值,但大量测试表明它非常可靠。

优化 Python 代码

上述方法的一个缺点是公式嵌入在代码中,即在ols_formula = "Y ~ C + E + F + G + X"以及dependent_variablevariables的声明中,这会在实际示例中导致代码重复。

如果能找到一种方法来扩展DataFrame类,以便能在任何数据集上通用地进行依赖性测试,那将会更好。

幸运的是,通过使用一种称为“猴子补丁”的技术,向DataFrame类添加自定义方法非常简单。如果你想要逐步教程,请查看我的教程文章…

## 如何扩展 Pandas DataFrames 以增强代码功能性和可读性

一步步扩展 pandas DataFrames 的自定义方法的指南,包括实现的完整示例…

[towardsdatascience.com

这是优化后的代码,它能够在任何数据集上执行任何依赖性测试…

一旦DataFrame类扩展了dependence方法,测试任何依赖性测试将变得非常简单。

例如,我们可以尝试𝑌⫫̸𝐶,𝐸,𝐹,𝐺,𝑋,这应该验证为True且没有错误…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们可以尝试𝑌⫫̸𝐶,𝐸,𝐹,𝐺,𝑋,𝐷,这应该验证为False,表示"D"是一个错误,因为 Y 不依赖于它…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这些测试都通过了,并且在我尝试过的所有 DAG 和数据集中成功率非常高,以验证这种方法的准确性。

汇总所有内容

总结来说,上述相对较小的代码库实现了令人印象深刻的结果,即能够对任何数据集进行任何依赖性测试,以指示该测试是否通过,并在测试失败时具体突出显示错误。

但这还不够。让我们假设在咨询我们的领域专家时,他们产生的 DAG 包含一个错误,而这些专家假设节点 D 到节点 Y 之间存在因果链(或依赖关系)。

拟议的 DAG 现在看起来是这样的 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

凭借我们的新能力,我们可以轻松地对节点 Y 进行 DAG 测试,方法如下 …

… 正如我们在上面的结果中看到的,节点“D”将被准确识别为“错误”。因此,我们已经识别出一个“虚假边”,即在 DAG 中存在但在数据中不存在的链,这告诉我们 DAG 必须进行调整,以移除那个虚假边以确保准确性。

因此,以下条件必须成立 …

  1. 从一个拟议的 DAG 开始。

  2. 遍历所有节点。

  3. 对所有输入连接执行依赖性测试。

  4. 收集所有错误的列表。

积累的错误列表将立即指示所有虚假边/连接/依赖关系,这些都必须从拟议的 DAG 中移除,以生成一个没有所有虚假边的新 DAG(即在 DAG 中存在但数据中不存在的依赖关系)。

实现这一点的代码如下 …

测试完整算法以检测 DAG 中的虚假边

使用这几行代码,现在可以测试任何 DAG(由一组边表示)与任何数据(由 pandas DataFrame表示)以查看 DAG 中是否存在数据中不存在的“虚假”边。

让我们从测试 DAG 正确表示数据中所有因果链的情况开始(记住df_causal正确表示 DAG,因为它是作者合成创建的准确表示)…

A ⫫̸ D
B ⫫̸ A, C
C ⫫̸ D, F
E ⫫̸ C
X ⫫̸ A, B, E, F, G
Y ⫫̸ C, E, F, G, X
[]

在 DAG 与数据匹配的情况下没有检测到错误。

现在,让我们在 DAG 中添加一个不存在的因果链 D => Y,并重新运行代码 …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

A ⫫̸ D
B ⫫̸ A, C
C ⫫̸ D, F
E ⫫̸ C
X ⫫̸ A, B, E, F, G
Y ⫫̸ C, D, E, F, G, X
[('D', 'Y')]

“虚假”边在 DAG 中被正确识别!但是当 DAG 中存在多个在数据中不存在的虚假因果关系时,我们的算法还会有效吗?

为了测试这一点,在 DAG 中添加一个第二个不存在的因果链 A => E …

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

A ⫫̸ D
B ⫫̸ A, C
C ⫫̸ D, F
E ⫫̸ A, C
X ⫫̸ A, B, E, F, G
Y ⫫̸ C, D, E, F, G, X
[('A', 'E'), ('D', 'Y')]

这个测试也通过了。如果向 DAG 中添加两个在数据中不存在的虚假因果关系,它们都能被正确检测到并识别为错误。

对算法进行破坏性测试

这些有希望的结果引发了一个问题:“那么,这种方法到底有多准确?”即在 DAG 中可以继续添加多少虚假因果关系而不被正确检测到。

为了回答这些问题,作者设计了一个具有挑战性的测试,首先识别 DAG 中每一个可能存在但实际上不存在的有效因果链接。对于这个特定的 DAG,所有可能链接的完整集合如下……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

随后,使用测试工具随机选择任何 3 个可能缺失的链接,同时重复测试不同的集合,以确定验证算法的准确性。

结果令人震惊。这里提出的简单算法能够以 100% 的准确率检测任何 3 个虚假链接的组合(使用示例 DAG 和数据)。即便将测试改为选择任何 12 个可能的虚假链接,一样可以达到 90% 的准确率!

附录部分:分开与组合依赖测试

在整篇文章中,通过查看所有的“父”节点来建立给定节点的依赖集合,然后创建一个依赖声明,例如……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

你可能会想知道相同的测试集是否等效……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

作者面临的挑战之一是假设这些单独的测试等同于检测虚假边的单个整体测试,但测试中的试错过程得出了明确结论,即情况并非如此。

在寻找虚假边 Y => D 时,实现 𝑌⫫̸𝐶,𝐸,𝐹,𝐺,𝑋,𝐷 测试是 100% 可靠的,但单独测试 𝑌⫫̸𝐷 不起作用,这通过执行多轮自动化测试来比较这两种方法的准确性得到了证明。

假设是因为封装这些变量之间关系的公式是𝑌 = 3𝐶 + 3𝐸 + 2𝐹 + 2.5𝐺 + 1.5𝑋 + ε,实现依赖的 OLS 测试需要考虑所有变量在一起,这也验证了因果推断中的另一个真理……

从数据中逆向工程一个 DAG 是非常困难甚至可能不可能,但当进行“初步尝试”并接近目标时,任务变得可实现

本节的要点是:在测试依赖关系时,考虑每个节点的所有输入关系,因为如果单独测试,它根本不起作用。

连接并保持联系……

如果你喜欢这篇文章,你可以通过成为 Medium 会员,每月仅需 5 美元,即可无限访问更多内容,点击我的推荐链接(如果你通过此链接注册,我将获得费用的一部分,且对你没有额外费用)。

[## 通过我的推荐链接加入 Medium - Graham Harrison

作为 Medium 会员,你的会员费用的一部分将会流向你阅读的作者,你可以全面访问所有故事…

grahamharrison-86487.medium.com](https://grahamharrison-86487.medium.com/membership?source=post_page-----4263b18d5f04--------------------------------)

… 或者通过 … 连接

订阅我的免费电子邮件,以便在我发布新故事时及时获取

快速浏览我的上一篇文章

下载我的免费战略数据驱动决策框架

访问我的数据科学网站 — 数据博客

揭秘 DreamBooth:一种个性化文本到图像生成的新工具

原文:towardsdatascience.com/demystifying-dreambooth-a-new-tool-for-personalizing-text-to-image-generation-70f8bb0cfa30

探索将无聊图像转化为创意杰作的技术

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mario Larcher

·发表于数据科学前沿 ·阅读时间 13 分钟·2023 年 6 月 13 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Dougie 和他的新个性由作者使用 DreamBooth 创建。你能猜出提示是什么吗?

介绍

想象一下,你轻松地生成一张你心爱的幼犬在雅典卫城背景下的新图像的喜悦。如果还不满足,你还想看看梵高会如何绘制你的好友,或者他如果被狮子所构思会是什么样子😱!感谢 DreamBooth,这一切都变为现实,如今可以让任何动物、物体或我们自己从一小堆图像中旅行于幻想世界。

尽管我们许多人已经在社交媒体上看到了利用这项技术可以取得的令人瞩目的成果,而且有大量教程可以让我们在自己的照片上进行尝试,但很少有人尝试回答这样一个问题:是的,那么它到底是如何工作的呢?

在本文中,我将尽力解析 Ruiz 等人发表的科学论文DreamBooth: 针对主题驱动生成的文本到图像扩散模型微调,这篇论文是所有这一切的起点。但别担心,我会简化复杂的部分,并在需要一些先验知识的地方进行解释。现在,请注意,这是一项高级话题,因此我假设你已经掌握了深度学习及相关内容的基础知识。如果你想深入了解扩散模型或其他有趣的话题,我会在过程中提供一些参考。让我们开始吧!

相关工作

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7 来自DreamBooth: 针对主题驱动生成的文本到图像扩散模型微调

在我们深入探讨 DreamBooth 的方法之前,让我们先仔细了解一下与该技术相关的工作和任务。

图像合成

在日常生活的喧嚣中,你心爱的背包已经很久没有踏上环球之旅。现在是给它注入刺激冒险的时刻,同时你也在规划下一次假期。通过图像合成,将你的背包无缝融入新的背景,让它在几秒钟内从大峡谷到波士顿。

如果简单地复制粘贴主题不能满足你对新视角的渴望,可以尝试探索 3D 重建技术的应用。然而,需要注意的是,这些技术主要针对刚性物体,并且通常需要大量的起始视图。

DreamBooth 引入了一项卓越的能力,可以在新的背景中生成新姿势,同时顺畅地融入关键元素,如光线、阴影和其他与场景相关的方面。实现这种一致性在以往的方法中一直是一个挑战。在论文中,这项任务也被称为重新背景化。

文本到图像编辑与合成

基于文本输入的图像编辑是许多照片编辑软件爱好者的一个秘密梦想。早期的方法,例如使用 GANs 的方法,展示了令人印象深刻的结果,但仅限于像编辑人脸这样结构良好的场景。

即使是利用扩散模型的新方法也有其局限性,通常仅限于全局编辑。直到最近,像Text2LIVE这样的进展才允许局部编辑。然而,这些技术都无法在新的背景中生成特定的主题。

尽管像ImagenDALL·E 2Stable Diffusion这样的文本图像合成模型取得了显著进展,但在合成图像中实现精细控制并保留主题身份仍然面临重大挑战。

可控生成模型

为了避免对主题进行修改,许多方法依赖于用户提供的掩码来限制修改的区域。逆转技术,如DALL·E 2使用的技术,提供了一个有效的解决方案,可以在修改背景的同时保留主题。

Prompt-to-Prompt使得本地和全局编辑成为可能,无需输入掩码。

然而,这些方法在生成新样本时无法充分保留主题的身份。

尽管一些基于 GAN 的方法专注于生成实例变体,但它们往往有局限性。例如,它们主要设计用于面部领域,需要大量的输入样本,难以处理独特的主题,并且无法保留重要的主题细节。

最后,最近 Gal 等人提出了文本反演,这是一种具有 DreamBooth 共同特征的方法论,但正如我们将看到的,它受到基于其的冻结扩散模型表现力的限制。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2 来自图像胜于千言:使用文本反演个性化文本到图像生成

由于这是作者用来与 DreamBooth 进行比较的工作,值得提供一个简要的描述。

文本反演从一个预训练的扩散模型开始,如潜在扩散模型,并定义一个新的占位符字符串 S*,以表示需要学习的新概念。在此阶段,保持扩散模型冻结,新的嵌入从仅 3-5 张图像中进行微调,类似于 DreamBooth。如果这个简要描述不够清楚,请等到你阅读更详细的 DreamBooth 描述时,它与这项工作有许多共同点。

方法

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3 来自DreamBooth:为主题驱动生成微调文本到图像扩散模型

在详细描述DreamBooth的组件之前,让我们简要了解一下这项技术的工作原理:

  1. 选择 3-5 张你喜欢的主题图像,可以是动物、物体,甚至是像艺术风格这样的抽象概念。

  2. 将这个概念与一个稀有词汇关联,该词汇对应一个唯一的标记,将从现在开始表示它,在科学论文中,作者称这个词为[V]。

  3. 使用兴趣主题的图像,通过简单的提示如“一个[V] [类别名]”来微调模型,例如,如果输入图像是你的狗的照片,则为“一个[V] 狗”。

  4. 由于我们正在微调模型的所有参数,因此有风险在这个阶段所有的狗(或我们主题的任何类别)都会变成与我们的输入图像相同。为了避免模型的这种退化,我们从冻结的模型生成图像,使用像“狗”(或“[类别名]”)这样的提示,并添加一个损失函数,当我们为这个提示微调的模型生成的图像偏离冻结模型生成的图像时,会受到惩罚。

好的,现在我们对过程有了一个高层次的了解,让我们详细讨论各种组件。

文本到图像扩散模型

你真的想了解扩散模型的工作原理,尤其是像稳定扩散这样的潜在扩散模型吗?请阅读我之前的文章,当你读完后,我会在这里等你!

论文解读——基于潜在扩散模型的高分辨率图像合成

虽然 OpenAI 凭借其生成文本模型主导了自然语言处理领域,但他们的图像…

towardsdatascience.com

好吧,也许你不需要完整的解释,在这种情况下,我将提供扩散模型背后的直观理解,这非常简单。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2. 来自 Denoising Diffusion Probabilistic Models

  1. 取一个图像 x0,并添加一定量的噪声(例如,高斯噪声),噪声量与某个时间步* t 成比例。如果t为零,则添加的噪声为零,如果t* > 0,则添加的噪声将与t的大小一样,直到你得到一个仅由噪声组成的图像。

  2. 训练一个模型,如 U-Net,通过将时间步* t *和受损图像作为输入来预测无噪声图像(或添加的噪声)。

  3. 此时,经过训练一个可以去除图像噪声的模型,我们可以采样一个仅由噪声组成的图像,并逐渐去除噪声(一次性完成效果不好),可以通过预测无噪声的图像或预测噪声并从图像中减去来实现。

  4. 前三点描述了无条件扩散模型。为了根据文本提示生成条件输出,文本使用像 CLIP 的模型进行编码,或者使用如 BERTT5 等语言模型。这个编码步骤允许集成额外的信息,然后将其与受损图像和时间步* t *一起输入模型。

论文中的作者使用了两个扩散模型:Google 的 Imagen(DreamBooth 也来自 Google Research)和 Stable Diffusion来自 Stability AI,这是主要的开源文本到图像模型。

Imagen 采用多分辨率策略来提高生成图像的质量。最初,使用低分辨率 64x64 图像训练扩散模型。然后,低分辨率模型的输出通过两个额外的扩散模型进行放大,这些模型在更高分辨率下操作,分别为 256x256 和 1024x1024。第一个模型专注于捕捉宏观细节,而随后的模型则通过利用较低分辨率模型生成图像的条件效应来精细化输出。这种迭代优化有助于生成高分辨率的图像,具有更好的质量和保真度。

Stable Diffusion 作为一种潜在扩散模型,采用三步法来提高训练和生成高分辨率图像的效率。最初,训练一个变分自编码器(VAE)以压缩高分辨率图像。从此之后,过程与标准扩散模型非常相似,一个关键区别在于:不是使用原始图像作为输入,而是使用由 VAE 编码器生成的潜在表示。随后,逆扩散过程的输出通过 VAE 解码器恢复到原始分辨率。为了更全面地理解整个过程,我在上述文章中进行了更详细的探讨。

文本到图像模型的个性化

DreamBooth 旨在将主题实例(例如你的狗)置于模型的输出领域内,使模型能够在查询时生成主题的新图像。扩散模型的一个优势是,与 GANs 相比,它们能够有效地将新信息纳入其领域,同时保留对先前数据的知识,并避免对有限的训练图像集的过拟合。

为少样本个性化设计提示

如前所述,该模型通过使用“一个 [identifier] [class noun]”结构的简单提示进行训练。这里,[identifier] 代表与主题相关的独特标识符,而 [class noun] 作为主题类别的一般描述(如猫、狗、手表等)。作者将类名纳入提示中,以建立通用类别与我们个体主题之间的联系,观察到使用不正确或缺失的类名会导致更长的训练时间和语言漂移,最终影响性能。本质上,主要目的是利用特定类别与我们的主题之间的关系,利用模型对该类别已有的知识。这使我们能够在各种上下文中生成新颖的姿势和变体。

稀有标记标识符

论文强调,普通的英语单词在这种情况下并不理想,因为模型需要将它们与原始含义脱离,并重新整合以指代我们的主题。

为了解决这个问题,作者提出使用在语言和扩散模型中都有弱先验的标识符。虽然选择像“xxy5syt00”这样的随机字符可能最初看起来很吸引人,但这也存在潜在风险。需要考虑的是,分词器可能会将每个字母单独分词。那么解决方案是什么?最有效的方法是识别词汇表中不常见的标记,然后在文本空间中反转这些标记。这可以最小化标识符具有强先验的可能性。

有趣的是,大多数教程使用“sks”来实现这一目的,但正如其中一位作者指出的,这个看似无害的词可能会产生副作用……

类别特定先验保持损失

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation中的图 6。

与文本反演不同,DreamBooth 微调模型的所有层以最大化性能。不幸的是,这样做会遇到众所周知的语言漂移问题,即当一个模型最初在一个广泛的文本语料库上进行预训练,然后再针对特定任务进行微调时,它会逐渐减少对语言语法和语义的理解。

另一个问题是输出多样性的潜在减少。这可以从图 6 的第二行中观察到,在该图中,模型,除非进一步调整,否则有倾向仅复制输入图像中找到的姿势。当模型训练时间较长时,这种效果变得更加明显。

为了减轻这些问题,作者引入了类别特定先验保持损失,让我们先看看其整体损失公式,然后再解释其组成部分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation中的公式 2。

第一部分是标准的 L2 去噪误差,这是任何扩散模型的典型特征。α_t 将初始图像x 缩放,然后添加高斯噪声 εN (0, I),乘以 σ_t。随机变量 z_t := α_tx** + σ_tε** 的分布为 N(α_t*x**, σ_t²)。此时,模型 xˆ 将尝试从 z_tt 和条件向量 c = Γ(P) 预测原始图像,其中在 DreamBooth 的情况下,Γ 是 T5,而提示 P 的形式为“一个 [标识符] [类别名词]”。

第二部分是 先验保留损失,在这里,x 被替换为 xpr,即由模型生成的图像,模型的权重被冻结(在微调之前),从随机初始噪声 z1N (0, I) 和条件向量 c_pr = Γ(“一个 [类别名词]”)。这一部分促使模型从其损坏版本中重新获取 x_pr,从而促使模型生成类似于在微调过程之前生成的图像。

最后,w_tw_t’ 是与噪声调度相关的术语,λ 定义了两个损失之间的相对权重。

实验

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4 来自 DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation

数据集

实验使用的数据集由作者生成,包含 30 个主题,包括独特的物品,如背包或太阳镜,以及动物,如狗、猫等。在这 30 个主题中,21 个是物体,9 个是活体主题/宠物。

作者定义了 25 个提示:20 个重新背景化提示和 5 个物体属性修改提示;10 个重新背景化提示,10 个配件化提示和 5 个活体主题/宠物属性修改提示。

评估指标

为了评估,每个主题和每个提示生成四张图像,共计 3,000 张图像。

为了测量 主题一致性,使用 CLIP-I 和 DINO。

CLIP-I,在之前的工作中已经使用,计算生成图像和真实图像的CLIP嵌入的平均成对余弦相似度。CLIP 的训练目标是使文本描述的嵌入与其所指的图像具有相同的嵌入,因此如果两个图像表示相同的文本,它们将具有相似的嵌入。

DINO,由作者引入的新指标,类似于 CLIP-I,但生成嵌入的方式是使用ViT-S/16 DINO,一个自监督训练的模型。

论文中观察到,由于 CLIP 的训练方式,CLIP-I 不区分可能具有非常相似文本描述的不同主题。另一方面,DINO(指的是模型,而不是指标)以自监督的方式进行训练,这有助于区分主题或图像中的独特特征。因此,他们将 DINO 视为主要指标。

最后,引入了第三个度量指标 CLIP-T,用来衡量另一个重要方面:提示一致性,即生成的图像与输入提示的接近程度。

CLIP-T 与之前的指标类似,测量从 CLIP 中获得的两个嵌入之间的平均余弦相似度:一个来自提示,另一个来自图像。值得注意的是,CLIP 特别训练以生成与对应图像的文本嵌入相似的嵌入。

比较

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 1 来自于 DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation

从表 1 可以看出,当使用 DINO 和 CLIP-T 指标测量时,DreamBooth 明显优于 Textual Inversion,而在使用 CLIP-I 测量时差距较小,但如前所述,CLIP-I 并不是一个好的衡量特定主题保真度的指标。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 2 来自于 DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation

很难找到一个与个人判断结果好坏完全一致的指标。因此,作者们还测量了一组 72 名用户的偏好。结果突显出,对于主题保真度和提示保真度,偏好 DreamBooth 的用户百分比要高于仅凭之前的指标所能得出的结论。我们可以通过查看论文中的图 4 来判断,两种方法之间的显著差异在这个特定例子中是显而易见的。

结论

图像生成和生成式 AI 领域近年来获得了显著关注。特别是通过扩散模型的使用,图像合成的进展推动了这一领域的发展。

在本文中,我们深入探讨了 DreamBooth 的科学论文——这是一种令人印象深刻的解决方案,能够生成具有不同姿势和背景的新图像,同时保持对期望主题的忠实。这种创新的方法展示了图像合成领域取得的显著进展,并对未来的发展具有巨大潜力。

感谢您抽出时间阅读本文,如有任何意见或问题,请随时留言或与我联系。要了解我的最新文章,您可以关注我在 MediumLinkedInTwitter 上的动态。

[## 通过我的推荐链接加入 Medium - Mario Namtao Shianti Larcher

阅读 Mario Namtao Shianti Larcher 的每一个故事(以及 Medium 上的其他成千上万位作家的故事)。您的会员费……

medium.com](https://medium.com/@mnslarcher/membership?source=post_page-----70f8bb0cfa30--------------------------------)

解密 GQA — 高效 LLM 预训练的分组查询注意力

原文:towardsdatascience.com/demystifying-gqa-grouped-query-attention-3fb97b678e4a?source=collection_archive---------0-----------------------#2023-12-27

驱动像 LLaMA-2、Mistral7B 等大语言模型的多头注意力变体

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Bhavin Jawade

·

关注 发布在 Towards Data Science ·6 分钟阅读·2023 年 12 月 27 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一组“驼鹿”(来源 — 作者使用 Dalle-3 创建的图像)

在前一篇关于训练大规模模型的文章中,我们探讨了 LoRA。在这篇文章中,我们将研究另一种被不同大语言模型采用的高效训练策略——分组查询注意力(GQA)。简而言之,分组查询注意力(GQA)是多头注意力(MHA)和多查询注意力(MQA)的推广——它们都是 GQA 的特例。因此,在深入探讨分组查询注意力之前,我们先回顾一下 Vaswani 等人提出的经典“Attention is All You Need”论文中的传统多头注意力。接下来,我们将探索多查询注意力及其如何解决 MHA 面临的挑战。最后,我们将回答“什么是 GQA?”和“它如何使我们兼得两全?”的问题。

多头注意力是变换器模型的关键组件,使其能够高效地处理和理解复杂序列任务,如语言翻译、摘要生成等。要掌握其复杂性,我们必须深入研究其数学基础,并理解注意力机制中多个头的功能。

基本的注意力机制计算值的加权和,加权依赖于查询和一组键。在数学上,这表示为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这被称为缩放点积注意力。在这个方程中,Q(查询)和 K(键)是表示查询和键的矩阵。V(值)是值的矩阵。“d_k”是键的维度,用于缩放。

扩展到多头注意力(MHA)

多头注意力使用多个‘头’的注意力层,使模型能够关注来自不同表示子空间的信息。在每个头中,有一组独立的线性层(投影矩阵)用于查询、键和值(这是一个重要的点,我们将在 GQA 中重述)。对于每个头(编号为 h):

headʰ = Attention(Q.Wqʰ,K.Wkʰ,V.Wvʰ)

连接头输出

各个头的输出被连接起来,然后进行线性变换。

MultiHead(Q,K,V) = Concat(head¹,head²,…,headʰ) .W

Wᵒ是另一个权重矩阵,用于将连接向量线性变换为最终输出维度。

多头注意力的直观理解是,通过并行地多次应用注意力机制,模型可以捕捉数据中不同类型的关系。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

插图描绘了缩放点积注意力、多头注意力在变换器编码器块中的应用。(来源——图示部分来自《Attention is All You Need》论文 arxiv.org/abs/1706.03762,作者编排)

然而,MHA 使得对输入的不同部分之间关系的理解更加细致。尽管如此,这种复杂性也带来了代价 — 对内存带宽的巨大需求,尤其是在解码器推理期间。

多头注意力中的内存带宽挑战

问题的关键在于内存开销。像 Transformers 这样的自回归模型中的每个解码步骤都需要加载解码器权重以及所有注意力键和值。这一过程不仅计算密集,而且内存带宽密集。随着模型大小的增长,这种开销也会增加,使得扩展变得越来越困难。

多查询注意力(MQA)的出现

多查询注意力(MQA)作为缓解这一瓶颈的解决方案出现。这个想法简单却有效:使用多个查询头,但只有一个键和值头。这种方法显著减少了内存负担,提高了推理速度。它已经在多个大规模模型中得到应用,如 PaLM、StarCoder 和 Falcon。

在多查询注意力中,我们对键和值的头进行平均,以便所有查询头共享相同的键和值头。这是通过将平均池化的“头”复制 H 次来实现的,其中 H 是查询头的数量。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

左侧 — 多头注意力,中间 — 多查询注意力,右侧 — 将现有的 MHA 检查点转换为 MQA(来源 — arxiv.org/pdf/2305.13245.pdf

这里一个有趣的问题是 — 如何将现有的预训练多头注意力模型转换为多查询注意力模型(MQA)?从现有的多头模型创建一个多查询注意力模型涉及两个步骤:模型结构的转换和随后的预训练。[1]

检查点转换:这一步将多头模型的结构转换为多查询模型。通过将原始模型中多个头的键和值的投影矩阵(线性层)合并(均值池化)为单个键和值的投影矩阵来实现。这种均值池化的方法被发现比选择现有的一个键和值头或从头初始化新的键和值头更有效。结果结构具有整合的键和值投影,具有多查询模型的特征。

预训练转换后的模型:在结构转换后,模型经历额外的训练。这次训练没有原始模型训练那么广泛;它是原始模型训练步骤的一部分(记作 α)。这个预训练阶段的目的是让模型根据其新的简化注意力机制调整和优化性能。训练遵循与原始模型相同的步骤,以确保学习动态的一致性。

然而,MQA 也并非没有缺点。减少的复杂性可能导致质量下降和训练不稳定。

分组查询注意力

分组查询注意力(GQA)是一种简单的方法,它结合了多头注意力(MHA)和多查询注意力(MQA)的元素,以创建一个更高效的注意力机制。GQA 的数学框架可以理解如下:

分组:在 GQA 中,传统多头模型中的查询头(Q)被分成 G 组。每组分配一个单独的键(K)和值(V)头。这种配置被称为 GQA-G,其中 G 代表组的数量。

GQA 的特殊情况:

  • GQA-1 = MQA:当只有一个组(G = 1)时,GQA 等同于 MQA,因为所有查询头只有一个键和值头。

  • GQA-H = MHA:当组的数量等于头的数量(G = H)时,GQA 的行为类似于传统的 MHA,每个查询头都有其独特的键和值头。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MHA、GQA 和 MQA 的区别(来源 — arxiv.org/pdf/2305.13245.pdf

我们对每组内原始头的键和值投影矩阵进行均值池化,将多头模型转换为 GQA 模型。这种技术对每组中每个头的投影矩阵进行平均,从而为该组生成一个单独的键和值投影。

通过利用 GQA,模型在 MHA 的质量和 MQA 的速度之间保持平衡。由于键值对较少,内存带宽和数据加载需求被最小化。G 的选择呈现一种折中:更多的组(接近 MHA)会导致更高的质量但较慢的性能,而较少的组(接近 MQA)则提高了速度,但可能牺牲质量。此外,随着模型规模的增长,GQA 允许内存带宽和模型容量按比例减少,与模型的规模相匹配。相比之下,对于更大的模型,在 MQA 中将其减少到一个单独的键和值头可能会过于严苛。

结论

在这篇文章中,我们首先介绍了传统的多头注意力(MHA)及其变体多查询注意力。然后我们探讨了一个更通用的公式 GQA,它被许多 LLM 模型用于有效的预训练。GQA 将多头注意力(MHA)与多查询注意力(MQA)结合起来,在质量和速度之间提供了一个公平的折中。GQA 通过将查询头分组来最小化内存带宽需求,使其适合于模型的扩展。GQA 已被用于取代典型的多头注意力,在最近的模型中如 LLaMA-2 和 Mistral7B。

参考文献: [1] GQA:从多头检查点训练通用化的多查询 Transformer 模型 — arxiv.org/pdf/2305.13245.pdf

[2] MQA:快速 Transformer 解码:一个写头就足够了 — arxiv.org/abs/1911.02150

[3] MHA: 注意力即一切: arxiv.org/abs/1706.03762

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值