TowardsDataScience 2023 博客中文翻译(二百四十三)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

使用 RetinaNet 和 KerasCV 的目标检测

原文:towardsdatascience.com/object-detection-using-retinanet-and-kerascv-b07940327b6c?source=collection_archive---------3-----------------------#2023-12-06

使用 KerasCV 库的力量和简便性进行目标检测。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ed Izaguirre

·

关注 发表在 Towards Data Science · 21 分钟阅读 · 2023 年 12 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一张植物叶子的图像。创建于 DALL·E 2

目录

  1. 等等,什么是 KerasCV?

  2. 检查数据

  3. 图像预处理

  4. RetinaNet 模型背景

  5. 训练 RetinaNet

  6. 做出预测

  7. 结论

  8. 参考文献

相关链接

  • Kaggle 实验笔记: 随意复制笔记本,试验代码,并使用免费的 GPU。

  • PlantDoc 数据集:这是本笔记本中使用的数据集,托管在 Roboflow 上。该数据集在 CC BY 4.0 DEED 许可证下发布,这意味着你可以在任何媒介或格式中复制和重新分发该材料,甚至用于商业目的。

等等,什么是 KerasCV?

在完成基于图像分割的小项目后(参见这里),我准备转入计算机视觉领域下另一个常见任务:物体检测。物体检测指的是对图像进行处理,产生围绕感兴趣对象的框,并分类这些框中的对象。作为一个简单的例子,看看下面的图片:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

物体检测的示例。请注意边界框和类标签。图片由作者提供。

蓝色的框被称为边界框类名放置在其正上方。因此,物体检测可以分解为两个小问题:

  1. 一个回归问题,模型必须预测盒子左上角和右下角的xy坐标。

  2. 一个分类问题,模型必须预测盒子正在观察的物体类别。

在这个例子中,边界框是由人类创建和标记的。我们希望自动化这个过程,而一个训练良好的物体检测模型正可以做到这一点。

我坐下来回顾我关于物体检测的学习资料,很快就感到失望。不幸的是,大多数介绍性的资料几乎没有提到物体检测。François Chollet 在Python 深度学习 [1] 中提到:

请注意,我们不会涵盖物体检测,因为它对于介绍性书籍来说过于专业和复杂。

Aurélion Géron [2] 提供了许多关于物体检测背后思想的文本内容,但只提供了几行代码来处理带有虚拟边界框的物体检测任务,远未达到我所期望的端到端流水线。Andrew Ng [3] 的著名深度学习专项课程在物体检测方面涵盖最深入,但甚至他在编码实验室中也只是加载了一个预训练的物体检测模型进行推理。

想要更深入地研究,我开始勾勒出一个物体检测流水线的大纲。仅仅为了为 RetinaNet 模型进行预处理,你需要执行以下步骤(注:其他物体检测模型如 YOLO 需要不同的步骤):

  • 将输入图片都调整为相同的大小,并进行填充以防止长宽比混乱。哦,不要忘记边界框;这些也需要适当地重新调整形状,否则你会破坏你的数据。

  • 根据训练集中的真实边界框生成不同尺度和纵横比的锚框。这些锚框在训练过程中作为模型的参考点。

  • 根据与真实框的重叠情况为锚框分配标签。重叠度高的锚框标记为正例,而重叠度低的锚框标记为负例。

  • 描述相同的边界框有多种方法。你需要实现函数来在这些不同格式之间进行转换。稍后会详细介绍。

  • 实现数据增强时,不仅要增强图像,还要增强框。理论上你可以省略这一步,但在实践中这是必要的,以帮助我们的模型更好地泛化。

看看这个例子 在 Keras 网站上。哎呀。我们模型预测的后处理将需要更多工作。借用 Keras 团队的话:这是一个技术上复杂的问题。

当我开始绝望时,我开始急切地浏览互联网,偶然发现了一个我从未听说过的库:KerasCV。当我阅读文档时,我开始意识到这是TensorFlow/Keras 计算机视觉的未来。根据他们的介绍:

KerasCV 可以被理解为 Keras API 的横向扩展:这些组件是新的第一方 Keras 对象,过于专业化而无法添加到核心 Keras 中。它们与核心 Keras API 享有相同级别的打磨和向后兼容保证,并由 Keras 团队维护。

“但为什么我的学习材料中没有提到这个?” 我想。答案很简单:这是一个相当新的库。GitHub 上的第一次提交是在 2022 年 4 月 13 日,太新了,甚至还未出现在我教科书的最新版本中。事实上,该库的 1.0 版本尚未发布(截至 2023 年 11 月 10 日,它是 0.6.4)。我预计 KerasCV 会在我教科书的下一版和在线课程中详细讨论(公平地说,Gèron 确实提到过“新的 Keras NLP 项目”和 Keras CV 项目,读者可能会感兴趣)。

KerasCV 刚刚推出,除了 Keras 团队自己发布的教程外,还没有很多教程(见这里)。在本教程中,我将演示一个端到端的目标检测流程,使用受官方 Keras 指南启发但又不同于这些指南的技术来识别健康和病变叶片。有了 KerasCV,即使是初学者也可以利用标记数据集来构建有效的目标检测管道。

在我们开始之前需要注意几点。KerasCV 是一个快速变化的库,其代码库和文档会定期更新。这里展示的实现将适用于 KerasCV 版本 0.6.4。Keras 团队已声明:“在 KerasCV 达到 v1.0.0 之前,没有向后兼容的承诺。” 这意味着无法保证本教程中使用的方法在 KerasCV 更新时仍然有效。我已在链接的 Kaggle notebook 中硬编码了 KerasCV 版本号,以防止这些问题。

KerasCV 有很多已知的错误,可以在 GitHub 的问题标签页 中查看。此外,文档在一些领域也有所欠缺(我看着你,MultiClassNonMaxSuppression)。在使用 KerasCV 时,尽量不要被这些问题气馁。事实上,这是一个成为 KerasCV 代码库贡献者的绝佳机会!

本教程将重点介绍 KerasCV 的实现细节。我将简要回顾一些目标检测的高级概念,但假设读者对如 RetinaNet 架构等概念有一定背景知识。这里展示的代码已进行编辑和调整以提高清晰度,完整代码请参见上面链接的 Kaggle notebook。

最后,关于安全的提示。这里创建的模型并非最先进的技术;请将其视为一个高层次的教程。在将此植物疾病检测模型投入生产之前,需要进一步的微调和数据清理。最好将模型做出的任何预测交由人工专家确认诊断。

检查数据

PlantDoc 数据集包含 2,569 张图像,涵盖 13 种植物和 30 个类别。数据集的目标在 Singh 等人撰写的论文 PlantDoc: A Dataset for Visual Plant Disease Detection 的摘要中进行了阐述 [4]。

印度由于植物疾病每年损失 35% 的作物产量。由于缺乏实验室基础设施和专业知识,植物疾病的早期检测仍然很困难。本文探讨了计算机视觉方法在可扩展和早期植物疾病检测中的可能性。

这是一个崇高的目标,也是计算机视觉可以为农民做出很多贡献的领域。

Roboflow 允许我们以多种不同格式下载数据集。由于我们使用 TensorFlow,建议将数据集下载为 TFRecord 格式。TFRecord 是 TensorFlow 中一种特定格式,旨在高效地存储大量数据。数据由一系列记录表示,每个记录是一个键值对。每个键称为 feature。下载的压缩文件包含四个文件,其中两个用于训练,两个用于验证:

  • leaves_label_map.pbtxt : 这是一个 Protocol Buffers 文本格式文件,用于描述数据的结构。打开文件时,我看到有三十个类别。既有健康叶子如 Apple leaf,也有不健康叶子如 Apple Scab Leaf

  • leaves.tfrecord : 这是包含我们所有数据的 TFRecord 文件。

我们的第一步是检查 leaves.tfrecord。我们的记录包含哪些特征?不幸的是,Roboflow 并未指定这一点。

train_tfrecord_file = '/kaggle/input/plants-dataset/leaves.tfrecord'
val_tfrecord_file = '/kaggle/input/plants-dataset/test_leaves.tfrecord'

# Create a TFRecordDataset
train_dataset = tf.data.TFRecordDataset([train_tfrecord_file])
val_dataset = tf.data.TFRecordDataset([val_tfrecord_file])

# Iterate over a few entries and print their content. Uncomment this to look at the raw data
for record in train_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(record.numpy())
  print(example)

我看到以下打印的特征:

  • image/encoded : 这是图像的编码二进制表示。在这个数据集中,图像是以 jpeg 格式编码的。

  • image/height : 这是每个图像的高度。

  • image/width : 这是每个图像的宽度。

  • image/object/bbox/xmin : 这是边界框左上角的 x 坐标。

  • image/object/bbox/xmax : 这是边界框右下角的 x 坐标。

  • image/object/bbox/ymin : 这是边界框左上角的 y 坐标。

  • image/object/bbox/ymax : 这是边界框右下角的 y 坐标。

  • image/object/class/label : 这些是与每个边界框关联的标签。

现在我们想把所有图像及其关联的边界框整合到一个 TensorFlow Dataset 对象中。Dataset 对象允许你存储大量数据而不会使系统内存超载。这是通过延迟加载批处理等功能实现的。延迟加载意味着数据不会被加载到内存中,直到它被显式请求(例如在执行转换或训练时)。批处理意味着一次只加载选择数量的图像(通常为 8、16、32 等)。简而言之,我建议你始终将数据转换为 Dataset 对象,特别是在处理大量数据时(在目标检测中很常见)。

要将 TFRecord 转换为 TensorFlow 中的 Dataset 对象,你可以使用 tf.data.TFRecordDataset 类从 TFRecord 文件创建数据集,然后使用 map 方法应用解析函数来提取和预处理特征。解析代码如下所示。

def parse_tfrecord_fn(example):
    feature_description = {
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
        'image/object/class/label': tf.io.VarLenFeature(tf.int64),
    }

    parsed_example = tf.io.parse_single_example(example, feature_description)

    # Decode the JPEG image and normalize the pixel values to the [0, 255] range.
    img = tf.image.decode_jpeg(parsed_example['image/encoded'], channels=3) # Returned as uint8

    # Get the bounding box coordinates and class labels.
    xmin = tf.sparse.to_dense(parsed_example['image/object/bbox/xmin'])
    xmax = tf.sparse.to_dense(parsed_example['image/object/bbox/xmax'])
    ymin = tf.sparse.to_dense(parsed_example['image/object/bbox/ymin'])
    ymax = tf.sparse.to_dense(parsed_example['image/object/bbox/ymax'])
    labels = tf.sparse.to_dense(parsed_example['image/object/class/label'])

    # Stack the bounding box coordinates to create a [num_boxes, 4] tensor.
    rel_boxes = tf.stack([xmin, ymin, xmax, ymax], axis=-1)
    boxes = keras_cv.bounding_box.convert_format(rel_boxes, source='rel_xyxy', target='xyxy', images=img)

    # Create the final dictionary.
    image_dataset = {
        'images': img,
        'bounding_boxes': {
            'classes': labels,
            'boxes': boxes
        }
    }
    return image_dataset

让我们详细拆解一下:

  • feature_description : 这是一个描述每个特征预期格式的字典。当特征在数据集中所有示例中的长度是固定时,我们使用 tf.io.FixedLenFeature,当长度存在某些变动时,我们使用 tf.io.VarLenFeature。由于边界框的数量在数据集中并不固定(有些图像有更多框,有些则较少),因此我们对所有与边界框相关的内容使用 tf.io.VarLenFeature

  • 我们使用 tf.image.decode_jpeg 解码图像文件,因为我们的图像是以 JPEG 格式编码的。

  • 请注意用于边界框坐标和标签的 tf.sparse.to_dense 的使用。当我们使用 tf.io.VarLenFeature 时,信息会以稀疏矩阵的形式返回。稀疏矩阵是大多数元素为零的矩阵,结果是一个只有效存储非零值及其索引的数据结构。不幸的是,TensorFlow 中的许多预处理函数要求使用稠密矩阵。这包括 tf.stack,我们用来水平堆叠来自多个边界框的信息。为了解决这个问题,我们使用 tf.sparse.to_dense 将稀疏矩阵转换为稠密矩阵。

  • 在堆叠框之后,我们使用 KerasCV 的 keras_cv.bounding_box.convert_format 函数。检查数据时,我注意到边界框坐标被归一化在 0 和 1 之间。这意味着这些数字表示图像总宽度/高度的百分比。例如,值为 0.5 表示 50% * image_width。这是一种 相对格式,Keras 称之为 REL_XYXY,而不是 绝对格式 XYXY。理论上,转换为绝对格式不是必要的,但当我使用相对坐标训练模型时遇到了错误。有关其他支持的边界框格式,请参见 KerasCV 文档

  • 最后,我们将图像和边界框转换为 KerasCV 所需的格式:字典。Python 字典是一种包含键值对的数据类型。具体来说,KerasCV 期望以下格式:

image_dataset = {
  "images": [width, height, channels],
  bounding_boxes = {
    "classes": [num_boxes],
    "boxes": [num_boxes, 4]
  }
}

这实际上是一个“字典中的字典”,因为 bounding_boxes 也是一个字典。

最后使用 .map 函数将解析函数应用于我们的 TFRecord。然后可以检查 Dataset 对象。一切正常。

train_dataset = train_dataset.map(parse_tfrecord_fn)
val_dataset = val_dataset.map(parse_tfrecord_fn)

# Inspecting the data
for data in train_dataset.take(1):
    print(data)

恭喜,最困难的部分现在已经完成了。 在我看来,创建 KerasCV 所需的“字典中的字典”是最具挑战性的任务。其余部分更为直接。

图像预处理

我们的数据已经分为训练集和验证集。所以我们将开始对数据集进行批处理。

# Batching
BATCH_SIZE = 32
# Adding autotune for pre-fetching
AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_dataset = val_dataset.ragged_batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

NUM_ROWS = 4
NUM_COLS = 8
IMG_SIZE = 416
BBOX_FORMAT = "xyxy"

一些说明:

  • 我们使用 ragged_batch 是因为我们不知道每个图像将有多少个边界框。如果所有图像都有相同数量的边界框,那么我们可以直接使用 batch

  • 我们设置了 BBOX_FORMAT=“xyxy” 。回忆一下,之前在加载数据时,我们将边界框格式从相对的 XYXY 格式转换为绝对的 XYXY 格式。

现在我们可以实现 数据增强。数据增强是计算机视觉问题中的一种常见技术。它对训练图像进行轻微的修改,例如轻微旋转、水平翻转图像等。这有助于解决数据不足的问题,并且有助于正则化。在这里,我们将引入以下增强方法:

  • KerasCV 的JitteredResize函数。这个函数旨在用于目标检测管道,实现了一种图像增强技术,涉及随机缩放、调整大小、裁剪和填充图像及相应的边界框。这一过程引入了尺度和局部特征的变异,提高了训练数据的多样性,从而改善了模型的泛化能力。

  • 然后我们添加了水平和垂直的RandomFlips以及RandomRotation。这里的factor是一个表示 2π分数的浮点数。我们使用 0.25,这意味着我们的增强器会将图像旋转-25%到 25%π之间的某个角度。以度数表示,这意味着旋转范围在-45°到 45°之间。

  • 最后,我们添加了RandomSaturationRandomHue。饱和度为 0.0 会留下灰度图像,而 1.0 则完全饱和。0.5 的因子不会造成任何变化,因此选择 0.4–0.6 的范围会产生细微的变化。色调因子为 0.0 不会产生变化。设置factor=0.2表示范围为 0.0–0.2,这是另一种细微变化。

augmenter = keras.Sequential(
    [
        keras_cv.layers.JitteredResize(
            target_size=(IMG_SIZE, IMG_SIZE), scale_factor=(0.8, 1.25), bounding_box_format=BBOX_FORMAT
        ),
        keras_cv.layers.RandomFlip(mode="horizontal_and_vertical", bounding_box_format=BBOX_FORMAT),
        keras_cv.layers.RandomRotation(factor=0.25, bounding_box_format=BBOX_FORMAT),
        keras_cv.layers.RandomSaturation(factor=(0.4, 0.6)),
        keras_cv.layers.RandomHue(factor=0.2, value_range=[0,255])
    ]
)

train_dataset = train_dataset.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)

我们通常只对训练集进行增强,因为我们希望模型避免“记忆”模式,而是确保模型学习到在现实世界中会遇到的通用模式。这增加了模型在训练过程中看到的多样性。

我们还希望将验证图像调整为相同的大小(带有填充)。这些图像将在不失真的情况下调整大小。边界框也必须相应地重新调整。KerasCV 可以轻松处理这一困难任务:

# Resize and pad images
inference_resizing = keras_cv.layers.Resizing(
    IMG_SIZE, IMG_SIZE, pad_to_aspect_ratio=True, bounding_box_format=BBOX_FORMAT
)

val_dataset = val_dataset.map(inference_resizing, num_parallel_calls=tf.data.AUTOTUNE)

最后,我们可以可视化我们的图像和包含预处理的边界框:

class_mapping = {
    1: 'Apple Scab Leaf',
    2: 'Apple leaf',
    3: 'Apple rust leaf',
    4: 'Bell_pepper leaf',
    5: 'Bell_pepper leaf spot',
    6: 'Blueberry leaf',
    7: 'Cherry leaf',
    8: 'Corn Gray leaf spot',
    9: 'Corn leaf blight',
    10: 'Corn rust leaf',
    11: 'Peach leaf',
    12: 'Potato leaf',
    13: 'Potato leaf early blight',
    14: 'Potato leaf late blight',
    15: 'Raspberry leaf',
    16: 'Soyabean leaf',
    17: 'Soybean leaf',
    18: 'Squash Powdery mildew leaf',
    19: 'Strawberry leaf',
    20: 'Tomato Early blight leaf',
    21: 'Tomato Septoria leaf spot',
    22: 'Tomato leaf',
    23: 'Tomato leaf bacterial spot',
    24: 'Tomato leaf late blight',
    25: 'Tomato leaf mosaic virus',
    26: 'Tomato leaf yellow virus',
    27: 'Tomato mold leaf',
    28: 'Tomato two spotted spider mites leaf',
    29: 'grape leaf',
    30: 'grape leaf black rot'
}

def visualize_dataset(inputs, value_range, rows, cols, bounding_box_format):
    inputs = next(iter(inputs.take(1)))
    images, bounding_boxes = inputs["images"], inputs["bounding_boxes"]
    visualization.plot_bounding_box_gallery(
        images,
        value_range=value_range,
        rows=rows,
        cols=cols,
        y_true=bounding_boxes,
        scale=5,
        font_scale=0.7,
        bounding_box_format=bounding_box_format,
        class_mapping=class_mapping,
    )

# Visualize training set
visualize_dataset(
    train_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

# Visualize validation set
visualize_dataset(
    val_dataset, bounding_box_format=BBOX_FORMAT, value_range=(0, 255), rows=NUM_ROWS, cols=NUM_COLS
)

这种类型的可视化函数在 KerasCV 中很常见。它绘制了一组图像和框,行和列由参数指定。我们看到我们的训练图像有些被轻微旋转,有些被水平或垂直翻转,可能还进行了放大或缩小,并且色调/饱和度的细微变化也可以看到。在 KerasCV 中,所有增强层也会在必要时增强边界框。 请注意,class_mapping是一个简单的字典。我从之前提到的leaves_label_map.pbtxt文本文件中获得了键和标签。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

左侧是原始图像(验证集)的示例,右侧是增强图像(训练集)。图片由作者提供。

在查看 RetinaNet 模型之前最后要说的一件事是,之前我们需要创建“字典中的字典”以将数据转换为与 KerasCV 预处理兼容的格式,但现在我们需要将其转换为数字元组以供模型训练。这相当直接:

def dict_to_tuple(inputs):
    return inputs["images"], bounding_box.to_dense(
        inputs["bounding_boxes"], max_boxes=32
    )

train_dataset = train_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = val_dataset.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)

RetinaNet 模型背景

一个用于进行目标检测的流行模型叫做RetinaNet。该模型的详细描述超出了本文的范围。简而言之,RetinaNet 是一个单阶段检测器,意味着它在预测边界框和类别之前只查看一次图像。这类似于著名的 YOLO(You Only Look Once)模型,但有一些重要的不同之处。我在这里要强调的是使用的创新分类损失函数:focal loss。它解决了图像中的类别不平衡问题。

为了理解这点的重要性,可以考虑以下类比:假设你是一名教室里有 100 个学生的老师。95 个学生吵闹且喧哗,总是喊叫和举手。5 个学生安静,不怎么说话。作为老师,你需要平等关注每个人,但吵闹的学生正在挤走安静的学生。这里你遇到了类别不平衡的问题。为了解决这个问题,你开发了一种特殊的助听器,它增强了安静学生的声音并弱化了吵闹学生的声音。在这个类比中,吵闹的学生是我们图像中不包含叶子的背景像素的大多数,而安静的学生是那些包含叶子的少量区域。这个“助听器”就是 focal loss,它使我们可以将模型集中在包含叶子的像素上,而不会过多关注那些不包含叶子的像素。

RetinaNet 模型有三个重要组件:

  • 一个 骨干网络。这构成了模型的基础。我们也称之为特征提取器。顾名思义,它接收图像并扫描特征。低层提取低级特征(例如线条和曲线),而高层提取高级特征(例如嘴唇和眼睛)。在这个项目中,骨干网络将是一个在COCO 数据集上进行过预训练的 YOLOv8 模型。我们只将 YOLO 用作特征提取器,而不是作为目标检测器。

  • 特征金字塔网络(FPN)。这是一种模型架构,在不同的尺度上生成“金字塔”特征图,以检测各种大小的对象。它通过通过自上而下的路径和横向连接将低分辨率的语义强特征与高分辨率的语义弱特征结合起来。查看这个视频以获取详细解释,或查看这篇论文 [5],该论文介绍了 FPN。

  • 两个任务特定的子网络。 这些子网络处理金字塔的每一层,并检测每层中的对象。一个子网络用于识别类别(分类),另一个用于识别边界框(回归)。这些子网络尚未训练。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

简化的 RetinaNet 架构。图片由作者提供。

之前我们将图像调整为 416x416 的大小。这是一个有点随意的选择,尽管你选择的目标检测模型通常会指定一个所需的最小大小。对于我们使用的 YOLOv8 主干,图像大小应该是 32 的倍数。这是因为主干的最大步幅是 32,而且它是一个完全卷积网络。对于你自己项目中使用的任何模型,请进行调研以找出这个因素。

训练 RetinaNet

让我们从设置一些基本参数开始,比如优化器和我们将使用的指标。这里我们将使用 Adam 作为优化器。请注意global_clip_norm参数。根据KerasCV 目标检测指南

在训练目标检测模型时,你总是希望包含global_clipnorm。这是为了修复在训练目标检测模型时经常出现的梯度爆炸问题。

base_lr = 0.0001
# including a global_clipnorm is extremely important in object detection tasks
optimizer_Adam = tf.keras.optimizers.Adam(
    learning_rate=base_lr,
    global_clipnorm=10.0
)

我们将遵循他们的建议。对于我们的指标,我们将使用BoxCOCOMetrics。这些是目标检测中流行的指标。它们基本上包括平均精度 (mAP)平均召回率 (mAR)。总的来说,mAP 通过测量正确对象检测的平均面积与模型预测覆盖的总面积的比率来量化模型定位和识别对象的有效性。mAR 是一个不同的分数,通过计算正确识别的对象区域与实际对象区域的平均比例来评估模型捕获对象全部范围的能力。有关指标的详细信息,请参见这篇文章这段视频 对精度和召回率的基本知识进行了很好的解释。

coco_metrics = keras_cv.metrics.BoxCOCOMetrics(
    bounding_box_format=BBOX_FORMAT, evaluate_freq=5
)

由于框的指标计算开销很大,我们传递evaluate_freq=5参数,以告知我们的模型在每五个批次后计算指标,而不是在训练期间每个批次后计算。我注意到,当数字设置得过高时,验证指标根本没有被打印出来。

让我们继续查看我们将使用的回调:

class VisualizeDetections(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if (epoch+1)%5==0:
            visualize_detections(
                self.model, bounding_box_format=BBOX_FORMAT, dataset=val_dataset, rows=NUM_ROWS, cols=NUM_COLS
            )

checkpoint_path="best-custom-model"

callbacks_list = [
    # Conducting early stopping to stop after 6 epochs of non-improving validation loss
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=6,
    ),

    # Saving the best model
    keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True
    ),

    # Custom metrics printing after each epoch
    tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: 
        print(f"\nEpoch #{epoch+1} \n" +
              f"Loss: {logs['loss']:.4f} \n" + 
              f"mAP: {logs['MaP']:.4f} \n" + 
              f"Validation Loss: {logs['val_loss']:.4f} \n" + 
              f"Validation mAP: {logs['val_MaP']:.4f} \n") 
    ),

    # Visualizing results after each five epochs
    VisualizeDetections()
]
  • 早停。如果验证损失在六个周期后没有改善,我们将停止训练。

  • 模型检查点。我们将在每个周期后检查val_loss,如果它优于早期的周期,将保存模型权重。

  • Lambda 回调。Lambda 回调是一个自定义回调,允许你在训练过程中于每个周期的不同点定义并执行任意 Python 函数。在这种情况下,我们用它来在每个周期后打印自定义指标。如果直接打印 COCOMetrics,会是一堆杂乱的数字。为了简化,我们只打印训练和验证的损失和 mAP。

  • 检测的可视化。 这将在每五个周期后打印出一个 4x8 的图像网格以及预测的边界框。这将使我们洞察我们的模型有多好(或多糟)。如果一切顺利,这些可视化效果应该随着训练的进行而变得更好。

最终我们创建了我们的模型。回顾一下,主干是一个 YOLOv8 模型。我们必须传递我们将使用的 num_classes,以及 bounding_box_format

# Building a RetinaNet model with a backbone trained on coco datset
def create_model():        
    model = keras_cv.models.RetinaNet.from_preset(
        "yolo_v8_m_backbone_coco",
        num_classes=len(class_mapping),
        bounding_box_format=BBOX_FORMAT
    )
    return model

model = create_model()

我们还必须自定义模型的 非极大值抑制 参数。非极大值抑制用于目标检测中,以过滤掉多个重叠的预测边界框,这些框对应于同一对象。它只保留置信度分数最高的框,并删除冗余的框,确保每个对象只被检测一次。它包含两个参数:iou_thresholdconfidence_threshold

  1. IoU,或 交并比,是一个介于 0 和 1 之间的数字,衡量一个预测框与另一个预测框之间的重叠程度。如果重叠高于 iou_threshold,则置信度较低的预测框会被丢弃。

  2. 置信度分数反映了模型对其预测的边界框的信心。如果预测框的置信度分数低于 confidence_threshold,则该框会被丢弃。

尽管这些参数不会影响训练,但它们仍需根据您的特定应用进行调整以用于预测。设置 iou_threshold=0.5confidence_threshold=0.5 是一个好的起点。

在开始训练之前有一点需要注意:我们讨论了为什么将 分类损失 设置为焦点损失是有帮助的,但我们还没有讨论定义预测边界框坐标误差的合适 回归损失。一种流行的回归损失(或 box_loss)是 平滑 L1 损失。我认为平滑 L1 是一种“兼顾两全”的损失。它结合了 L1 损失(绝对误差)和 L2 损失(均方误差)。当误差值较小时,损失是二次的,当误差值较大时,损失是线性的(查看此链接)。KerasCV 为我们的便利提供了内置的平滑 L1 损失。训练期间显示的损失将是 box_lossclassification_loss 的总和。

# Using focal classification loss and smoothl1 box loss with coco metrics
model.compile(
    classification_loss="focal",
    box_loss="smoothl1",
    optimizer=optimizer_Adam,
    metrics=[coco_metrics]
)

history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=40,
    callbacks=callbacks_list,
    verbose=0,
)

在 NVIDIA Tesla P100 GPU 上训练大约需要一个小时 12 分钟。

进行预测

# Create model with the weights of the best model
model = create_model()
model.load_weights(checkpoint_path)

# Customizing non-max supression of model prediction. I found these numbers to work fairly well
model.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
    bounding_box_format=BBOX_FORMAT,
    from_logits=True,
    iou_threshold=0.2,
    confidence_threshold=0.6,
)

# Visuaize on validation set
visualize_detections(model, dataset=val_dataset, bounding_box_format=BBOX_FORMAT, rows=NUM_ROWS, cols=NUM_COLS)

现在我们可以加载在训练过程中看到的最佳模型,并用它对验证集进行一些预测:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

验证集预测的样本视觉效果。图片由作者提供。

我们最佳模型的指标是:

  • 损失: 0.4185

  • mAP: 0.2182

  • 验证损失: 0.4584

  • 验证集 mAP: 0.2916

值得尊敬,但还有改进的空间。更多内容将在结论中讨论。(注意:我发现MultiClassNonMaxSuppression似乎没有正常工作。上面显示的左下角图像明显有超过 20%重叠的框,但较低置信度的框没有被抑制。这是我需要进一步研究的问题。)

这里是我们每个训练周期和验证周期的损失图。可以看到有些过拟合现象。此外,增加一个学习率调度器以逐渐降低学习率可能是明智的。这可能有助于解决在训练结束时出现的大幅跳跃问题。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个训练周期和验证周期的损失图。我们看到了一些过拟合的迹象。图片由作者提供。

结论

如果你已经做到这一步,给自己一个赞美吧!目标检测是计算机视觉中较为困难的任务之一。幸运的是,我们有新的 KerasCV 库来简化我们的工作。总结一下创建目标检测管道的工作流程:

  • 开始时可视化你的数据集。问自己一些问题:“我的边界框格式是什么?是xyxyRelxyxy?我处理多少个类别?”确保创建一个类似于visualize_dataset的函数来查看你的图像和边界框。

  • 将你拥有的任何格式的数据转换为 KerasCV 所需的“字典中的字典”格式。使用 TensorFlow Dataset 对象来存储数据特别有帮助。

  • 进行一些基本的预处理,例如图像缩放和数据增强。KerasCV 使这些操作相对简单。请注意查阅你选择的模型的文献,以确保图像尺寸适当。

  • 将字典转换回元组以用于训练。

  • 选择一个优化器Adam是一个简单的选择),两个损失函数focal用于类别损失,L1 smooth用于框损失是简单的选择),以及指标COCO metrics是一个简单的选择)。

  • 在训练期间可视化你的检测结果可以帮助了解你的模型遗漏了哪些对象。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据集中问题标签的示例。图片由作者提供。

主要的下一步之一是清理数据集。例如,查看上面的图像。标注者正确地识别了马铃薯叶晚疫病,但其他所有健康的马铃薯叶子呢?为什么这些没有标注为马铃薯叶?查看 Roboflow 网站上的健康检查标签,你可以看到某些类别在数据集中严重不足:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

显示类别不平衡的图表。来自 Roboflow 的网站

在调整任何超参数之前,先尝试修复这些问题。祝你在目标检测任务中好运!

参考文献

[1] F. Chollet, 用 Python 进行深度学习(2021), Manning Publications Co.

[2] A. Géron, 动手实践机器学习:使用 Scikit-Learn、Keras 和 TensorFlow (2022), O’Reily Media Inc.

[3] A. Ng, 深度学习专项课程, DeepLearning.AI

[4] D. Singh, N. Jain, P. Jain, P. Kayal, S. Kumawat, N. Batra, PlantDoc:用于视觉植物疾病检测的数据集 (2019), CoDS COMAD 2020

[5] T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, 用于目标检测的特征金字塔网络(2017), CVPR 2017

[6] T. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, 目标检测中的焦点损失(2020), IEEE 模式分析与机器智能学报

面向对象的数据科学:重构代码

原文:towardsdatascience.com/object-oriented-data-science-refactoring-code-5bcb4ae7ce72

提升机器学习模型和数据科学产品的效率代码和 Python 类。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Molly Ruby

·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 8 月 24 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者创建。

对于数据科学家来说,代码是分析和决策的核心。随着数据科学应用的复杂性增加,从嵌入在软件中的机器学习模型到协调大量信息的复杂数据管道,开发干净、组织良好且易于维护的代码变得至关重要。面向对象编程(OOP)解锁了灵活性和效率,使数据科学家能够敏捷地应对不断变化的需求。OOP 引入了类的概念,这些类作为创建对象的蓝图,这些对象封装了数据及其操作。这种范式转变使数据科学家能够超越传统的函数方法,促进模块化设计和代码重用。

在本文中,我们将探讨通过创建类和部署面向对象技术来重构数据科学代码的好处,以及这种方法如何增强模块化和可重用性。

数据科学中的类的力量

在传统的数据科学工作流中,函数是封装逻辑的方法。这通常足够,因为函数允许开发人员减少重复代码。然而,随着项目的发展,维护大量函数可能会导致代码难以导航、调试和扩展。

这时类发挥了作用。类是创建对象的蓝图,这些对象将数据和操作数据的函数(称为方法)捆绑在一起。通过将代码组织成类,开发人员可以实现以下几个优势:

  1. 模块化和封装:类通过将相关功能组合在一起来促进模块化。每个类封装了自己的属性(数据)和方法(函数),减少了全局变量污染的风险和命名冲突的可能性。这有助于保持关注点的清晰分离,使代码更容易理解和修改。

  2. 可重用性:类通过为项目的不同部分提供一致的接口来鼓励重用。一旦定义了一个类,它可以在需要时实例化,并且其方法可以用来实现一致的结果。

  3. 4. 继承和多态:继承允许开发人员创建子类,从父类继承属性和方法。这促进了代码重用,同时使得特定任务的定制成为可能。多态性,另一个面向对象编程的概念,让开发人员可以在不同类中使用相同的方法名称,根据具体实现调整行为。

  4. 5. 测试和调试:类促进了单元测试,因为测试用例可以针对类中的单独方法,这使得识别和修复问题变得更加容易,从而提高了代码库的整体健壮性。

将代码重构为类:一个理论例子

假设你正在进行一个涉及数据预处理、模型训练和评估的机器学习项目。最初,你可能会有一组用于每个步骤的函数:

# Example: Using functions for data preprocessing

def load_data(file_path):
    # Load and preprocess data
    ...

def preprocess_data(data):
    # Clean, transform, and encode data
    ...

def train_model(preprocessed_data):
    # Train a machine learning model
    ...

def evaluate_model(trained_model, test_data):
    # Evaluate model performance
    ...

尽管功能分解有效,但随着时间的推移,预处理、训练和评估中可能会有许多步骤。这可能会使管理这些函数变得具有挑战性。

将代码重构为类:

def load_data(self, file_path):
        # Load data
        ...

class DataPreprocessor:
    def __init__(self, data):
        self.raw_data = data
        self.cleaned_data = self.clean_data(data)

    def clean_data(self):
        # imputation, outlier treatment
        ...

    def transform_data(self):
        # transformations and encode data
        ...

class ModelTrainer:
    def __init__(self, preprocessed_data):
        self.model = self.train_model(preprocessed_data)

    def fit(self, preprocessed_data):
        # Train a machine learning model
        ...

    def predict(self, preprocessed_data):
        # Predict using the machine learning model
        ...

class ModelEvaluator:
    def __init__(self, predictions, actuals):
        self.performance_metrics = self.evaluate_model(predictions, actuals)

    def evaluate_model(self, predictions, actuals):
        # Evaluate model performance
        ...

    def calculate_rmse(self, predictions, actuals):
        # Evaluate root mean squared error

    def calculate_r_squared(self, predictions, actuals):
        # Evaluate r_squared of the model

通过将工作流程拆分为类,组织性得到了提升,结构也更易于阅读和维护。每个类处理过程的特定方面。它们可以被实例化为:

data_preprocessor = DataPreprocessor('data.csv')
model_trainer = ModelTrainer(data_preprocessor.preprocessed_data)
model_evaluator = ModelEvaluator(model_trainer.model, test_data)

在这种情况下,类的引入提供了额外的结构和灵活性,改善了代码的工作流程和可用性。通过利用类的强大功能,这个示例创建了一个更为健壮和可扩展的代码库。

将代码重构为类:一个实际的例子

作为一个实际的例子,我最近将这个代码库三年前开发的代码重构到一个新代码库中,以展示重构前后的代码差异。

在初始代码库中,许多函数涵盖了建模任务,因为训练和测试了多个不同的模型。在重构版本中,有一个名为 SalesForecasting 的模型类涵盖了所有建模任务。这更易于阅读,并且使得作为 SalesForecasting 部署包更为高效,并可以用不同的输入多次实例化。作为预览,这个类的样子如下:

class SalesForecasting:
    """
    SalesForecasting class to train and predict sales using a variety of models. 
    """

    def __init__(self, model_list):
        """
        Initialize the SalesForecasting class with a list of models to train and predict.

        Args:
            model_list (list): list of models to train and predict. Options include:
                - LinearRegression
                - RandomForest
                - XGBoost
                - LSTM
                - ARIMA

        Returns:
            None
        """

        ...

    def fit(self, X_train, y_train):
        """
        Fit the models in model_dict to the training data.

        Args:
            X_train (pd.DataFrame): training data exogonous features for the model
            y_train (pd.Series): training data target for the model

        Returns:
            None
        """

        ...

    def __fit_regression_model(self, model):
        """
        Fit a regression model to the training data.

        Args:
            model (sklearn model): sklearn model to fit to the training data

        Returns:
            model (sklearn model): fitted sklearn model
        """
        ...

    def __fit_lstm_model(self, model):
        """
        Fit an LSTM model to the training data.

        Args:
            model (keras model): keras model to fit to the training data

        Returns:
            model (keras model): fitted keras model
        """

        ...

    def __fit_arima_model(self, model_name):
        """
        Fit an ARIMA model to the training data.

        Args:
            model_name (str): name of the model to fit to the training data

        Returns:
            model (pmdarima model): fitted pmdarima model
        """
        ...

    def predict(self, x_values, y_values=None, scaler=None, print_scores=False):
        """
        Predict values using the models in model_dict.

        Args:
            x_values (pd.DataFrame): exogenous features to predict on
            y_values (pd.Series): target values to compare predictions against
            scaler (sklearn scaler): scaler used to scale the data
            print_scores (bool): whether to print the scores for each model

        Returns:
            self (SalesForecasting): self with updated predictions
        """

        ...

    def __predict_regression_model(self, model):
        """
        Predict values using a regression model.

        Args:
            model (sklearn model): sklearn model to predict with

        Returns:
            predictions (np.array): array of predictions
        """
        ...

    def __predict_lstm_model(self, model):
        """
        Predict values using an LSTM model.

        Args:
            model (keras model): keras model to predict with

        Returns:
            predictions (np.array): array of predictions
        """
        ...

    def __predict_arima_model(self, model):
        """
        Predict values using an ARIMA model.

        Args:
            model (pmdarima model): pmdarima model to predict with
        Returns: 
            predictions (np.array): array of predictions
        """
        ...

    def __undo_scaling(self, values, scaler):
        """
        Undo scaling on a set of values.

        Args:
            values (np.array): array of values to unscale
            scaler (sklearn scaler): scaler to use to unscale the values

        Returns:
            unscaled_values (np.array): array of unscaled values
        """
        ...

    def get_scores(self, y_pred, y_true, model_name=None, print_scores=False):
        """
        Get the scores for a model. Scores include RMSE, MAE, and R2.

        Args:
            y_pred (np.array): array of predicted values
            y_true (np.array): array of true values
            model_name (str): name of the model to get scores for
            print_scores (bool): whether to print the scores for the model

        Returns:
            rmse (float): root mean squared error
            mae (float): mean absolute error
            r2 (float): r squared
        """
        ...

    def plot_results(self, model_list=None, figsize=p.FIG_SIZE, xlabel="Date", ylabel="Sales", title="Sales Forecasting Predictions"):
        """
        Plot the results of the predictions against the actual values.
        Generates a timeseries for predictions from each model in model_dict.

        Args:
            model_list (list): list of models to plot. If None, plots all models in model_dict
            figsize (tuple): tuple of figure size
            xlabel (str): label for x axis
            ylabel (str): label for y axis
            title (str): title for the plot

        Returns:
            fig (matplotlib figure): figure with the plot
        """

        ...

    def plot_errs(self, figsize=(13,3)):
        """
        Plot the errors for each model in model_dict. Errors include RMSE, MAE, and R2.

        Args:
            figsize (tuple): tuple of figure size

        Returns:
            fig (matplotlib figure): figure with the plot
        """
        ...

“SalesForecasting”类作为一个全面的蓝图,帮助数据驱动的企业通过应用各种预测模型来预测未来的销售趋势。在这个类中,数据科学家可以利用不同的建模技术,包括线性回归、随机森林、XGBoost、LSTM(长短期记忆)和 ARIMA(自回归积分滑动平均)。通过将预测工作流封装在这个类中,模型拟合、预测和评估的过程变得更加简化和一致。通过“SalesForecasting”类,数据科学家可以高效地实验不同的算法,并轻松维护代码库。

面向对象编程是数据科学家用来构建反映其分析的真实世界系统复杂性的代码的工具,使他们能够在最大化灵活性的同时提取有价值的洞察。虽然 python 旨在将类用于实例化和继承,上面的例子展示了一个初步步骤,其中类被用于模块化代码。随着数据科学能力的扩展和团队的增长,维护高效的代码是至关重要的。

在这里查看完整的重构代码库。

无需 OCR 的文档数据提取与变换器 (1/2)

原文:towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-1-2-b5a826bc2ac3?source=collection_archive---------3-----------------------#2023-04-28

Donut 与 Pix2Struct 在自定义数据上的对比

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Toon Beerten

·

关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 4 月 28 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像 ()

DonutPix2Struct 是图像到文本模型,将纯像素输入的简单性与视觉语言理解任务相结合。简单来说:输入一张图像,提取的索引以 JSON 格式输出。

最近我发布了一个在发票上微调的 Donut 模型。我经常收到如何使用自定义数据集进行训练的问题。此外,还发布了一个类似的模型:Pix2Struct,它声称性能显著更好。但真的是这样吗?

该是卷起袖子的时候了。我将展示给你:

  • 如何为 Donut 和 Pix2Struct 微调准备数据

  • 两种模型的训练过程

  • 实际数据集上的比较结果

当然,我也会提供 colab 笔记本,以便于你的实验和/或复制。

数据集

要进行此比较,我需要一个公开的数据集。我想避免使用通常用于文档理解任务的数据集,例如CORD,浏览了一下,发现了Ghega 数据集。它相当小(约 250 个文档),由 2 种类型的文档组成:专利申请和数据表。通过不同类型,我们可以模拟一个分类问题。每种类型我们都有多个索引需要提取。这些索引对于每种类型都是唯一的。正是我所需要的。来自的 Trieste 大学机器学习实验室的Medvet教授慷慨批准了这些文章的使用。

数据集似乎比较旧,所以需要调查它是否仍然适合我们的目标。

初步探索

当你获得一组新的数据时,你首先需要熟悉其结构。幸运的是,网站的详细描述对我们很有帮助。这是数据集的文件结构:

ghega-dataset
    datasheets
        central-zener-1
        central-zener-2
        diodes-zener
            document-000-123542.blocks.csv
            document-000-123542.groundtruth.csv
            document-000-123542.in.000.png
            document-000-123542.out.000.png
            document-001-123663.blocks.csv
            document-001-123663.groundtruth.csv
            document-001-123663.in.000.png
            document-001-123663.out.000.png
            ...
        mcc-zener
        ...
    patents
        ...

我们可以看到两个主要的子文件夹对应两个文档类型:数据表专利。在更下一级,我们有一些子文件夹,这些子文件夹本身不重要,但它们包含以某个前缀开头的文件。我们可以看到一个唯一的标识符,例如document-000–123542。对于每个这些标识符,我们有 4 种数据:

  • blocks.csv 文件包含有关边界框的信息。由于 Donut 或 Pix2Struct 不使用这些信息,我们可以忽略这些文件。

  • out.000.png 文件是后处理(去倾斜)的图像文件。由于我更愿意测试未处理的文件,我也会忽略这些。

  • 原始的、未处理的文档图像有一个 in.000.png 后缀。这是我们感兴趣的。

  • 最后是相应的groundtruth.csv文件。这包含我们认为是实际标注的图像索引。

这里是一个示例 groundtruth csv 文件以及列描述:

Case,-1,0.0,0.0,0.0,0.0,,0,1.28,2.78,0.79,0.10,MELF CASE
StorageTemperature,0,0.35,3.40,2.03,0.11,Operating and Storage Temperature,0,4.13,3.41,0.63,0.09,-65 to +200
 1\. element type
 2\. page of the label block (-1 if absent)
 3\. x of the label block
 4\. y of the label block
 5\. w of the label block
 6\. h of the label block
 7\. text of the label block
 8\. page of the value block (never absent!)
 9\. x of the value block
10\. y of the value block
11\. w of the value block
12\. h of the value block
13\. text of the label block

这意味着我们只对第一列和最后一列感兴趣。第一列是,最后一列是。在这种情况下:

KEY                   VALUE
Case                  MELF CASE
StorageTemperature    -65 to +200

这意味着对于该文档,我们将微调模型以查找‘Case’的值为‘MELF CASE’,并且提取一个‘StorageTemperature’,其值为‘-65 to +200’。

索引

在 groundtruth 元数据中存在以下索引:

  • 数据表:型号、类型、外壳、功耗、储存温度、电压、重量、热阻

  • 专利:标题、申请人、发明人、代表、申请日期、出版日期、申请编号、出版编号、优先权、分类、摘要第一行

观察到地面真实值的质量和可行性,我选择保留以下索引:

elements_to_extract = ['FilingDate', 'RepresentiveFL', 'Classification', 'PublicationDate','ApplicationNumber','Model','Voltage','StorageTemperature']

质量

对于图像转换为文本,使用了 ocropus 版本 0.2。这意味着它大约在 2014 年底发布。在数据科学领域这已经很古老了,那么地面真实度的质量是否符合我们的任务要求呢?

为此,我查看了一些随机图像,并将地面真实值与实际在文档上写的内容进行了比较。以下是两个 OCR 不正确的示例:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 Ghega 数据集的 document-001–109381.in.000.png

Classification 被设置为 BGSD 81/00 作为地面真实值。它应该是 B65D 81/100

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 Ghega 数据集的 document-003–112107.in.000.png

StorageTemperature 显示 I -65 {O + 150 作为地面真实值,而我们可以看到它应该是 -65 to + 150

数据集中有许多此类错误。一种方法是纠正这些错误。另一种是忽略这些错误。由于我将使用相同的数据来比较两个模型,我选择了后者。如果数据用于生产,你可能需要选择前一种方法以获得最佳结果。

(还要注意,这些特殊字符可能会搞乱 JSON 格式,稍后我会回到这个话题)

Donut 数据集结构

我们需要的数据格式是什么样的?

对于微调 Donut 模型,我们需要将数据组织在一个文件夹中,所有文档作为单独的图像文件和一个元数据文件,结构为 JSON lines 文件。

donut-dataset
    document-000-123542.in.000.png
    document-001-123663.in.000.png
    ...
    metadata.jsonl

JSONL 文件包含每个图像文件一行,格式如下:

{"file_name": "document-010-100333.in.000.png", "ground_truth": "{\"gt_parse\": { \"DocType\": \"patent\", \"FilingDate\": \"06.12.1999\", \"RepresentiveFL\": \"Manresa Val, Manuel\", \"Classification\": \"A47l. 5/28\", \"PublicationDate\": \"1139845\", \"ApplicationNumber\": \"99959528 .3\" } }"}

让我们分解这行 JSON。在上层我们有一个包含两个元素的字典:file_nameground_truth。在 ground_truth 键下,我们有一个包含 gt_parse 键的字典。其值本身是一个字典,包含我们在文档中知道的键值对。或者更好:assign。记住,文档中不一定会出现文档类型。术语 datasheet 并没有作为文本出现在这些文档中。

幸运的是,pix2struct 使用相同的格式进行微调,因此我们可以一举两得。一旦我们将其转换为这种结构,我们还可以用来微调 Pix2Struct。

转换

对于转换本身,我在 colab 上创建了一个 Jupyter notebook。我决定在这个阶段将数据拆分为训练集和验证集,而不是在微调之前。这种方式,两个模型将使用相同的验证图像,结果会更具可比性。五个文档中会有一个用于验证。

利用上述 Ghega 数据集的结构知识,我们可以将转换过程概括如下:

对于每个以 in.000.png 结尾的文件名,我们取对应的 groundtruth 文件并创建一个临时的数据框对象。

注意,groundtruth 可能为空或完全不存在。(例如,对于 datasheets/taiwan-switching

接下来,我们从子文件夹中扣除类:patentdatasheet 。现在我们需要构建 JSON 行。对于每个我们想提取的元素/索引,我们检查它是否在数据框中并进行收集。然后复制图像本身。

对所有图像执行此操作,最后我们就有一个 JSONL 文件可以写出。

在 Python 中,它看起来是这样的:

json_lines_train = ''
json_lines_val = ''

for dirpath, dirnames, filenames in os.walk('/content/ghega-dataset/'):
    for filename in filenames:
        if filename.endswith('in.000.png'):
          gt_filename = filename.replace('in.000.png','groundtruth.csv')
          gt_filename_path = os.path.join(dirpath, gt_filename)
          if not os.path.exists(gt_filename_path):    #ignore files in /ghega-dataset/datasheets/taiwan-switching/ because no groundtruth exists
            continue
          if os.path.getsize(gt_filename_path) == 0:  #ignore empty groundtruth files
            print(f'skipped {gt_filename_path} because no info in metadata')
            continue
          doc_df = pd.read_csv(gt_filename_path, header=None)
          #find the doctype, based on path
          if 'patent' in dirpath:
            type = 'patent'
          else:
            type = 'datasheet'
          #create json line
          #eg:
          #{"file_name": "document-034-127420.in.000.png", "ground_truth": "{\"gt_parse\": { \"DocType\": \"datasheet\", \"Model\": \"ZMM5221 B - ZMM5267B\", \"Voltage\": \"1.5\", \"StorageTemperature\": \"-65 to 175\" } }"}
          p2 = ''
          #add always first element: DocType
          p2 += '\\"' + 'DocType' + '\\": '
          p2 += '\\"' + type + '\\"'
          new_row = {'ImagePath': os.path.join(dirpath, filename), 'DocType' :type}
          ghega_df = pd.concat([ghega_df, pd.DataFrame([new_row])], ignore_index=True)
          #fill other elements if available
          for element in elements_to_extract:
            value = doc_df[doc_df[0] == element][12].tolist()
            if len(value) > 0:
              p2 += ', '
              p2 += '\\"' + element + '\\": '
              value = re.sub(r'[^A-Za-z0-9 ,.()/-]+', '', value[0])   #get rid of \ of ” and " in json
              p2 += '\\"' + value + '\\"'
              new_row = {'ImagePath': os.path.join(dirpath, filename), element :value}
              ghega_df = pd.concat([ghega_df, pd.DataFrame([new_row])], ignore_index=True)

          p3 = ' } }"}'

          json_line = p1 + p2 + p3
          print(json_line)

          #take ~20% to validation
          #copy image file and append json line
          if random.randint(1, 100) < 20:
            output_path = '/content/dataset/validation/'
            json_lines_val += json_line + '\r\n'
            shutil.copy(os.path.join(dirpath, filename), '/content/dataset/validation/')  
          else:
            output_path = '/content/dataset/train/'
            json_lines_train += json_line + '\r\n'
            shutil.copy(os.path.join(dirpath, filename), '/content/dataset/train/')  

#write jsonl files
text_file = open('/content/dataset/train/metadata.jsonl', "w")
text_file.write(json_lines_train)
text_file.close()
text_file = open('/content/dataset/validation/metadata.jsonl', "w")
text_file.write(json_lines_val)
text_file.close()

ghega_df 是一个数据框,用于进行一些合理性检查或统计分析(如有需要)。我用它来检查随机样本,验证我的转换数据是否正确。

问题

转换完成后,一切看起来都很顺利。但我想摆脱那种通常第一次尝试就能成功的想法。总是会有一些小的意外问题发生。谈论我遇到的错误并展示解决方案对任何模拟整个过程并使用自己数据集的人都是有用的。

例如,在转换数据集后,我想训练 Donut 模型。在此之前,我需要创建一个训练数据集,如下所示:

train_dataset = DonutDataset("/content/dataset", max_length=max_length,
                             split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
                             sort_json_key=False, # dataset is preprocessed, so no need for this
                             )

并且出现了这个错误:

---------------------------------------------------------------------------
ArrowInvalid                              Traceback (most recent call last)
<ipython-input-13-7726ec2b0341> in <cell line: 7>()
      5 processor.feature_extractor.do_align_long_axis = False
      6 
----> 7 train_dataset = DonutDataset("/content/dataset", max_length=max_length,
      8                              split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
      9                              sort_json_key=False, # cord dataset is preprocessed, so no need for this

ArrowInvalid: JSON parse error: Missing a comma or '}' after an object member. in row 7

看起来第 7 行的 JSON 格式有问题。我复制了那一行并将其粘贴到一个 在线 JSON 验证器 中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

然而,它表示这是一个有效的 JSON 行。让我们更深入地看看:

{
   "file_name":"document-012-108498.in.000.png",
   "ground_truth":"{\"gt_parse\": {\"DocType\": \"patent\"\"FilingDate\": \"15\. Januar 2004 (15.01.2004)\",\"Classification\": \"BOZC 18/08,\",\"PublicationDate\": \"5\. August 2004 (05.08.2004)\",\"ApplicationNumber\": \"PCT/AT2004/000006\"} }"
}

你发现错误了吗?经过一段时间,我注意到 DocTypeFilingDate 之间缺少逗号。然而,这在所有行中都是缺失的,所以我不清楚为什么第 7 行会出现问题。当我修复了这个问题后,我再次尝试,现在它声称第 17 行有问题:

ArrowInvalid: JSON parse error: Missing a comma or '}' after an object member. in row 17

这是第 17 行,你发现了问题吗?

{"file_name": "document-007-103668.in.000.png", "ground_truth": "{\"gt_parse\": {\"DocType\": \"patent\",\"FilingDate\": \"18.12.2008\",\"RepresentiveFL\": \"Schubert, Siegmar\",\"Classification\": \"A47J 31/42 (2""6·"')\",\"PublicationDate\": \"12.08.2009\",\"ApplicationNumber\": \"08021980.1\"} }"}

这是Classification 元素的未转义引号。为了解决这个问题,我决定所有值只能包含字母数字字符和一些特殊字符,并使用了这个正则表达式:

[^A-Za-z0-9 ,.()/-]+

这可能会严重影响真实性能,但从我所见,其他字符都是由于 OCR 错误引起的。我认为,对于模型之间的相对比较,忽略这些字符影响不大。

数据准备:完成

数据准备的重要性常被忽视且被低估。通过上述步骤,我展示了如何调整自己的数据,以便 Donut 和 Pix2Struct 用于文档的关键索引提取。常见的陷阱也得到了修正。包含所有步骤的 Jupyter 笔记本可以在这里找到。我们已经完成了一半。下一步是用这个数据集训练这两个模型。我非常好奇它们的表现如何,但比较和训练将留到下一篇文章中。

你可能还喜欢:

[## 实战:使用🍩变换器进行文档数据提取

我使用 Donut 变换器模型提取发票索引的经验。

toon-beerten.medium.com](https://toon-beerten.medium.com/hands-on-document-data-extraction-with-transformer-7130df3b6132?source=post_page-----b5a826bc2ac3--------------------------------)

参考文献:

[## 无 OCR 文档理解变换器

理解文档图像(例如,发票)是一项核心但具有挑战性的任务,因为它需要复杂的功能…

arxiv.org](https://arxiv.org/abs/2111.15664?source=post_page-----b5a826bc2ac3--------------------------------) [## Pix2Struct:作为视觉语言理解预训练的截图解析

视觉位置语言无处不在——来源包括带有图表的教科书到包含图像的网页…

arxiv.org](https://arxiv.org/abs/2210.03347?source=post_page-----b5a826bc2ac3--------------------------------) [## 机器学习实验室 - Ghega 数据集

Ghega 数据集:用于文档理解和分类的数据集,我们提供了一个标注数据集,可以…

machinelearning.inginf.units.it](https://machinelearning.inginf.units.it/data-and-tools/ghega-dataset?source=post_page-----b5a826bc2ac3--------------------------------) [## to-be/donut-base-finetuned-invoices · Hugging Face

编辑模型卡 基于 Donut 基础模型(在论文《无 OCR 文档理解变换器》中介绍)…

huggingface.co](https://huggingface.co/to-be/donut-base-finetuned-invoices?source=post_page-----b5a826bc2ac3--------------------------------)

无 OCR 文档数据提取与变换器(2/2)

原文:towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-2-2-38ce26f41951?source=collection_archive---------1-----------------------#2023-08-10

Donut 与 Pix2Struct 在自定义数据上的对比

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Toon Beerten

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 8 月 10 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供(使用

这两种变换器模型对文档的理解如何?在第二部分中,我将展示如何训练它们并比较它们在关键索引提取任务中的结果。

调整 Donut 模型

所以让我们从 第一部分 开始,在那里我解释了如何准备自定义数据。我将数据集的两个文件夹打包并上传到一个新的 huggingface 数据集 这里。我使用的 Colab 笔记本可以在 这里 找到。它将下载数据集,设置环境,加载 Donut 模型并进行训练。

在微调了 75 分钟后,我在验证指标(即编辑距离)达到 0.116 时停止了:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

在字段级别,我得到这些验证集结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

当我们查看Doctype时,我们发现 Donut 总是正确地将文档识别为专利数据表。因此,我们可以说分类达到了 100% 的准确率。同样需要注意的是,即使我们有一个类别数据表,它也不需要文档上出现这个确切的词来进行分类。对于 Donut 来说,这并不重要,因为它经过微调以这样识别。

其他领域的得分也相当不错,但仅凭这张图表很难了解内部情况。我想看看模型在特定情况下的正确与错误之处。因此,我在我的笔记本中创建了一个例行程序来生成 HTML 格式的报告表。对于我的验证集中的每个文档,我都有这样的行条目:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

左侧是识别(推断)数据及其真实值。右侧是图像。我还使用了颜色代码以便快速概览:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

理想情况下,一切都应该用绿色突出显示。如果你想查看验证集的完整报告,可以在 这里 查看,或者本地下载这个 zip 文件

有了这些信息,我们可以发现常见的 OCR 错误,如Dczcmbci(应为December)或GL420(应为GL420,0 和 O 难以区分),这些错误会导致假阳性。

现在让我们关注表现最差的字段:电压。以下是推断数据、真实值和实际相关文档片段的一些样本。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

问题在于真实值大多是错误的。是否包括单位(Volt 或 V)没有标准。有时会包含无关文本,有时只是一个(错误的!)数字。我现在明白为什么 Donut 会对此感到困难。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

上面是一些 Donut 实际上给出最佳答案的样本,而实际情况是不完整或错误的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

上面是另一个糟糕训练数据混淆 Donut 的好例子。地面真实值中的‘I’字母是 OCR 读取信息前的垂直线的伪影。有时它存在,有时不存在。如果你对数据进行预处理,使其在这方面一致,Donut 将会学习并遵循这种结构。

微调 Pix2Struct

Donut 的结果保持稳定,Pix2Struct 的呢?我用来训练的 Colab 笔记本可以在这里找到。

经过 75 分钟的训练,我得到的编辑距离分数为 0.197,而 Donut 的为 0.116。这显然收敛速度较慢。

另一个观察结果是,到目前为止,每个返回的值都以一个空格开头。这可能是 ImageCaptioningDataset 类中的错误,但我没有进一步调查根本原因。不过,我在生成验证结果时会去掉这个空格。

Prediction: <s_DocType> datasheet</s_DocType></s_DocType> TSZU52C2 – TSZUZUZC39<s_DocType>
    Answer: <s_DocType>datasheet</s_DocType><s_Model>Tszuszcz</s_Model><s_Voltage>O9</s_Voltage>

在 2 小时后我停止了微调过程,因为验证指标再次上升:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

那么这对验证集的字段级别意味着什么呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

这看起来比 Donut 的结果差得多!如果你想查看完整的 HTML 报告,可以在这里查看,或者在本地下载这个 zip 文件

只有在数据表专利之间的分类似乎还不错(但不如 Donut)。其他字段则完全不佳。我们能推断发生了什么吗?

对于专利文档,我看到很多橙色线条,这意味着 Pix2Struct 根本没有返回这些字段。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

即使在专利中返回字段,它们也完全是虚构的。而 Donut 的错误源于从文档的其他区域提取或有轻微的 OCR 错误,Pix2Struct 在这里则是出现了幻觉。

对 Pix2Struct 的表现感到失望,我尝试了几次新的训练以期获得更好的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

我尝试将 accumulate_grad_batches 逐渐从 8 降到 1。但这样学习率过高,会导致超调。将其降低到 1e-5 会使模型无法收敛。其他组合则导致模型崩溃。即使在一些特定的超参数下,验证指标看起来相当不错,但它给出了很多不正确或无法解析的行,例如:

<s_DocType> datasheet</s_DocType><s_Model> CMPZSM</s_Model><s_StorageTemperature> -0.9</s_Voltage><s_StorageTemperature> -051c 150</s_StorageTemperature>

这些尝试都没有给我带来实质性的更好结果,所以我就此停止了。

直到我看到 huggingface 实现中的交叉注意力 bug被修复。因此,我决定再试一次。训练了两个半小时,停在 0.1416 的验证指标上。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

这显然比所有之前的结果都要好。查看 HTML 报告,现在似乎幻觉更少。总体来说,它的表现仍不如 Donut。

至于原因,我有两个理论。首先,Pix2Struct 主要在 HTML 网页图像上训练(预测掩码图像部分后面的内容),并且在切换到另一个领域,即原始文本时,遇到了困难。其次,使用的数据集非常具有挑战性。它包含了许多 OCR 错误和不一致(如包含单位、长度、负号)。在我的其他实验中,我真的发现数据集的质量和一致性比数量更重要。在这个数据集中,数据质量真的很差。也许这就是我无法复制论文中声称 Pix2Struct 超越 Donut 表现的原因。

推断速度

这两种模型在速度方面如何比较?所有训练都在相同的 T4 架构上进行,因此时间可以直接比较。我们已经看到 Pix2Struct 收敛所需的时间要长得多。那么推断时间呢?我们可以比较推断验证集所需的时间:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

Donut 每个文档提取的平均时间为 1.3 秒,而 Pix2Struct 则超过两倍。

要点

  • 对我来说,最终的赢家是 Donut。在易用性、性能、训练稳定性和速度方面。

  • Pix2Struct 训练具有挑战性,因为它对训练超参数非常敏感。它收敛较慢,并且在这个数据集上没有达到 Donut 的结果。可能值得重新考虑使用更高质量的数据集来尝试 Pix2Struct。

  • 由于 Ghega 数据集包含太多不一致性,我将避免在进一步实验中使用它。

是否有其他替代模型?

  • Dessurt,似乎与 Donut 有相似的架构,应该表现相当。

  • DocParser,论文称其表现甚至更好。不幸的是,目前没有计划将该模型发布到未来。

  • mPLUG-DocOwl将很快发布,这是另一个有前景的无 OCR LLM 文档理解工具。

你可能还会喜欢:

[## 实战:使用🍩变压器进行文档数据提取

我使用甜甜圈转换器模型来提取发票索引的经验。

toon-beerten.medium.com

参考文献:

[## Pix2Struct: 截图解析作为视觉语言理解的预训练

视觉定位语言无处不在——来源从带有图表的教科书到带有图像的网页等。

arxiv.org [## 无 OCR 文档理解转换器

理解文档图像(例如发票)是一项核心但具有挑战性的任务,因为它需要复杂的功能,比如……

arxiv.org [## GitHub - Toon-nooT/notebooks

通过在 GitHub 上创建帐户来为 Toon-nooT/notebooks 的开发做贡献。

github.com

哦,你是说“管理变革”?

原文:towardsdatascience.com/oh-you-meant-manage-change-bc9639affab5?source=collection_archive---------7-----------------------#2023-10-20

数据组织中对变革的不同视角

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Marc Delbaere

·

关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 10 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者在布鲁塞尔的 Menssa 餐厅拍摄

变革的不同酿制方式

[场景:现代办公室的休息室。咖啡机的嗡嗡声是唯一的声音,空气中弥漫着新鲜咖啡的香气。CDO Alex 站在咖啡机旁,倒了一杯咖啡。数据工程师 Jamie 走了进来,看起来有些疲惫。]

  • Jamie:“又一天,又一个挑战。你知道,Alex,管理变革开始让我感到疲惫不堪。”

  • Alex(点头):“绝对是这样,Jamie。变革管理现在是我最优先的任务。我们必须确保自己在适应并保持领先。”

  • Jamie(扬起眉毛):“保持领先?我只是尽力让每次变化时事情不会崩溃。”

  • 亚历克斯:“确切地说,这就是预测这些变化并保持领先的关键。我们必须保持团队的动力和一致性。”

  • 杰米(困惑,但尝试认同):“是的,对齐并且……不落后。明白了。”

[谈话流转到其他话题,但他们对“变革管理”的观点差异依然未被说出或承认。]

那么,让我们来分析一下亚历克斯和杰米之间刚刚发生了什么。他们都提到了变革管理这个词,但他们就像在说不同的语言。

我们的首席数据官亚历克斯有远大的目标。她在监测市场变化、新兴技术,并设想公司在未来几年的发展方向。然而,制定战略是简单的,复杂的是让每个人达成共识。

引入一个新工具?她得准备好接受翻白眼和“又一个要学习的软件”的抱怨。一个新流程?准备好接受“但我们一直这么做”的合唱。对亚历克斯来说,变革管理就像走钢丝——平衡公司需要走的方向,同时确保每个人都支持,并且不对他们的工作安全感到恐慌。

然后是杰米。他的变革管理并不是关于未来几年的,而是现在。那个刚刚坏掉的管道?这是他的问题。销售报告中的差异?他的责任。最难的往往不是技术细节,而是人际因素。比如有人忘记告诉他一个微小的“无关紧要”的数据变化,导致一切陷入混乱。或者当任务出现问题时,指责游戏开始。对杰米来说,变革管理就是让事情今天顺利进行,并处理任何突发的问题。

战略视角:首席数据官的愿景

我经常与首席数据官互动,这种对话的多样性是我工作中真正让我欣赏的一方面。每一次对话都是不同的,带来独特的视角。然而,不知为何——也许是因为这些话题我非常关注,或者也许这里才是真正的行动所在——某些共同的主题不可避免地浮现出来。

首先,强调的是推动业务价值。这不仅仅是收集数据或实施最新的技术;关键在于将数据转化为可操作的洞察。确保每一个数据驱动的举措都与公司的目标相一致,无论是增加销售、提升客户满意度还是优化运营。

接下来是对效率的追求。首席数据官(CDO)经常面临改善运营、消除重复工作以及确保数据及时到达所需地点的任务。这不是轻松的工作;它涉及拆除旧有障碍、鼓励团队合作以及跟上新的技术解决方案。

许多首席数据官(CDO)正在倾向于去中心化数据网格的概念。这是从传统的中心化数据团队转变为一个模型的重大变化,在这个模型中,领域团队拥有、生成并提供他们的数据作为产品。这里的思维过程既简单又具有革命性:那些对数据最了解的人应该负责打包和维护数据。这不仅能确保更好的数据质量,还能培养自我消费的文化,赋予组织不同部分更多的自主权。

达成这些目标非常困难。每一个战略目标都带来了变更管理的问题,像亚历克斯这样的 CDO 必须直接面对这些问题。

比如以业务价值为首的议程。对于那些已经在技术任务中工作了多年,甚至几十年的数据专业人士来说,转移关注到业务成果上可能会让人感到不适。他们已经被训练成以数据准确性、系统集成和代码优化为思考的方式。要求他们“以业务价值为思考方式”常常会遇到困惑的目光!

还有去中心化的趋势,这在纸面上无疑是一个好主意:赋予团队权力,让他们承担责任,组织变得更加敏捷和高效。实际上,这意味着大量的变化需要被管理。去中心化带来了明确角色和责任的挑战。当每个人都是利益相关者时,任务很容易被忽视。谁负责数据质量?谁确保数据对需要的人是可访问的?没有明确的界定,问题会被遗漏,责备游戏就会开始。

从本质上讲,对于每一个战略转变,都存在着一个隐含的变更管理复杂网络。这不仅仅是绘制路线图,更要确保每个人都理解自己的角色,具备执行任务的能力,并且致力于前进的旅程。

实际情况:日常挑战

虽然亚历克斯作为 CDO 的角色主要关注大局,驾驭广泛而不可预测的情境,但变更管理还有另一面。这体现在像杰米这样的数据工程师面临的日常详细挑战中。在他们的领域中,变更管理不是关于长期战略或总体业务目标。相反,它关注于确保数据在不断变化的背景下保持一致和可访问的持续、每时每刻的障碍。

首先,组织中的大部分数据是作为副产品产生的。随着各种业务活动的展开,数据自然地积累,就像机器的废气一样。然而,虽然这些数据对于生成它的人来说可能只是副产品,但对于数据团队及其内部和外部的下游客户来说,这些数据成为了他们日常运营的核心。讽刺的是,在源头,这种‘废气’往往被忽视,尽管它对于链条上的这些利益相关者来说是不可或缺的。

可以将其视为在不稳定的地面上为高楼大厦奠定基础。地球下方总在移动,但你的任务是确保上面的庞大结构保持稳定。这是许多数据工程师和商业智能(BI)分析师的世界。他们站在前线,每天处理数据的异常行为。

他们面临的一个重大问题是交织在数据世界中的复杂依赖网络。数据从一个平台移动到另一个平台,经历转化,与其他数据集合并,最后到达预期的位置。这个过程的每个阶段都可能出现故障。在一个平台上进行的小调整可能会产生连锁反应,导致其他地方的干扰。而最具挑战性的一点是?通常,做出这些更改的人对他们可能引发的连锁反应毫无察觉。

面对不断变化的数据链带来的日常问题,思想领袖们提出了新的概念。他们首先引入了数据产品,这些数据产品打包数据以便于消费,类似于商店老板将商品展示给顾客。但随着更多人开始使用这些数据产品,出现了对正式承诺的需求——一种确保数据产品所有者可靠服务其用户的方式。这一认识促使了数据合同的制定,以确保这些义务得到履行。

数据合同作为打包数据以供使用者使用的人员与他们服务的消费者之间的桥梁,记录和执行明确的承诺:数据模式的不可变性、质量和可用性的标准等。它们优雅地解决了在数据依赖链中管理变化的问题。

弥合分歧:统一的变革视角

战略转型的挑战和管理不稳定依赖的挑战乍一看似乎大相径庭。像 Alex 这样的首席数据官(CDO)负责指导整体战略,并使整个组织朝着共同愿景对齐。同时,Jamie 是处理日常数据挑战的“消防员”。然而,在这两种观点的核心,有一些统一的原则可以弥合分歧。

透明度至关重要。无论是 Alex 沟通战略举措的广泛目标,还是 Jamie 标记架构变更可能带来的下游影响,清晰和开放的沟通可以预防许多问题。

协作确保一致性。数据组织中的每个人都需要保持同步。在日常层面,这意味着有效沟通以防止意外的麻烦。在战略层面,则是确保每个人对广泛目标有清晰认识,确保日常任务和整体计划朝着相同的方向推进。

标准化提供了稳定性。引入像数据合同这样的实践不仅解决了 Jamie 面临的细节挑战,还巩固了 Alex 战略愿景建立的基础。通过建立明确的标准,我们消除了模糊性,使得既有大局观的思考者也有注重细节的执行者能够协同朝同一方向前进。

最终的讽刺在于,为了解决 Jamie 的日常问题(即不透明依赖链中的意外后果),你需要将这个话题提升为战略优先事项。

如果你想取得成功,鉴于过程、人力和技术的变革,你需要应用所有由 Alex 倡导的变革管理良好原则。当然,Jamie 在这里扮演着至关重要的角色,他是最接近问题及其后果的人,因此他可以成为变革推动者,让他的同事和管理层参与其中。

所以,最初的互惠互利实际上可能是战略转型的开始:明确的行动理由、路线图、合适的领导者和工具。

即使是管理变革,你也需要变革管理!

好的,你已经训练了最好的机器学习模型。接下来做什么?

原文:towardsdatascience.com/okay-youve-trained-the-best-machine-learning-model-what-s-next-e7b8f167006e

数据科学

一个超越 Jupyter Notebook 建模的 MLOps 项目

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Albers Uzila

·发表于 Towards Data Science ·阅读时间 18 分钟·2023 年 6 月 4 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Elena MozhviloUnsplash

***Table of Contents***
**·** **Initialize a Repository**
**·** **Migrate Your Codebase**
  ∘ config/config.py
  ∘ config/args.json
  ∘ tagolym/utils.py
  ∘ tagolym/data.py
  ∘ tagolym/train.py
  ∘ tagolym/predict.py
  ∘ tagolym/evaluate.py
  ∘ tagolym/main.py
**·** **Package Your Codebase** **·** **Setup Data Source Credential** **·** **Run Your Pipeline** **·** **Miscellaneous** **·** **Push Your Project to GitHub** **·** **Wrapping Up**

假设你正在构建一个数据科学项目,可能是为了工作、大学、作品集、爱好或其他任何目的。你已经花费了很多时间来解决问题陈述,并在 Jupyter notebooks 中进行实验。现在,你在想,“我怎么将我的工作部署成一个有用的产品?”。

具体来说,假设你有一个托管论坛的网站。用户可以给论坛中的线程添加标签,以方便在不同主题的论坛之间导航。你希望通过建议预定义的标签来改善用户体验,从而为讨论提供背景。

论坛可以是任何形式的,因此让我们更具体一点;它通常以一个 帖子 开始,解释一个数学问题,接着是围绕这个问题的想法、问题、提示或答案。以下是一个线程的样子及其三个标签,即 inductioncombinatorics unsolvedcombinatorics

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

论坛中的一个帖子示例 | 图片由 author 提供

此时,你已经在你的 notebooks 中完成了所有工作,从理解问题陈述、定义指标、查询数据、清理数据、预处理、EDA、构建模型到评估和优化模型。

你会注意到有很多帖子有着大量的标签。为了简化,你只筛选了 10 个标签。你开发的模型是简单的线性分类器(SVM、逻辑回归等),前面经过 TF-IDF 向量化,并用随机梯度下降(SGD)进行训练。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

前 30 个频繁标签计数 | 图片由作者提供

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终的标签分布。注意几何标签是最常见的 | 图片由作者提供

虽然笔记本非常好,并且可以帮助你非常快速地进行实验,但它们并不适合生产环境,并且有时很难维护。因此,你需要将代码迁移到独立的 Python 文件中,然后逐步添加其他工具,同时与团队成员合作。

这个故事将引导你通过简明的步骤完成这项工作。在此之前,你可能想要刷新一下关于线性模型、TF-IDF 和 SGD 的知识:

## 线性回归、逻辑回归和 SVM 在 10 分钟内

线性回归与逻辑回归和支持向量机有什么关系?

towardsdatascience.com ## 你需要了解的词袋模型和 Word2Vec — 文本特征提取

为什么 Word2Vec 更好,但为什么它还不够好

towardsdatascience.com ## 从头开始的完整梯度下降算法步骤

以及其对常数学习率和线性搜索的实现

towardsdatascience.com

初始化一个仓库

首先,让我们在GitHub上创建一个名为tagolym-ml的新仓库,并配有README.md.gitignoreLICENSE

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

创建新的 GitHub 仓库 | 图片来源 author

要使用这个代码库,请执行以下步骤:

  1. 克隆代码库,将创建一个名为tagolym-ml的文件夹。

  2. 将工作目录更改为此文件夹。

  3. 创建一个名为venv的虚拟环境。

  4. 激活环境。

  5. 升级pip

  6. 可选地,你可以使用pip list检查当前环境中已安装的包,其中会有pipsetuptools

  7. 创建一个名为code_migration的新 git 分支并切换到它。

  8. 创建一个setup.py文件。

  9. 创建一些名为configtagolymcredentials的新文件夹。

  10. config文件夹内创建config.pyargs.json文件。

  11. tagolym文件夹内创建main.pyutils.pydata.pytrain.pyevaluate.pypredict.py文件。

如果你不知道如何做这些,不用担心。这里是你可以在喜欢的终端上运行的所有命令:

$ git clone https://github.com/dwiuzila/tagolym-ml.git
$ cd tagolym-ml
$ python3 -m venv venv
$ source venv/bin/activate
$ python3 -m pip install --upgrade pip
$ pip list
Package    Version
---------- -------
pip        23.1.2
setuptools 58.0.4
$ git checkout -b code_migration
$ touch setup.py
$ mkdir config tagolym credentials
$ touch config/config.py config/args.json
$ cd tagolym
$ touch main.py utils.py data.py train.py evaluate.py predict.py
$ cd ..

你现在有一个本地 git 仓库,已连接到 GitHub 上的远程仓库。当地仓库的目录将如下所示。

config/
├── args.json        - preprocessing/training parameters
└── config.py        - configuration setup
credentials/         - keys and passwords
tagolym/
├── data.py          - data processing components
├── evaluate.py      - evaluation components
├── main.py          - training/optimization pipelines
├── predict.py       - inference components
├── train.py         - training components
└── utils.py         - supplementary utilities
venv/                - virtual environment
.gitignore           - files/folders that git will ignore
LICENSE              - project license
README.md            - longform description of the project
setup.py             - code packaging

目前几乎所有这些文件都是空的。你将一个一个地填写它们,从config文件夹开始。

迁移你的代码库

你的项目中有两个主要文件夹,即configtagolym。你需要将笔记本中的必要代码复制到这些文件夹中的文件中。我们来做吧。

config/config.py

在这里,你定义了与种子、目录、实验跟踪、预处理和标签名称相关的全局变量。

当这个文件在你的代码中被导入时,如果尚未创建,它将创建两个新文件夹:

  1. data,用于存储项目的标记数据,

  2. stores/model,用于存储模型注册,

然后将stores/model连接到用于实验跟踪的 MLflow 跟踪 URI。

你还在这里定义了停用词和额外的命令词。停用词将默认为nltk包中的词汇,而命令词为["prove", "let", "find", "show", "given"],这些词经常出现在帖子中,但不提供任何有用的信号。

正则表达式用于预处理。它们看起来很吓人,但你不需要理解它们。它们的基本功能是捕捉任何数学表达式渐近线语法的 LaTeX,这些在数学问题的帖子中是基础和核心。

最后,记住你只选择了 10 个入围标签进行处理?你在这个文件中列出了所有这些标签。一些标签与您的标签有类似的含义(例如“inequalities” → “inequality”),因此你也有 10 个部分标签来捕获这些标签并用适当的标签替换它们。请参见下面的tagolym/data.py,了解如何操作。

config/args.json

这是你存储整个过程的初始参数的地方。它们来自管道的不同部分。

它们是什么意思?

  1. nocommandstem —— 处理帖子时的布尔值,是否排除命令词和实现词干提取

  2. ngram_max_range —— 在TF-IDF 向量化过程中提取不同n-gram 的n值范围的上限。

  3. lossl1_ratioalphalearning_rateeta0power_t —— 用于SGD 分类器的模型的超参数。

tagolym/utils.py

流水线有些复杂,因此你需要一些实用函数和 Python 类来简化代码。这个文件包含了这些:

  1. load_dictsave_dict —— 从 JSON 文件中加载字典,或将字典转储到 JSON 文件中。

  2. NumpyEncoder —— 将包含 Numpy 实例的对象编码为 Python 内置实例,用于save_dict

  3. IterativeStratification —— 当你处理像这个项目这样的多标签分类时,普通的训练-测试划分方法对于数据并不理想。相反,你需要我们所说的迭代分层,它旨在提供在给定阶数下标签关系证据的良好平衡分布。在这个项目中,阶数设置为 2。

tagolym/data.py

与数据相关的所有函数都写在这个文件中,包括数据分割、预处理和转换。

  1. preprocess —— 从包含部分标签的标签创建映射到config/config.py中定义的 10 个标签之一,然后对所有帖子和标签进行文本处理。这个函数还会在文本处理后删除所有空帖子样本。

  2. binarize — 根据模型要求,如果你正在处理多标签分类问题,你可能需要对标签进行二值化。此函数将标签转换为一个大小为(#样本 × #标签)的二进制矩阵,指示标签中标签的存在。例如,包含两个标签["algebra", "inequality"]的标签将被转换为[1, 0, 0, 0, 0, 1, 0, 0, 0, 0]。除了返回转换后的标签,它还返回稍后使用的[MultiLabelBinarizer](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MultiLabelBinarizer.html)对象,特别是在将矩阵转换回标签时。

  3. split_data — 使用tagolym/utils.py中的IterativeStratification,此函数将帖子和标签拆分为 3 部分,比例为 70/15/15,分别用于模型训练、验证和测试。

tagolym/train.py

最佳实践是将模型训练、验证和测试放在不同的文件中。正如文件名所示,你在这里进行所有的训练。由于你希望用户能自信地使用模型的标签推荐,你需要降低假阳性率。

另一方面,目前假阴性并不是你的首要任务。为了说明这一点,我们来看看一个极端的例子:模型将所有 10 个标签预测为负值,因此没有推荐标签,你会有大量的假阴性。但用户可以毫不犹豫地创建自己的标签。这没什么大不了的。

所以,你的目标是拥有一个高精度的模型。

现在,让我们讨论一下这个文件的内容:

  1. train — 预处理数据,将标签二值化,并使用tagolym/data.py中的函数拆分数据。然后,初始化一个模型,训练它,使用训练好的模型对所有三个数据拆分进行标签预测,并评估预测结果。这个函数接受args,其中包含config/args.json中的所有参数,返回时可能会添加一个额外的参数threshold。基本上,threshold是通过tune_threshold计算出的每个标签的最佳阈值列表。

  2. objective — f1 分数是超参数调整中选择的优化指标。使用试验中选择的args,此函数训练模型并返回验证集的 f1 分数。它还为试验设置了额外的属性,包括所有三个数据拆分的精确度、召回率和 f1 分数。

  3. tune_threshold — 二分类问题的默认决策边界是 0.5,但这可能并不是最优的,具体取决于问题。因此,除了调整args,你还需要在优化 f1 分数时调整每个标签的阈值。它的作用是尝试从 0 到 1 的网格中所有可能的阈值,并选择具有最大 f1 分数的阈值。

tagolym/predict.py

模型训练之后该做什么?预测!这个文件中有两个函数:

  1. custom_predict — 如果模型具有 predict_proba 属性,则此函数将预测每个标签作为标签的概率。否则,使用 0.5 阈值直接预测标签。在前一种情况下,如果提供了真实标签,函数将使用 tagolym/train.py 中的 tune_threshold 来调整阈值。

  2. predict — 加载 args、标签二值化器和训练好的模型。然后,预处理给定的文本,并使用 custom_predict 对其进行预测。之后,将预测矩阵转换回标签。

tagolym/evaluate.py

给定预测和真实标签矩阵,本文件的目的是计算精度、召回率、F1 分数和样本数量。性能是根据总体样本、每类样本和每个切片样本计算的。你考虑了 8 个切片:

  1. 短帖,即经过预处理后少于 5 个单词的帖子,

  2. 六个切片,其中帖子被标记为子主题但未标记为覆盖子主题的更大主题,以及

  3. 不包含频繁出现的四字或更多字的帖子。

tagolym/main.py

这是运行所有任务的主要文件。这里有 5 个函数和你需要在其中编写的指令:

  1. elt_data — 查询标记数据并以 JSON 格式保存到data文件夹中。

  2. train_model — 从data文件夹加载标记数据并训练模型。不要忘记使用 MLflow 记录指标、工件和参数。还要将 MLflow run_id和指标保存到config文件夹中。

  3. optimize — 从data文件夹加载标记数据并优化给定的参数。为了提高搜索效率,优化分为两个步骤:a) 预处理、向量化和建模中的超参数;b) 学习算法中的超参数。还要根据目标将最佳参数保存到config文件夹中,命名为args_opt.json

  4. load_artifacts — 将特定 run_id 的工件加载到内存中,包括参数、指标、模型和标签二值化器。

  5. predict_tag — 给定特定的 run_id,使用预加载的工件预测接收到的每个文本的标签。

唷!你刚刚完成了所有迁移工作。现在,如何使用这些代码?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Jason Strull 提供的照片,来源于 Unsplash

打包你的代码库

当你使用笔记本时,你有一个预加载的包集合用于实验。为了在本地重现并部署到生产环境,你希望明确地定义你的环境。

你在这个项目中导入了许多开源包,但你的环境中只有pipsetuptools。因此,在运行管道之前,你需要安装这些包。

下面是一个方便的命令来实现这一点。注意,我在最后添加了 [pip-chill](https://pypi.org/project/pip-chill/) 以便于后续清理生成的包要求文件。

$ pip install mlflow nltk regex scikit-learn snorkel joblib optuna pandas google-cloud-bigquery google-auth numpy scipy pip-chill

pip-chill 的一个很酷的特点是,它可以生成一个不包含文件中依赖于其他包的包的要求文件,使得要求文件干净且准确。让我们运行一下。

$ pip-chill --no-chill > requirements.txt

这将创建一个 requirements.txt 文件,包含你实际需要的所有包。请注意,因为这些包是已经列在文件中的包的依赖项,所以文件中没有 pandasscikit-learnregex 以及其他几个包。

现在你将使用 setup.py 打包你的代码库,将所有依赖项打包在一起。在这个文件中,加载你在 requirements.txt 中的所有库,并使用 setuptools 中的 setup 函数定义你的包。

你的包名将是 tagolym。你可以在下面的代码中看到其他细节,如版本和描述。你从 requirements.txt 中加载的库将用于 install_requires 参数,并成为 tagolym 的依赖项。

然后你可以使用下面的命令安装 tagolym。这将创建一个名为 tagolym.egg-info 的新文件夹,包含项目的元数据。

$ python3 -m pip install -e .

请注意,-e--editable 标志会从本地项目路径以可编辑模式安装包。换句话说,如果你在当前工作目录中使用一些函数,例如使用 from tagolym import main,然后对 tagolym/main.py 进行一些更改,你将能够使用这个更新版本,而无需使用 pip install 重新安装你的包。

设置数据源凭证

有一个小问题。这些项目中使用的数据是我自己的数据,存储在我的 BigQuery 中。在创建并下载一个 服务账户密钥 后,我将其重命名为 bigquery-key.json,并将其放置在 credentials 文件夹中。

要访问数据,你需要我的凭证,但不幸的是,这些凭证不能共享。不过不用担心,我会提供样本供你使用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

创建服务账户密钥 | 图片由 作者 提供

你需要做的很简单:下载样本 labeled_data.json 在这里 并将文件保存在工作目录中名为 data 的文件夹里。

运行你的管道

现在你准备好了!在终端中输入 python3 命令,你就可以运行 Python 中的所有内容。你只需使用 tagolym/main.py 文件。

首先,我使用我的凭证和 elt_data 函数查询数据。当我看到 ✅ Saved data! 时,我知道过程顺利完成。如上所述,你可以跳过这一步,手动将我提供的样本放入 data 文件夹中。

然后,您可以使用 optimize 函数来优化模型,通过读取初始参数 config/args.json。我将试验次数设置为 10,但您可以尝试其他设置。由于您有一个两步优化过程,所以将创建一个新的 MLflow 研究,总共 20 次试验。找到的最佳验证 f1 分数是 0.7730。

使用一组优化后的参数 config/args_opt.json,您可以再次使用 train_model 函数训练模型,并使用 predict_tag 函数对文本列表进行推断。您可以看到下面的预测非常准确!

$ python3
>>> from pathlib import Path
>>> from config import config
>>> from tagolym import main
>>>
>>> # query data
>>> key_path = "credentials/bigquery-key.json"
>>> main.elt_data(key_path)
✅ Saved data!
>>>
>>> # optimize model
>>> args_fp = Path(config.CONFIG_DIR, "args.json")
>>> main.optimize(args_fp, study_name="optimization", num_trials=10)
2023/06/03 17:42:12 INFO mlflow.tracking.fluent: Experiment with name 'optimization' does not exist. Creating a new experiment.
[I 2023-06-03 17:41:45,657] A new study created in memory with name: optimization
[I 2023-06-03 17:42:12,343] Trial 0 finished with value: 0.7519199358796977 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.6011150117432088, 'alpha': 0.001331121608073689}. Best is trial 0 with value: 0.7519199358796977.
[I 2023-06-03 17:42:38,441] Trial 1 finished with value: 0.7629559140596291 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.43194501864211576, 'alpha': 7.476312062252303e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:42:57,713] Trial 2 finished with value: 0.7511576441724478 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 3, 'loss': 'hinge', 'l1_ratio': 0.5924145688620425, 'alpha': 1.3783237455007187e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:43:19,108] Trial 3 finished with value: 0.7106573336158825 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 4, 'loss': 'hinge', 'l1_ratio': 0.6842330265121569, 'alpha': 0.00020914981329035596}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:43:37,349] Trial 4 finished with value: 0.741392879377292 and parameters: {'nocommand': False, 'stem': False, 'ngram_max': 2, 'loss': 'hinge', 'l1_ratio': 0.5467102793432796, 'alpha': 3.585612610345396e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:04,235] Trial 5 finished with value: 0.7426444422157734 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 3, 'loss': 'hinge', 'l1_ratio': 0.045227288910538066, 'alpha': 9.46217535646148e-05}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:30,104] Trial 6 finished with value: 0.7337258988967691 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 2, 'loss': 'modified_huber', 'l1_ratio': 0.07455064367977082, 'alpha': 0.009133995846860976}. Best is trial 1 with value: 0.7629559140596291.
[I 2023-06-03 17:44:51,778] Trial 7 finished with value: 0.7700323704566581 and parameters: {'nocommand': True, 'stem': False, 'ngram_max': 4, 'loss': 'log_loss', 'l1_ratio': 0.3584657285442726, 'alpha': 2.2264204303769678e-05}. Best is trial 7 with value: 0.7700323704566581.
[I 2023-06-03 17:45:18,125] Trial 8 finished with value: 0.7559495178348377 and parameters: {'nocommand': True, 'stem': True, 'ngram_max': 2, 'loss': 'log_loss', 'l1_ratio': 0.8872127425763265, 'alpha': 0.00026100256506134784}. Best is trial 7 with value: 0.7700323704566581.
[I 2023-06-03 17:45:47,029] Trial 9 finished with value: 0.7730089901544794 and parameters: {'nocommand': False, 'stem': True, 'ngram_max': 4, 'loss': 'log_loss', 'l1_ratio': 0.02541912674409519, 'alpha': 2.1070472806578224e-05}. Best is trial 9 with value: 0.7730089901544794.
[I 2023-06-03 17:45:47,056] A new study created in memory with name: optimization
[I 2023-06-03 17:46:16,061] Trial 0 finished with value: 0.7730089901544794 and parameters: {'learning_rate': 'optimal'}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:46:48,008] Trial 1 finished with value: 0.7701884982320516 and parameters: {'learning_rate': 'adaptive', 'eta0': 0.15930522616241014}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:47:18,651] Trial 2 finished with value: 0.7331091235928242 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.0265875439832727, 'power_t': 0.17272998688284025}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:47:49,429] Trial 3 finished with value: 0.7196639813595901 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.038234752246751866, 'power_t': 0.34474115788895177}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:48:21,601] Trial 4 finished with value: 0.7727673901952036 and parameters: {'learning_rate': 'adaptive', 'eta0': 0.3718364180573207}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:48:51,330] Trial 5 finished with value: 0.7576010292654753 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.16409286730647918, 'power_t': 0.16820964947491662}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:49:21,906] Trial 6 finished with value: 0.7428637006524251 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.040665633135147955, 'power_t': 0.13906884560255356}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:49:52,034] Trial 7 finished with value: 0.746701310091385 and parameters: {'learning_rate': 'constant', 'eta0': 0.011715937392307063}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:50:21,383] Trial 8 finished with value: 0.7683160697730758 and parameters: {'learning_rate': 'constant', 'eta0': 0.10968217207529521}. Best is trial 0 with value: 0.7730089901544794.
[I 2023-06-03 17:50:51,373] Trial 9 finished with value: 0.7338062675694838 and parameters: {'learning_rate': 'invscaling', 'eta0': 0.7568292060167615, 'power_t': 0.4579309401710595}. Best is trial 0 with value: 0.7730089901544794.
Best value (f1): 0.7730089901544794
Best hyperparameters: {
  "nocommand": false,
  "stem": true,
  "ngram_max": 4,
  "loss": "log_loss",
  "l1_ratio": 0.02541912674409519,
  "alpha": 2.1070472806578224e-05,
  "learning_rate": "invscaling",
  "eta0": 0.7568292060167615,
  "power_t": 0.4579309401710595,
  "threshold": [
    0.59,
    0.79,
    0.55,
    0.7000000000000001,
    0.5,
    0.72,
    0.76,
    0.63,
    0.7000000000000001,
    0.77
  ]
}
>>>
>>> # train model
>>> args_fp = Path(config.CONFIG_DIR, "args_opt.json")
>>> main.train_model(args_fp, experiment_name="baselines", run_name="sgd")
2023/06/03 17:52:01 INFO mlflow.tracking.fluent: Experiment with name 'baselines' does not exist. Creating a new experiment.
Run ID: fbdba0c7cab640bc853611ba6cd75cee
>>> text = [
...     "Let $c,d \geq 2$ be naturals. Let $\{a_n\}$ be the sequence satisfying $a_1 = c, a_{n+1} = a_n^d + c$ for $n = 1,2,\cdots$.Prove that for any $n \geq 2$, there exists a prime number $p$ such that $p|a_n$ and $p \not | a_i$ for $i = 1,2,\cdots n-1$.",
...     "Let $ABC$ be a triangle with circumcircle $\Gamma$ and incenter $I$ and let $M$ be the midpoint of $\overline{BC}$. The points $D$, $E$, $F$ are selected on sides $\overline{BC}$, $\overline{CA}$, $\overline{AB}$ such that $\overline{ID} \perp \overline{BC}$, $\overline{IE}\perp \overline{AI}$, and $\overline{IF}\perp \overline{AI}$. Suppose that the circumcircle of $\triangle AEF$ intersects $\Gamma$ at a point $X$ other than $A$. Prove that lines $XD$ and $AM$ meet on $\Gamma$.",
...     "Find all functions $f:(0,\infty)\rightarrow (0,\infty)$ such that for any $x,y\in (0,\infty)$, $$xf(x²)f(f(y)) + f(yf(x)) = f(xy) \left(f(f(x²)) + f(f(y²))\right).$$",
...     "Let $n$ be an even positive integer. We say that two different cells of a $n \times n$ board are [b]neighboring[/b] if they have a common side. Find the minimal number of cells on the $n \times n$ board that must be marked so that any cell (marked or not marked) has a marked neighboring cell."
... ]
>>> main.predict_tag(text=text)
[
  {
    "input_text": "Let $c,d \\geq 2$ be naturals. Let $\\{a_n\\}$ be the sequence satisfying $a_1 = c, a_{n+1} = a_n^d + c$ for $n = 1,2,\\cdots$.Prove that for any $n \\geq 2$, there exists a prime number $p$ such that $p|a_n$ and $p \not | a_i$ for $i = 1,2,\\cdots n-1$.",
    "predicted_tags": [
      "number theory"
    ]
  },
  {
    "input_text": "Let $ABC$ be a triangle with circumcircle $\\Gamma$ and incenter $I$ and let $M$ be the midpoint of $\\overline{BC}$. The points $D$, $E$, $F$ are selected on sides $\\overline{BC}$, $\\overline{CA}$, $\\overline{AB}$ such that $\\overline{ID} \\perp \\overline{BC}$, $\\overline{IE}\\perp \\overline{AI}$, and $\\overline{IF}\\perp \\overline{AI}$. Suppose that the circumcircle of $\triangle AEF$ intersects $\\Gamma$ at a point $X$ other than $A$. Prove that lines $XD$ and $AM$ meet on $\\Gamma$.",
    "predicted_tags": [
      "geometry"
    ]
  },
  {
    "input_text": "Find all functions $f:(0,\\infty)\rightarrow (0,\\infty)$ such that for any $x,y\\in (0,\\infty)$, $$xf(x²)f(f(y)) + f(yf(x)) = f(xy) \\left(f(f(x²)) + f(f(y²))\right).$$",
    "predicted_tags": [
      "algebra",
      "function"
    ]
  },
  {
    "input_text": "Let $n$ be an even positive integer. We say that two different cells of a $n \times n$ board are [b]neighboring[/b] if they have a common side. Find the minimal number of cells on the $n \times n$ board that must be marked so that any cell (marked or not marked) has a marked neighboring cell.",
    "predicted_tags": [
      "combinatorics"
    ]
  }
]
>>> exit()

您可以在美观的 MLflow UI 中查看您的实验:

$ mlflow ui --backend-store-uri stores/model

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

MLflow 用户界面 | 图片由 作者 提供

这些过程中的一些在后台创建了新的文件,大多数是模型训练的输出。您可以在下一节中解释的 README.md 文件中查看当前的项目目录。

杂项

您项目的安全性至关重要。因此,凭据应仅存在于本地仓库中;您不希望将其推送到 GitHub。

作为预防措施,在 .gitignore 文件末尾添加 credentials/。这将忽略您在开发项目时对 credentials 文件夹所做的任何更改。

其他您可能想要添加到 .gitignore 的内容包括 data/stores/,因为它们可能包含敏感信息或占用大量磁盘空间。如果您使用的是 macOS,还需添加 .DS_Store。这是一个存储其所在文件夹自定义属性的文件,对您的项目没有用处。

完成所有这些之后,您可以选择更新 README.md 中的项目描述。只需输入您在这个故事中完成的高层次过程,以便每个人都可以轻松复制您的工作。这可能看起来是这样的。

将您的项目推送到 GitHub

您的项目很酷,但它对其他人有用吗?要回答这个问题,您可以开源您的项目,以便每个人都可以从中受益,提供反馈,甚至贡献。这样做非常简单。

您需要的是下面的三个命令:

  1. 将您所做的每一项更改添加到 Git 索引中。

  2. 将索引中的更改提交到本地仓库,并

  3. 将本地仓库推送到远程,这将创建一个新的分支 code_migration 在远程仓库中。

$ git add .
$ git commit -m "Initial commit"
$ git push origin code_migration

您可以在 这里 查看结果。

了解更多关于 Git 的信息:

## 作为数据科学家使用 Git 命令的真实案例研究

完成分支说明

towardsdatascience.com

总结

恭喜你!🍻 你已经阅读完了这个故事。你学会了如何将你的数据科学实验从 Jupyter Notebook 转化为一个干净且可维护的项目。除此之外,你还知道了如何打包你的项目,运行整个数据管道,并使用 GitHub 和 BigQuery。

不过,这只是你 MLOps 旅程的开始。还有很长的路要走。敬请关注!📻

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Matese FieldsUnsplash 提供

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

🔥 你好!如果你喜欢这个故事并想支持我作为一个作家,可以考虑 成为会员。每月只需 $5,你就可以无限制访问 Medium 上的所有故事。如果你通过我的链接注册,我将获得一小笔佣金。

🔖 想了解更多关于经典机器学习模型如何运作以及如何优化其参数的信息?或者 MLOps 大型项目的示例?还有精选的顶尖文章?继续阅读:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Albers Uzila

MLOps 大型项目 - 第二部分

查看列表3 篇故事!外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Albers Uzila

从零开始的机器学习

查看列表8 篇故事!外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Albers Uzila

高级优化方法

查看列表7 篇故事!外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Albers Uzila

我的最佳故事

查看列表10 个故事外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Albers Uzila

R 中的数据科学

查看列表7 个故事外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

关于 A/B 测试和携带效应

原文:towardsdatascience.com/on-ab-tests-and-carryover-effect-43668dbd52e2?source=collection_archive---------11-----------------------#2023-05-23

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Ron Hansen 提供,来源于 Unsplash

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Denis Vorotyntsev

·

关注 发布于 Towards Data Science · 7 分钟阅读 · 2023 年 5 月 23 日

在复杂的数据驱动决策世界中,A/B 测试脱颖而出,成为一个强大的工具,帮助企业优化策略和改善用户体验。但当一个测试的效果渗透到下一个测试中时,会发生什么情况呢?这会使结果变得模糊不清,扭曲结果。

这种现象被称为“滞后效应”,可能对理解测试中变更的真实影响构成重大挑战。在本文中,我们将深入探讨 A/B 测试和滞后效应的细微差别,讨论有效管理这种现象的策略。我们将探索用户分组的机制、分桶技术以及如何识别和解决滞后效应,以确保你的 A/B 测试提供可靠的、可操作的结果。

用户与桶

AB 测试是比较两个版本功能的基础方法,通常用于确定哪一个表现更好。为了执行这些测试,我们通常将用户分成两个组——对照组和处理组,基于用户 ID。为了简化,我们可以将所有“偶数”用户分配给对照组,将所有“奇数”用户分配给处理组。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

直观的 AB 测试设置:所有用户被分成对照组和处理组

初步步骤涉及估算样本大小——根据我们选择的指标(如点击率或每用户平均收入)来确定需要收集的用户或事件数量。这些估算考虑了方差(基于这些指标的历史观察)和预期效果(基于建议模型的离线结果)。在收集了足够的数据后,我们深入进行统计分析,如使用 t 检验比较对照组和处理组的平均收入,以确定表现更好的模型。

然而,这种方法面临几个障碍:

  1. 同时多重测试:同时启动多个测试成为挑战。例如,如果一个新模型需要测试,没有剩余流量来容纳这个测试。解决方案是暂停当前的 AB 测试,并将用户分成三组。但如果我们不知道未来将运行多少个变体呢?

  2. 可扩展性问题:处理广泛的用户基础时,扩展分析成为一项艰巨的任务。即使在用户级别缓存结果,对最近几天的 AB 测试进行统计计算也可能非常费力,尤其是在处理大量用户时。

为了绕过这些问题,我们采用了一种称为“分桶”的技术。

分桶方法

分桶将几个用户组合成一个称为“桶”的单元。你可以把这个桶看作是一个“元用户”。为了确定桶的 ID,我们使用以下公式:

bucket_id = hash(user_id + salt) % number_of_buckets

在这里,salt是一个固定的随机字符串,而number_of_buckets是系统的预定义参数。根据系统设计,桶的 ID 可以实时计算(当用户访问网站时)或在用户访问网站时计算一次并存储在键值存储中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分桶理念:用户根据上述公式被分配到不同的桶中。处理是在用户级别应用的,但分析是在桶级别进行的。

在启动 AB 测试之前,我们估算要分配给对照组和处理组的桶数量。例如,如果总桶数为 1000,我们将桶 0–99 的用户分配给对照模型,将桶 100–199 的用户分配给处理模型。这为两个模型提供了 10%的流量。

分桶允许我们在桶级别分析结果,而不是用户级别,这使我们可以在桶级别缓存指标,从而消除了繁重的重新计算需求。

了解延续效应

想象一个场景,你正在为一个大型电子商务网站设计一个新的推荐模型,使用高级神经网络。这个模型包括一个创新的功能,即 OpenAI GPT3 API 调用,用于生成商品标题的嵌入。该模型的离线结果显示出显著的性能提升,因此决定进行在线测试。

AB 测试的结构是跨越一周,将网站流量的 10%分配给对照组,10%分配给处理组。目标是比较两组之间的点击数量,以确定哪个模型表现更好。

然而,在 AB 测试上线几小时后,所有指标出现了令人担忧的下降。深入分析数据发现,由于复杂的模型,页面加载时间显著增加。这是一个意料之外的问题,在离线测试中没有遇到,也未在在线测试中考虑到。

针对加载时间缓慢的问题,分配到处理组的用户变得沮丧,导致一些用户减少使用或完全流失。这一不幸事件扰乱了用户在桶中的分布平衡,这是新 AB 测试中未考虑的一个方面。

这种持续的不平衡,即初始测试的直接后果,被称为“延续效应”。它发生在同一组用户在多个测试中不断经历变化时。本质上,由于分桶分配中使用了一致的盐或种子,桶“记住”了之前的 AB 测试,从而影响了后续测试的结果。

当用户行为因先前的处理而改变时,延续效应变得特别明显。例如,如果正在测试的新功能需要用户学习曲线,处理组的成员可能由于早期接触而更快适应,从而在其他组用户中获得优势。

在大规模和成熟的系统中,即使是 1%的微小变化也可能意味着数百万的收入,因此这一效应变得极为重要。数据科学家和机器学习工程师通常力求在指标上获得 0.1%的提升。然而,即使是轻微的延续效应也可能使多个 AB 测试失效,从而导致一个次优的模型被采纳或一个优质的模型因 AB 分组偏差而被弃用。

识别问题

为识别这一问题,数据科学家应定期进行 AA 测试。设计良好的 AB 测试系统应在 AA 测试中产生均匀的 p 值分布。AA 测试中 p 值直方图的不均匀性表明桶存在不平衡,可能由各种因素造成,包括带来的影响。

应对问题的策略

重排所有用户

避免桶内存储问题的最快且最简单的解决方案是定期更换盐值。这能确保用户在每次重排后在桶内随机分布,从而打破与先前分割的关联。然而,这种方法在进行 AB 测试期间并不实际,因为它扰乱了对照组和处理组之间的用户分布,破坏了 AB 测试的独立同分布(i.i.d.)前提,从而使结果无效。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

简单重排:更改盐值将导致用户在桶之间重新分配

当同时进行多个 AB 测试时,这种方法也会带来挑战,协调所有测试的终止可能很困难,任何延误都可能代价高昂。此外,持续时间长的负面测试无法停止,否则会失去 AB 测试的进展。

重排非 AB 测试用户

另一种替代方案是对不涉及任何 AB 测试的用户进行重排。采用这种方法,在每次完成 AB 测试后,那些不参与任何测试的用户将被重新分配到可用的桶中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

重排未参与 AB 测试的用户:在测试 2 结束后,来自桶 2、3、4 和 5 的用户被重排。测试 1 中的用户保持不变。在测试 1 结束后,所有桶中的用户都被重排。

尽管这种方法不需要停止所有 AB 测试,但其实施更为复杂。我们需要跟踪实验中的用户并存储用户-ID 与桶的映射,频繁更新——这在大型系统中可能比较棘手。

结论

处理 AB 测试的复杂性,从用户设置和桶管理到处理带来的影响,需要精心规划和策略处理。理解这些复杂性可以帮助确保测试提供有价值的、可操作的见解,并对你正在进行的开发工作产生积极的贡献。通过采用有效的解决方案来克服潜在的障碍,你可以优化测试过程,提高用户体验,并最终优化产品的成功。

进一步阅读

为了深入理解带来的影响和 AB 测试的其他细微差别,这里有一些有价值的资源供进一步阅读:

网页上的对照实验:调查与实用指南 对桶化设计进行了详细的探讨,这在执行对照网页实验中是一个关键要素。

避免 A/B 测试中的三个陷阱的分离策略 对分组设计进行了详细的讲解,分组设计是进行受控网络实验中的一个关键要素。

如何解读 p 值直方图 对 p 值直方图的解读进行了深入探讨,这在 AA 测试中至关重要。它有助于检测延续效应,使读者对 AB 测试的统计方面有更深入的理解。

关于人工智能与推理的类型

原文:towardsdatascience.com/on-ai-and-types-of-reasoning-fc6980295158?source=collection_archive---------3-----------------------#2023-01-20

人工智能如何做决策?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jazmia Henry

·

关注 发表在 Towards Data Science · 5 分钟阅读 · 2023 年 1 月 20 日

嗨,各位数据领域的朋友们!

我工作在算法领域的时间越长,我就越确信,算法只是人类让机器模仿我们思维方式的一种方法。

在任何给定的时刻,我们会接收 1100 万比特的信息,但仅处理其中的 40 到 50 个。我们已经进化到只关注对生存最有价值的信息。

在构建算法时,我们使用数据来进行预测或协助决策,其中一些特征对我们的分析更有价值或更有用。

处理我们数据的算法与处理我们周围世界的思维之间的区别在于理解上下文的能力,并在归纳推理(当我感到热时,我会出汗。因此,当未来温度高时,我将会出汗)、演绎推理(如果 A = B,B = C,那么 A = C)和演绎推理(我把食物放在有狗的房间的台面上。我回来后发现我的食物不见了,而我的狗看起来很内疚。我的狗一定吃了我的食物)之间轻松切换的能力。算法可以被可靠地训练来执行所有这些类型的推理,但无法像人类一样同时可靠地进行所有这些推理。

归纳推理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

归纳推理

归纳推理遵循特定的路径。它从特定的观察开始(观察到的树上的叶子是绿色的),注意到一个模式(我面前的这群树都有绿色的叶子),然后得出一个一般性的结论(所有树木的叶子都是绿色的)。分类算法如逻辑回归在归纳推理方面表现良好。它们有一个目标变量,并利用特定的特征来得出更大的结论。

这里有一个这种现象的例子。假设你正在执行一个逻辑回归算法,该算法能够识别苹果和橙子的区别。你的目标变量是一个二进制变量——1 代表苹果或 0 代表橙子。你的特征是颜色和皮肤质地的分类变量,以及是否有果梗的布尔变量。当运行模型时,算法得出结论:如果水果的颜色不是橙色,皮肤光滑,并且有果梗,那么该水果是苹果。特征是具体的,算法能够检测到一个模式,最终得到的结果是一般性的。

演绎推理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

演绎推理

归纳推理从具体开始,得出一般性结论,而演绎推理则从一般性结论开始,得出具体结论。这就像是开车经过一片树木繁茂的森林,注意到所有树上的叶子都是绿色的,然后提出假设:森林中的任何一棵树也会有绿色的叶子。

基本的聚类算法在演绎推理方面表现良好。它们将模型中的特征用于识别围绕一个质心最近的数据点,并根据接近度对其进行分组,从而利用一般信息(所有数据点都在这个平面上)来得出具体结论(在欧几里得空间中最接近的数据点在某种有价值的方式上是最相似的)。

演绎推理

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

演绎推理

演绎推理发生在算法在注意到模式后用不完全的数据得出结论时。比如,你想仅通过观察人们穿的衣物来推断外面的温度。当人们感到寒冷时,他们通常会穿上外套。通过观察外面没有人穿外套,你得出结论外面一定很温暖。

强化学习算法在归纳推理方面表现良好。代理使用模拟环境在面对不完整观察时通过计算轨迹和优化奖励来得出结论。

让我们来看一个例子。假设你正在构建一个 Q-Learning 算法作为自驾车的基本模型,以将包裹送到社区中的人们那里。你希望确保你的自驾车能在一天结束前安全高效地送达所有包裹。为了训练你的自驾车,你创建了一个数字代理,每次车辆安全送达包裹时,你可以奖励它。

你的代理能够观察人类专家驾驶交付路径并沿途做出决策的行动。在每次观察后,代理会尝试行驶该路线,直到做出最佳决策以获得最佳奖励。代理可以做出的决策包括最佳驾驶速度、转动方向盘的方式、何时刹车以及何时加速。然而,当代理驾驶其典型路径时,它遇到了导致交通堵塞的施工。这可能看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

哎呀,有交通了!

研究人员可以促使代理在遇到新的未曾遇到的情况时做出最佳选择。经过更多的训练,代理能够预测最佳行动,同时继续执行其主要目标:按时安全地送达所有包裹。在计算过程中,代理可能发现还有另一个包裹需要送达——一个可能传统上在其路线后期送达的包裹。通过绕道将包裹送到另一个家庭,然后再返回原来的路线,代理可能会获得最大奖励。这项决策可能看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这种推理方式是归纳推理的一个例子。尽管代理没有关于导致交通的施工的完整信息,但它能够决定最佳的行动路线。它能够意识到交通会增加交付时间,并且在未被告知的情况下决定尝试另一条路线。

结论

讨论的每种推理方式都有其优点和缺点,具体取决于应用任务。通过理解三种主要的 AI 推理方式,可以推动 AI 的可能性,使我们更接近更有用和强大的通用 AI。

然而,是否这确实是我们未来 AI 的最佳目标尚待观察,但如果是这样,具有这种复杂推理能力的 AI 将会改变游戏规则。

你觉得怎么样?在下面告诉我。

关注我,了解更多关于数据和人工智能的文章!

** 所有图片均由作者创作。

数据驱动的方程发现

原文:towardsdatascience.com/on-data-driven-equation-discovery-5069795d239d?source=collection_archive---------3-----------------------#2023-12-01

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 乔治·米洛舍维奇

·

关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 12 月 1 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 ThisisEngineering RAEng 提供,来源于 Unsplash

借助实验验证的分析表达式来描述自然,一直以来是科学特别是物理学从万有引力定律到量子力学及更广泛领域成功的标志。随着气候变化、聚变和计算生物学等挑战使我们将关注点转向更多的计算,迫切需要在降低成本的同时保持物理一致性的简明而强大的减少模型。科学机器学习是一个新兴领域,它承诺提供这样的解决方案。本文是对近期数据驱动方程发现方法的简要回顾,面向对机器学习或统计学非常基础的科学家和工程师。

动机与历史视角

单纯地将数据拟合得很好已被证明是一种短视的努力,这一点通过托勒密的地心说模型得到了证明,该模型是直到开普勒的日心说之前最符合观测的模型。因此,将观察与基本物理原理相结合在科学中发挥了重要作用。然而,在物理学中,我们常常忽略了我们的世界模型已经是数据驱动的程度。以粒子标准模型为例,它有 19 个参数,其数值是通过实验确定的。用于气象和气候的地球系统模型虽然在基于流体动力学的物理一致核心上运行,但也需要对其许多敏感参数进行仔细的观测校准。最后,减少阶次建模在聚变和空间天气社区中正在获得关注,并且很可能在未来保持相关性。在生物学和社会科学等领域,第一性原理方法效果较差,统计系统识别已经发挥了重要作用。

机器学习中有多种方法可以直接从数据中预测系统的演变。最近,深度神经网络在天气预报领域取得了显著进展,这一点由Google’s DeepMind团队等证明。这在一定程度上归因于他们拥有的巨大资源,以及气象数据和物理数值天气预报模型的一般可用性,这些模型通过数据同化将这些数据插值到全球。然而,如果数据生成的条件发生变化(例如气候变化),这些完全基于数据驱动的模型可能会表现不佳。这意味着将这些黑箱方法应用于气候建模及其他数据不足的情况可能会存在疑问。因此,在本文中,我将强调从数据中提取方程的方法,因为方程更具可解释性,且不易过拟合。在机器学习术语中,我们可以将这些范式称为高偏差——低方差

首先值得一提的方法是Schmidt 和 Lipson的开创性工作,该工作使用了遗传编程(GP)进行符号回归,从简单动力系统(如双摆等)的轨迹数据中提取方程。该过程包括生成候选符号函数,推导这些表达式中涉及的偏导数,并将其与从数据中数值估算的导数进行比较。这个过程会重复进行,直到达到足够的准确性。重要的是,由于潜在候选表达式的数量非常庞大且相对准确,因此选择符合“简约原则”的表达式。简约原则通过表达式中的项数的倒数来衡量,而预测准确性则通过仅用于验证的保留实验数据上的误差来衡量。这个简约建模的原则构成了方程发现的基础。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

遗传编程(GP)的思想是通过尝试一系列潜在的项来探索可能的分析表达式空间。这个表达式被编码在上面的树中,其结构可以表示为一种“基因”。通过突变这些基因的序列、选择和交叉最优候选项,可以获得新的树。例如,要获取右侧框中的方程,只需跟随右侧树的层级中的箭头即可。

这种方法的优点在于探索各种可能的解析表达式组合。它已在各种系统中尝试过,特别是,我将重点介绍 AI — 费曼,借助 GP 和神经网络,能够从数据中识别出费曼物理讲座中的 100 个方程。GP 的另一个有趣应用是发现 气候中的海洋参数化,其中实质上运行了一个高保真模型来提供训练数据,同时从训练数据中发现了较便宜的低保真模型的修正。然而,GP 并非没有缺陷,人工干预是不可或缺的,以确保参数化效果良好。此外,由于它遵循进化的过程:试错,因此可能非常低效。还有其他可能性吗?这将引导我们到近年来主导方程发现领域的方法。

稀疏系统识别

非线性动力学的稀疏识别(SINDy) 属于概念上简单但强大的方法家族。由 Steven L. Brunton 的团队介绍,以及 其他团队 ,并配有文档完善、支持良好的 代码库YouTube 教程。要获得一些实际操作经验,只需试用他们的 Jupyter notebooks。

我将根据 原始 SINDy 论文 描述该方法。通常,您拥有轨迹数据,其中包含 x(t)、y(t)、z(t) 等坐标。目标是从数据中重建一阶常微分方程(ODEs):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通常,x(t)(有时称为响应函数)是从观察或建模数据中获得的。目标是估计 f = f(x)(ODE 的右侧)的最佳选择。通常,会尝试一个单项式库,算法会继续寻找稀疏系数向量。系数向量的每个元素控制着这个单项式对整个表达式的重要性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这里,函数 f = f(x) 被表示为单项式库与稀疏向量的乘积。有关更多说明,请参见下面的图形。

有限差分法(例如)通常用于计算常微分方程左侧的导数。由于导数估计容易出错,这会在数据中产生噪声,这通常是不希望的。在某些情况下,过滤可能有助于处理这些问题。接下来,选择一个单项式(基函数)库来拟合常微分方程右侧,如图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如[1]所示的非线性动力学的稀疏识别(SINDy)。其思想是提取一小部分基函数(例如单项式),即全基库的一个子集,当数据代入时,这些基函数能满足方程。在左侧写出时间导数(每列对应不同变量,每行对应数据样本,样本可能是时间),而右侧则是基库矩阵(其行跨度每个基函数)与稀疏向量相乘,稀疏向量是算法学习的对象。促进稀疏性意味着我们希望最终得到的大多数向量值为零,这符合节俭原则。

问题在于,除非我们拥有天文数字级的数据,否则这个任务将毫无希望,因为许多不同的多项式都会很好地工作,这将导致显著的过拟合。幸运的是,这正是稀疏回归的救援之处:重点是对右侧有太多活跃基函数进行惩罚。这可以通过多种方式实现。原始 SINDy 所依赖的一种方法叫做序列阈值最小二乘法(STLS),可以总结如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来自 SINDy 论文补充材料的 Matlab 稀疏表示代码。

换句话说,使用标准最小二乘法求解系数,然后在每次应用最小二乘法时逐步消除小系数。该过程依赖于一个超参数,该超参数控制系数的小值容忍度。这个参数看似任意,但可以进行所谓的帕累托分析:通过保留一些数据并测试学习模型在测试集上的表现来确定这个稀疏化超参数。这个系数的合理值对应于学习模型的准确性与复杂性曲线(复杂性 = 包含的项数)中的“肘部”,即所谓的帕累托前沿。或者,某些其他文献推荐使用信息准则来推广稀疏性,而不是执行上述的帕累托分析。

作为 SINDy 的最简单应用,考虑如何使用 STLS 成功识别Lorenz 63 模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将 SINDy 应用于 Lorenz 63 模型识别的示例。系数(颜色)大致对应于用于生成训练数据的系数。这些数据是通过解决带有这些参数的相关 ODE 生成的。

STLS 在应用于自由度较大的系统(如偏微分方程(PDEs))时存在局限性,在这种情况下,可以考虑通过 主成分分析(PCA)或 非线性自编码器 等进行降维。后来,SINDy 算法通过 PDE-FIND 论文 得到了进一步改进,该论文引入了顺序阈值岭回归 (STRidge)。在后者中,岭回归 指的是带有 L2 惩罚的回归,而在 STRidge 中则交替进行小系数的淘汰,如同 STLS。这使得从仿真数据中发现各种标准 PDE 成为可能,例如 布尔戈斯方程科尔特韦格-德弗里斯方程(KdV)、纳维-斯托克斯方程、反应-扩散方程,甚至是科学机器学习中常遇到的一个相当特殊的方程,Kuramoto-Sivashinsky 方程,由于需要直接从数据中估计其四阶导数项,这通常较为棘手。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Kuramoto-Sivashinsky 方程描述了层流火焰流中的扩散-热不稳定性。

该方程的识别直接基于以下输入数据(这些数据是通过数值求解相同方程获得的):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Kuramoto-Sivashinsky 方程的解。右侧面板显示了场,而右侧面板则显示了其时间导数。

这并不是说该方法容易出错。事实上,将 SINDy 应用于现实观察数据的一个大挑战在于这些数据往往本身稀疏且噪声较大,通常在这种情况下识别效果较差。同样的问题也影响了基于符号回归的方法,如遗传编程(GP)。

弱 SINDy 是一种较新的发展,它显著提高了算法在噪声方面的鲁棒性。这种方法已由多位作者独立实施,尤其是 丹尼尔·梅森丹尼尔·R·古列维奇帕特里克·赖恩博德。其主要思想是,与发现 PDE 的微分形式相比,发现其 [弱] 积分形式,通过在一组域上对 PDE 进行积分,并乘以一些测试函数。这允许通过分部积分,从 PDE 的响应函数(未知解)中去除棘手的导数,而将这些导数应用于已知的测试函数。这种方法进一步应用于 Alves 和 Fiuza 进行的等离子体物理方程发现,其中恢复了 Vlasov 方程和等离子体流体模型。

SINDy 方法的另一个显而易见的局限性是,识别始终受到构成基的项库(例如单项式)的限制。虽然可以使用其他类型的基函数,如三角函数,但这仍然不够通用。假设 PDE 具有一个有理函数的形式,其中分子和分母都可以是多项式:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这种情况使得像 PDE-FIND 这样的算法应用变得复杂

这种情况当然可以通过遗传编程(GP)轻松处理。然而,SINDy 也扩展到了这样的情况,引入了 SINDy-PI(并行隐式),该方法成功用于识别描述 贝洛乌索夫-扎博廷斯基反应 的 PDE。

最后,其他稀疏促进方法,如稀疏贝叶斯回归,也称为相关向量机(RVM),也被用于使用完全相同的拟合术语库的方法从数据中识别方程,但受益于边际化和统计学家高度尊重的“奥卡姆剃刀”原则。我在这里不覆盖这些方法,但可以说,像张和林这样的作者声称对 ODEs 的系统识别更为稳健,并且这种方法甚至尝试用于学习简单条带气候模型的闭合,其中作者认为 RVM 比 STRidge 更稳健。此外,这些方法为识别方程的估计系数提供了自然的不确定性量化(UQ)。话虽如此,集成 SINDy的最新发展更加稳健,提供 UQ,但则依赖于统计方法自助聚合(bagging),这一方法也广泛应用于统计学和机器学习。

物理信息深度学习识别

解决和识别偏微分方程(PDE)系数的另一种方法是物理信息神经网络(PINNs),该方法在文献中引起了极大关注。主要思想是使用神经网络对 PDE 的解进行参数化,并将运动方程或其他类型的基于物理的归纳偏置引入损失函数。损失函数在预定义的一组所谓的“协同点”上进行评估。在执行梯度下降时,神经网络的权重会被调整,从而“学习”解决方案。所需提供的唯一数据包括初始条件和边界条件,这些条件也在一个单独的损失项中受到惩罚。该方法实际上借鉴了旧的非神经网络的协同方法。虽然神经网络提供了自然的自动微分方式,使这种方法非常有吸引力,但事实证明,PINNs 与标准数值方法如有限体积/有限元等通常不具竞争力。因此,作为解决前向问题(数值求解 PDE)的工具,PINNs并不那么有趣。

它们成为解决逆问题的有趣工具:通过数据估计模型,而不是通过已知模型生成数据。在原始 PINNs 论文中,两个 Navier-Stokes 方程的未知系数是通过数据进行估计的。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

输入到 PINN 损失函数中的 Navier-Stokes 方程的假定形式。通过识别,获得了两个未知参数(位于黄色框内)。有关 PINNs 的 tensorflow 实现,请参阅

回顾起来,与 PDE-FIND 等算法相比,这似乎有些天真,因为方程的一般形式已经被假定。然而,这项工作的一个有趣方面是算法并没有输入压力数据,而是假设了不可压缩流动,并通过 PINN 直接恢复压力的解。

PINNs 已经在各种情况下应用,我想特别强调一个应用是空间天气,在这个应用中,展示了通过解决 Fokker-Planck 方程的逆问题来估计辐射带中的电子密度。这里,重新训练神经网络的集成方法在估计不确定性方面非常有用。最终,为了实现可解释性,进行学习的扩散系数的多项式扩展。将这种方法与直接使用类似 SINDy 的方法进行比较会非常有趣,后者也提供了多项式扩展。

“物理信息”这个术语已经被其他团队采纳,他们有时发明了自己将物理先验融入神经网络的版本,并称之为类似“基于物理”或“受物理启发”等引人注目的名称。这些方法有时可以被归类为软约束(惩罚不满足某些方程或对称性的损失)或硬约束(将约束实施到神经网络的架构中)。这种方法的例子可以在气候科学等其他学科中找到。

由于反向传播的神经网络提供了一种估计时间和空间导数的替代方法,因此稀疏回归(SR)或遗传编程(GP)与这些神经网络配合方法的结合似乎是不可避免的。虽然这样的研究有很多,但我将重点介绍一个相对文档齐全且支持良好的DeePyMoD以及代码库。了解这种方法的工作原理足以理解同时期或之后出现的所有其他研究,并在各种方面改进

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DeePyMoD 框架:PDE 的解通过前馈神经网络(NN)进行参数化。在最新的论文中,损失函数由两个项组成:数据与 NN 预测之间的均方误差(MSE)项;正则化损失,它惩罚包括活跃库项在内的 PDE 函数形式。类似于 SINDy 的 STLS,当网络收敛到解时,稀疏性向量中的小项被消除,从而仅推广库中最大的系数。然后,NN 的训练会重复进行,直到满足收敛标准。

损失函数包括均方误差(MSE):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

以及促进 PDE 函数形式的正则化

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

与较弱的 SINDy 相比,DeePyMoD 在噪声下显著更稳健,仅需要在时空域上很少的观测点,这对于从观测数据中发现方程是个好消息。例如,许多 PDE-FIND 能正确识别的标准 PDE 也可以由 DeePyMoD 识别,但只需在包含噪声数据的空间中采样几千个点。然而,使用神经网络进行这项任务的代价是更长的收敛时间。另一个问题是一些 PDE 对原始配合方法存在问题,例如由于高阶导数的 Kuramoto-Sivashinsky (KS) 方程。没有弱形式方法,从数据中识别 KS 通常很困难,尤其是在噪声存在的情况下。更多的近期发展涉及将弱 SINDy 方法与神经网络配合方法结合。另一个有趣且实际未探讨的问题是这些方法通常如何受到非高斯噪声的影响。

结论

总结来说,方程发现是基于物理的机器学习的自然候选者,正在全球多个团队积极开发。它已在流体动力学、等离子体物理、气候等多个领域找到应用。有关其他方法的更广泛概述,请参见综述文章。希望读者对该领域存在的不同方法有所了解,但我只是略微触及了表面,避免过于技术化。值得一提的是许多新的基于物理的机器学习方法,如神经常微分方程(ODEs)。

参考文献

  1. Camps-Valls, G. et al. 从数据中发现因果关系和方程。Physics Reports 1044,1–68 (2023)。

  2. Lam, R. et al. 学习高技能的中期全球天气预测。Science 0,eadi2336 (2023)。

  3. Mehta, P. et al. 物理学家机器学习的高偏差、低方差介绍。Physics Reports 810,1–124 (2019)。

  4. Schmidt, M. & Lipson, H. 从实验数据中提炼自由形式自然法则。Science 324,81–85 (2009)。

  5. Udrescu, S.-M. & Tegmark, M. AI Feynman: 一种受物理启发的符号回归方法。Sci Adv 6,eaay2631 (2020)。

  6. Ross, A., Li, Z., Perezhogin, P., Fernandez-Granda, C. & Zanna, L. 在理想化模型中对机器学习海洋子网格参数化的基准测试。Journal of Advances in Modeling Earth Systems 15,e2022MS003258 (2023)。

  7. Brunton, S. L., Proctor, J. L. & Kutz, J. N. 通过稀疏识别非线性动态系统从数据中发现主方程。Proceedings of the National Academy of Sciences 113,3932–3937 (2016)。

  8. Mangan, N. M., Kutz, J. N., Brunton, S. L. & Proctor, J. L. 通过稀疏回归和信息准则选择动态系统模型。Proceedings of the Royal Society A: Mathematical, Physical and Engineering Sciences 473,20170009 (2017)。

  9. Rudy, S. H., Brunton, S. L., Proctor, J. L. & Kutz, J. N. 数据驱动的偏微分方程发现。Science Advances 3,e1602614 (2017)。

  10. Messenger, D. A. & Bortz, D. M. 用于偏微分方程的弱 SINDy。Journal of Computational Physics 443,110525 (2021)。

  11. Gurevich, D. R., Reinbold, P. A. K. & Grigoriev, R. O. 对非线性 PDE 模型的鲁棒和最优稀疏回归。Chaos: An Interdisciplinary Journal of Nonlinear Science 29,103113 (2019)。

  12. Reinbold, P. A. K., Kageorge, L. M., Schatz, M. F. & Grigoriev, R. O. 通过物理约束的符号回归从噪声、不完整、高维实验数据中进行鲁棒学习。Nat Commun 12,3219 (2021)。

  13. Alves, E. P. & Fiuza, F. 从全动能模拟中数据驱动地发现简化的等离子体物理模型。Phys. Rev. Res. 4,033192 (2022)。

  14. Zhang, S. & Lin, G. 具有误差条的数据驱动的物理定律发现。皇家学会 A 卷:数学、物理和工程科学学报 474, 20180305 (2018)。

  15. Zanna, L. & Bolton, T. 数据驱动的海洋中尺度闭合方程发现。地球物理研究快报 47, e2020GL088376 (2020)。

  16. Fasel, U., Kutz, J. N., Brunton, B. W. & Brunton, S. L. Ensemble-SINDy:在低数据、高噪声极限下,通过主动学习和控制实现稳健的稀疏模型发现。皇家学会 A 卷:数学、物理和工程科学学报 478, 20210904 (2022)。

  17. Raissi, M., Perdikaris, P. & Karniadakis, G. E. 物理信息神经网络:解决涉及非线性偏微分方程的正向和逆向问题的深度学习框架。计算物理学杂志 378, 686–707 (2019)。

  18. Markidis, S. 旧与新:物理信息深度学习能否取代传统线性求解器?大数据前沿 4, (2021)。

  19. Camporeale, E., Wilkie, G. J., Drozdov, A. Y. & Bortnik, J. 数据驱动的 Fokker-Planck 方程发现:使用物理信息神经网络研究地球辐射带电子。地球物理研究杂志:空间物理学 127, e2022JA030377 (2022)。

  20. Beucler, T. 在模拟物理系统的神经网络中强制实施解析约束。物理评论快报 126, 098302 (2021)。

  21. Both, G.-J., Choudhury, S., Sens, P. & Kusters, R. DeepMoD:在噪声数据中进行模型发现的深度学习。计算物理学杂志 428, 109985 (2021)。

  22. Stephany, R. & Earls, C. PDE-READ:使用深度学习发现可读的偏微分方程。神经网络 154, 360–382 (2022)。

  23. Both, G.-J., Vermarien, G. & Kusters, R. 稀疏约束神经网络用于 PDE 模型发现。预印本于 doi.org/10.48550/arXiv.2011.04336 (2021)。

  24. Stephany, R. & Earls, C. Weak-PDE-LEARN:一种基于弱形式的 PDE 发现方法,适用于噪声大、数据有限的情况。预印本于 doi.org/10.48550/arXiv.2309.04699 (2023)。

在代表性不足的群体面前的学习

原文:towardsdatascience.com/on-learning-in-the-presence-of-underrepresented-groups-8937434d3c85?source=collection_archive---------11-----------------------#2023-07-11

改变是困难的:对亚群体偏移的更深入了解 (ICML 2023)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Yuzhe Yang

·

关注 发表在 Towards Data Science ·8 min read·2023 年 7 月 11 日

让我向您介绍我们最新的工作,这项工作已被 ICML 2023 会议接受:改变是困难的:对亚群体偏移的更深入了解。机器学习模型在许多应用中表现出巨大的潜力,但它们在训练数据中代表性不足亚群体上往往表现较差。理解导致这种亚群体偏移的机制变异,以及算法在大规模不同偏移下的泛化能力仍然是一个挑战。在这项工作中,我们旨在通过提供对亚群体偏移及其对机器学习算法影响的细致分析来填补这一空白。

我们首先提出了一个统一的框架,剖析并解释了子群体中常见的变化。此外,我们引入了一个综合基准,包含 20 种最先进的算法,我们在 12 个现实世界的数据集上对其进行了评估,这些数据集涵盖了视觉语言医疗保健领域。通过我们的分析和基准测试,我们提供了关于子群体变化及机器学习算法在这些现实世界变化下如何泛化的有趣观察和理解。代码、数据和模型已经在 GitHub 上开源:github.com/YyzHarry/SubpopBench

背景与动机

机器学习模型在面对分布变化时通常表现出性能下降。这种变化发生在基础数据分布发生变化时(例如,训练分布与测试分布不同),导致模型部署时性能下降。构建对这些变化具有鲁棒性的机器学习模型对于在现实世界中安全部署这些模型至关重要。一种普遍存在的分布变化类型是子群体变化,其特征是在训练和部署之间某些子群体的比例发生变化。在这种情况下,模型可能在总体上表现良好,但在稀有子群体中表现较差。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

:在牛与骆驼分类任务中,牛通常出现在绿色背景中,而骆驼则通常出现在黄色背景中。因此,模型在这些背景下表现良好,但无法泛化到背景颜色不同的图像中。 :在医学诊断任务中,机器学习模型在代表性不足的年龄或种族群体上表现往往较差。(图片由作者提供)

例如,在牛和骆驼分类任务中,牛通常出现在绿色草地区域,而骆驼则通常出现在黄色沙地背景区域。然而,这种关联是虚假的,因为牛或骆驼的存在与背景颜色无关。因此,训练好的模型在上述图像上表现良好,但无法泛化到训练数据中稀少的不同背景颜色的动物,例如沙地上的牛或草地上的骆驼。

此外,研究发现,在医学诊断方面,机器学习模型在代表性不足的年龄或种族群体上表现往往较差,这引发了重要的公平性问题。

所有这些变化通常被称为子群体变化,但对于导致子群体变化的机制变异及算法如何在大规模的不同变化下泛化的了解甚少。那么,如何建模子群体变化

子群体变化的统一框架

我们首先提供了一个统一的子群体转移建模框架。在经典分类设置中,我们有来自多个类别的训练数据(其中我们使用不同的颜色密度来表示每个类别中的样本数量)。然而,当涉及子群体转移时,除了类别之外还存在属性——例如在牛骆驼问题中的背景颜色。在这种情况下,我们可以根据属性标签定义离散的子群体,而且在同一类别中,不同属性的样本数量也可能有所不同(见下图)。自然地,为了测试模型,类似于我们在所有类别中评估性能的分类设置,在子群体转移中我们测试模型在所有子群体上的表现,以确保所有子群体中的最差性能也足够好,或确保所有组的性能都同样优秀

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在子群体转移中,我们需要考虑属性,而不仅仅是类别标签。(图片由作者提供)

具体而言,为了提供一个通用的数学公式,我们首先使用贝叶斯定理重写分类模型。我们进一步将每个输入x视为由一组潜在核心特征(X_core)和一个属性列表(a)完全描述或生成。在这里,X_core表示与标签特定的、支持稳健分类的潜在不变成分,而属性a可能具有不一致的分布,并且不是标签特定的。因此,我们可以将这种建模整合回方程,并进一步分解为三项,如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个通用的子群体转移建模框架。(图片由作者提供)

具体而言,第一项表示X_corey之间的点对点互信息(PMI),这是与潜在类别标签相关的稳健指标。第二项和第三项分别对应于属性分布和标签分布中可能出现的偏差。这种建模解释了属性和类别如何在子群体转移下影响结果。因此,给定训练和测试分布之间不变的X_core,我们可以忽略第一项的变化,关注属性类别在子群体转移下如何影响结果。

基于此框架,我们正式定义并描述了四种基本的子群体转移类型:虚假相关属性不平衡类别不平衡属性泛化。每种类型构成了子群体转移中可能出现的基本转移成分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

四种基本的子群体转移类型。(图片由作者提供)

首先,当某些属性在训练数据中与标签y存在虚假相关性,但在测试数据中没有时,这意味着虚假相关性。此外,当某些属性的采样概率远小于其他属性时,会引发属性不平衡。类似地,类别标签可能会表现出不平衡的分布,导致对少数标签的偏好较低,这将导致类别不平衡。最后,某些属性在训练中可能完全缺失,但在测试中对于某些类别却存在,这促使了属性泛化的需求。每种转移的属性/类别偏差来源及其对分类模型的影响总结在下面的表格中:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(图片由作者提供)

这四种情况构成了基本的转移组件,并且是解释真实数据中复杂子群体转移的重要元素。在实际应用中,数据集通常同时包含多种类型的转移,而不仅仅是一种。

SubpopBench:子群体转移基准测试

在建立了公式后,我们提出了SubpopBench,这是一个包括在 12 个真实世界数据集上评估的最先进算法的综合基准测试。特别是,这些数据集来自各种模态和任务,包括视觉、语言和医疗保健应用,数据模态范围从自然图像、文本、临床文本到胸部 X 光。这些数据集还展现了不同的转移组件。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

SubpopBench 基准测试。(图片由作者提供)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

SubpopBench 基准测试。(图片由作者提供)

关于此基准测试的详细信息,请参阅我们的论文。通过建立的基准测试和使用 20 种最先进算法训练的超过 10K 模型,我们揭示了未来研究中的一些有趣观察。

对子群体转移的细粒度分析

SOTA 算法仅改善某些类型的转移

首先,我们观察到 SOTA 算法仅在某些类型的转移上改善子群体鲁棒性,而在其他类型的转移上则没有。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(图片由作者提供)

我们在这里绘制了各种 SOTA 算法相对于 ERM 的最差组准确性改进。对于虚假相关性类别不平衡,现有算法可以提供相对于 ERM 的一致最差组增益,表明在解决这两种特定转移上已有进展。

然而,有趣的是,当涉及到属性不平衡时, across 数据集几乎没有观察到改进。此外,对于属性泛化,性能甚至变得更差。

这些发现强调了当前的进展仅针对特定的转移,而对于更具挑战性的转移,如 AG,没有进展

表示和分类器的作用

此外,我们受到启发去探讨表示分类器在子群体变化中的作用。具体来说,我们将整个网络分为两个部分:特征提取器f和分类器g,其中f从输入中提取潜在特征,而g输出最终预测。我们提出的问题是,表示和分类器如何影响子群体性能

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(作者提供的图片)

首先,给定一个基础的 ERM 模型,当仅优化分类器学习并固定表示时,可以显著提高虚假相关类别不平衡的性能,这表明 ERM 学到的表示已经足够好。然而有趣的是,改进表示学习而非分类器可以带来显著的提升,特别是在属性不平衡方面,这表明我们可能需要更强大的特征来应对某些变化。最后,无分层学习的方式在属性泛化下没有性能提升。这突显了在面对现实中不同类型的变化时,需要考虑模型管道设计

关于模型选择与属性可用性

此外,我们观察到模型选择属性可用性对子群体变化评估有显著影响。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(作者提供的图片)

具体而言,当逐渐去除训练和/或验证数据中的属性注释时,所有算法的性能都出现了显著下降,特别是当训练和验证数据中没有属性可用时。

这表明获取属性仍在子群体变化中发挥重要作用,未来的算法应该考虑更现实的场景以进行模型选择和属性可用性。

超越最差组准确率的指标

最后,我们揭示了基本的 权衡在评估指标之间。最差组准确率,或WGA,被认为是子群体评估的金标准。然而,改善 WGA 是否总是提升其他有意义的指标

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(作者提供的图片)

我们首先展示了改善 WGA 可能导致某些指标性能提升,例如这里显示的调整准确率。然而,如果我们进一步考虑最差情况精度,它却与 WGA 显示出非常强的负线性相关性。这揭示了使用 WGA 作为唯一指标来评估模型在子群体变化中的表现的基本限制:表现良好的模型具有高 WGA,但其最差类别精度可能很低,这在医疗诊断等关键应用中尤其令人担忧。

我们的观察强调了在子群体转移中需要更多现实且广泛的评估指标。我们还展示了许多在本文中与 WGA 呈负相关的其他指标。

结语

总结本文,我们系统地研究了子群体转移问题,形式化了一个统一的框架来定义和量化不同类型的子群体转移,并进一步建立了一个全面的基准,以便在真实世界数据中进行评估。我们的基准包括 20 种 SOTA 方法和 12 个来自不同领域的真实数据集。基于超过 10K 训练模型,我们揭示了子群体转移中的有趣特性,这些特性对未来的研究具有重要意义。我们希望我们的基准和发现能够促进现实和严格的评估,并激发子群体转移领域的新进展。最后,我附上了几篇相关论文的链接;感谢阅读!

代码: github.com/YyzHarry/SubpopBench

项目页面: subpopbench.csail.mit.edu/

演讲: www.youtube.com/watch?v=WiSrCWAAUNI

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值