TowardsDataScience 2023 博客中文翻译(二百一十六)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

在 DALL-E 3 翻译中迷失

原文:towardsdatascience.com/lost-in-dall-e-3-translation-b85a3958b9d6

多语言生成 AI 图像会导致不同的结果

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Yennie Jun

·发表于Towards Data Science ·11 分钟阅读·2023 年 11 月 2 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 DALL-E 3 在六种语言中生成的“一个人的图像”示例。图由作者创建。

本文最初发表于 artfish intelligence

介绍

OpenAI 最近推出了最新的 DALL-E 3,这是他们 AI 图像生成模型系列中的最新版本。

但正如近期媒体报道研究所揭示的,这些 AI 模型带有偏见和刻板印象。例如,AI 图像生成模型如 Stable Diffusion 和 Midjourney 倾向于放大关于种族、性别国家身份的现有刻板印象。

然而,这些研究大多数主要测试使用英文提示的模型。这就提出了一个问题:这些模型如何响应非英语提示?

在这篇文章中,我深入探讨了 DALL-E 3 在不同语言提示下的表现。借鉴我以前的工作的主题,我提供了对最新 AI 图像生成模型的多语言视角。

DALL-E 3 的工作原理:提示转换

与以前的 AI 图像生成模型不同,DALL-E 模型的最新版本并不会直接生成你输入的内容。相反,DALL-E 3 包含自动提示转换,这意味着它将你的原始提示转换成一个不同的、更具描述性的版本

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从 OpenAI 的论文中得到的一个提示转换示例,详细说明了标题改进过程:通过更好的标题改善图像生成。图由作者创建。

根据DALL-E 3 系统卡,这样做有几个原因:

  • 改进标题以使其更具描述性

  • 移除公共人物名称

  • 指定更多生成人物的多样化描述(例如,在提示转换之前,生成的人物往往主要是白人、年轻和女性)。

因此,图像生成过程如下:

  1. 您将提示输入到 DALL-E 3(通过 ChatGPT Plus 可用)。

  2. 您的提示在后台被修改为四种不同的转化提示。

  3. DALL-E 3 根据每个转化提示生成图像。

添加这种提示转换对图像生成领域来说相当新。通过添加提示修改,AI 图像生成的机制在后台变得更加抽象,用户更难以理解。

多语言提示转换

大多数研究文本到图像 AI 模型中的偏差都集中在使用英语提示。然而,对于这些模型在非英语语言提示下的行为知之甚少。这样做可以揭示潜在的语言特定或文化特定行为。

我让 DALL-E 3 使用以下英语提示生成图像:

  • “一个男性的图像”

  • “一个女性的图像”

  • “一个人的图像”

我使用了 GPT-4(没有 DALL-E 3)将这些短语翻译成以下语言:韩语、普通话、缅甸语、亚美尼亚语和祖鲁语。

然后,我使用 DALL-E 3 为每种语言生成 20 张图像,总共在 6 种语言中每个提示生成 120 张图像。在从 ChatGPT Plus 保存生成的图像时,图像文件名会自动保存为转化提示的文本。在本文的其余部分,我将分析这些转化提示。

元数据提取

在我的提示中,我从未指定特定的文化、种族或年龄。然而,转化后的提示通常包含这些指示符。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个提示转换的示例,标注了句子的哪个部分涉及艺术风格、年龄、种族和性别。图由作者创建。

从转化提示中,我提取了诸如艺术风格(“插图”)、年龄(“中年”)、种族(“非洲裔”)和性别标识(“女性”)等元数据。66%的转化提示包含了种族标记,58%包含了年龄标记。

观察 1:所有提示都被转换为英语。

无论原始提示是什么语言,修改后的提示总是被转换为英语。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一张展示了 ChatGPT Plus 的截图,展示了“一个人的图像”原始韩语提示被修改为四种不同英文提示转换的示例。图由作者创建。

我对这种行为感到惊讶——虽然我预计提示会被转换成更具描述性的提示,但我没有预料到会发生翻译成英语的情况。

大多数 AI 生成模型,如 Stable Diffusion 和 Midjourney,主要以英语进行训练和测试。一般来说,这些模型在从非英语提示生成图像时表现较差,导致一些用户将提示从其母语翻译成英语。然而,这样做有可能会失去母语的细微差别。

然而,据我所知,这些其他模型中没有一个会自动将所有提示翻译成英语。在背后添加这种额外的翻译步骤(而且,我相信,大多数用户并不知道,因为在使用工具时没有明确说明)使得已经不透明的工具变得更加神秘。

观察 2:原始提示的语言影响修改后的提示

提示转换步骤似乎还包括了关于原始提示语言的未指定元数据。

例如,当原始提示是缅甸语时,即使提示没有提及缅甸语言或缅甸人,提示转换通常会提到缅甸人

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个缅甸语提示的示例,内容为“一个男人的图像”,经过 DALL-E 3 转换成关于缅甸男人的描述性提示。图由作者创建。

这并非所有语言的情况,结果因语言而异。对于某些语言,转换后的提示更可能提到与该语言相关的族裔。例如,当原始提示是祖鲁语时,转换后的提示提到非洲人的频率超过 50%(相比之下,当原始提示是英语时,提到非洲人的频率接近 20%)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

DALL-E 3 生成的所有组合提示(一个人/男人/女人的图像)的族裔百分比,每种语言。图由作者创建。

我并不旨在对这种行为是否正确或错误做出价值判断,也没有规定应有的行为标准。不过,我发现 DALL-E 3 的行为在原始提示语言中变化如此之大是很有趣的。例如,当原始提示是韩语时,DALL-E 3 的提示转换中没有提到韩国人。同样,当原始提示是英语时,DALL-E 3 的提示转换中没有提到英国人。

观察 3:即便是中性提示,DALL-E 3 也会生成带有性别的提示

我将 DALL-E 3 的提示转换中的人物标识符名词映射到三个类别之一:女性、男性或中性:

  • woman, girl, lady → “女性”

  • man, boy, male doctor → “男性”

  • athlete, child, teenager, individual, person, people → “中性”

然后,我将原始提示(“人/男人/女人”)与转换提示(“中性/男性/女性”)进行了比较:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

给定原始提示(“一个人的图像/男人/女人”),转换提示中包含性别个体的百分比。图由作者创建。

毫不意外,“一个男人的图像”原始提示结果大多数是男性标识符(女性情况也是如此)。然而,我发现当使用中性提示“一个人的图像”时,DALL-E 3 75%的时间会将提示转换为包含性别(例如女性、男性)的术语。 DALL-E 3 生成的转换提示中,女性个体略多(40%)于男性个体(35%)。不到四分之一的中性提示转化为提及中性个体的提示。

观察 4:女性通常被描述为年轻,而男性的年龄则更为多样

有时,DALL-E 3 会在修改后的提示中包含一个年龄组(年轻、中年或年长)来描述个体。

在提示中提到女性个体的情况下,年龄描述往往偏向年轻。 具体来说,35%的转换提示将女性个体描述为“年轻”,是将她们描述为“年长”(13%)的两倍多,也比“中年”(7.7%)的频率高出四倍多。这表明,如果提示中提到女性,她很可能也会被描述为年轻。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

提及年龄组的转换提示数量,按提示中提到的个体性别分类。图由作者创建。

以下是一些提示转换的例子:

Illustration of a young woman of Burmese descent, wearing a fusion of modern and traditional attirePhoto of a young Asian woman with long black hair, wearing casual clothing, standing against a cityscape backgroundWatercolor painting of a young woman with long blonde braids, wearing a floral dress, sitting by a lakeside, sketching in her notebookOil painting of a young woman wearing a summer dress and wide-brimmed hat, sitting on a park bench with a book in her lap, surrounded by lush greenery

另一方面,提及男性个体的提示转换显示了更平衡的年龄分布。这可能表明,文化和社会观念持续认为女性的青春更具价值,而男性则被视为不论年龄都具吸引力和成功。

观察 5:个体年龄的变化取决于原始提示语言

年龄组的变化也取决于原始提示的语言。变换提示更有可能将某些语言(例如祖鲁语)描述为年轻,而其他语言(例如缅甸语)则较少如此。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

变换提示中提到年龄组的数量,按原始提示语言分开(图像中的男性/女性/人)。图像由作者创作。

观察 6:艺术风格的变化取决于个体性别

我预计艺术风格(例如摄影、插图)会在年龄组、语言和个体性别之间随机分布。也就是说,我预计女性个体和男性个体的照片数量会相似。

然而,情况并非如此。实际上,女性个体的照片更多,而男性个体的插图更多。描述个体的艺术风格并没有在性别之间均匀分布,而是更偏爱某些性别。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

变换提示中提到的每种艺术风格的数量,按提示中提到的个体性别分开。图像由作者创作。

观察 7:从年轻的亚洲女性到年长的非洲男性的陈词滥调重复

在我的实验中,提示变换中有 360 种独特的人口描述(例如年龄/种族/性别组合)。虽然许多组合仅出现了几次(如“年轻的缅甸女性”或“年长的欧洲男性”),但某些人口描述的出现频率较高。

一个常见的描述是“年长的非洲男性”,出现了 11 次。查看一些生成的图像可以看到,虽然面部表情、姿势、配饰和衣物相似,但还是有所不同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个子集图像的变换提示包含了“年长的非洲男性”这一短语。图像由作者创作。

更常见的描述是“年轻的亚洲女性”,出现了 23 次。再次地,许多面部表情、面部特征、姿势和衣物都是相似的,甚至几乎相同。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

一个子集图像的变换提示包含了“年轻的亚洲女性”这一短语。图像由作者创作。

这一现象捕捉了充斥我们世界的偏见的本质。当我们观察韩国 K-Pop 明星中国偶像的面孔时,他们的面部结构有着惊人的相似性。这种缺乏变化 perpetuates a specific beauty standard, narrowing the range of accepted appearances.

同样,在 AI 生成的图像中,诸如“年长的非洲男性”和“年轻的亚洲女性”等人口描述的狭隘解读助长了有害的刻板印象。这些模型通过不断生成面部特征、表情和姿势缺乏多样性的图像,固化了对这些人群应有的外貌的有限和刻板的看法。这种现象特别令人担忧,因为它不仅反映了现有的偏见,还有可能加剧这些偏见,因为这些图像被社会接受和规范化。

但 DALL-E 3 与其他图像生成模型相比如何?

我使用另外两个流行的文本到图像 AI 工具:MidjourneyStable Diffusion XL,生成了 6 种语言的“一个人的图像”。

对于使用 Midjourney 生成的图像,非英语提示更可能生成风景图像而不是人类图像(尽管,公平地说,英语图像相当令人不安)。对于一些语言,如缅甸语和祖鲁语,生成的图像包含模糊(也许有些不准确)的文化表现或对原始提示语言的参考。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用Midjourney生成的六种语言的“一个人的图像”。图形由作者创建。

在使用 Stable Diffusion XL 生成的图像中观察到了类似的模式。非英语提示更可能生成风景图像。亚美尼亚语提示只生成了看起来像地毯图案的图像。中文、缅甸语和祖鲁语的提示生成的图像对原始语言的参考模糊不清。(而且,再次强调,使用英语提示生成的图像相当令人不安)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用Stable Diffusion XL生成的六种语言的“一个人的图像”。我使用Playground AI来使用该模型。图形由作者创建。

从某种程度上说,DALL-E 3 的提示转换起到了人工引入更多变异性和多样性的作用。至少 DALL-E 3 在所有六种语言中一致地生成了人类形象,按照指示进行。

讨论与总结

自动提示词变换有其自身的考虑因素:它们可能会改变提示词的含义,可能带有固有的偏见,并且可能不总是符合个人用户的偏好。

—* DALL-E 3 系统卡片

在本文中,我探讨了 DALL-E 3 如何利用提示词变换来增强用户的原始提示。在此过程中,原始提示不仅被变得更加描述性,还被翻译成英语。可能会使用关于原始提示的附加元数据,例如其语言,来构建变换后的提示,尽管这只是推测,因为 DALL-E 3 系统卡片没有详细说明这一过程。

我对 DALL-E 3 的测试涵盖了六种不同的语言,但需要注意的是,这并不是对全球数百种语言的全面检查。然而,这是系统性探讨非英语语言中的 AI 图像生成工具的重要第一步,这是一个我尚未看到过多探讨的研究领域。

在通过 ChatGPT Plus 网络应用访问 DALL-E 3 时,提示词变换步骤对用户并不透明。这种缺乏清晰度进一步抽象了 AI 图像生成模型的工作原理,使得审视模型中编码的偏见和行为变得更加困难。

然而,与其他 AI 图像生成模型相比,DALL-E 3 在按照提示生成人物方面总体上 准确,在生成多种族面孔方面总体上 多样(由于提示词变换)。因此,尽管在某些种族类别的面部特征方面可能存在有限的多样性,但总体结果是生成图像的多样性(尽管是人为引起的)高于其他模型。

我以对 AI 文本到图像模型期望输出的开放性问题结束了本文。这些模型通常在大量互联网图像上训练,可能会不经意地延续社会偏见和刻板印象。随着这些模型的发展,我们必须考虑是否希望它们反映、放大或减轻这些偏见,特别是在生成人的图像或描绘社会文化机构、规范和概念时。认真思考这些图像的潜在规范化及其更广泛的影响是非常重要的。

注意:DALL-E 3 和 ChatGPT 都是定期演进的产品。即使我在一周前进行了实验,本文中的一些结果可能已经过时或无法再现。随着模型的持续训练和用户界面的不断更新,这种情况不可避免。虽然这是当前 AI 领域的常态,但在未来的研究中,对非英语语言的图像生成模型进行探讨的方法仍然适用。

如果你喜欢这篇文章,我鼓励你订阅我的通讯以支持我的工作,并阅读更多我的作品。谢谢!

低代码时间序列分析

原文:towardsdatascience.com/low-code-time-series-analysis-2d5d02b5474b

使用 Darts 来简化你的 Python 时间序列分析开发

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Pier Paolo Ippolito

·发布于 Towards Data Science ·6 分钟阅读·2023 年 3 月 8 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Afif Ramdhasuma 提供,来自 Unsplash

介绍

时间序列预测是机器学习中的一个独特领域。实际上,处理时间序列时,序列中不同点之间存在固有的时间依赖性,因此不同的观察值彼此高度依赖。如果你对学习时间序列分析的基础知识感兴趣,可以在 我之前的文章 中找到更多细节。

在经典的分类和回归问题中,scikit-learn 能够提供我们可能需要的大多数工具,以获得良好的基线(例如数据预处理、低代码模型、评估指标等……),但在时间序列中情况却截然不同。多年来,许多专门的库已出现,以覆盖时间序列分析工作流中的一些关键步骤(例如 statsmodelsProphet、自定义回测等……),但直到 Darts 出现之前,无法在单一解决方案中涵盖所有内容。

演示

作为本文的一部分,我们将通过一个实际示例演示如何使用 Darts 来分析 Kaggle 上的德里每日气候时间序列数据集 [1]。本文中使用的所有代码(及更多内容!)都可以在 我的 GitHubKaggle 账户 上找到。

首先,我们需要确保环境中安装了 Darts。

pip install darts

数据预处理

到这一步,我们已经准备好导入必要的库和数据集(图 1)。为了便于分析,首先将日期列从字符串转换为日期时间,然后将其设置为数据框的索引。

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import darts
from darts.ad import QuantileDetector

df = pd.read_csv('DailyDelhiClimateTrain.csv')
df["date"] = pd.to_datetime(df["date"])
df = df.set_index('date')
df.head(5)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1: 德里每日气候时间序列数据集(图像由作者提供)。

清理数据集后,我们现在可以将其划分为训练集和测试集,并可视化时间序列(图 2)。在我们的分析中,我们暂时只关注平均温度。

ts = darts.TimeSeries.from_series(df.meantemp)
train, val = ts.split_before(0.75)
train.plot(label="Training Data")
val.plot(label="Validation Data")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2: 德里每日温度时间序列(图像由作者提供)。

异常检测

由于其特性,时间序列通常作为实时或流服务的一部分进行处理,这可能使它们更容易受到错误测量和异常值生成的影响。为了监控我们的时间序列可能存在的异常值,可以使用不同的异常检测技术。两种可能的方法是使用分位数或阈值。使用分位数时,我们决定将序列中最高和最低的百分比值标记为异常值,而使用阈值时,我们指定固定的参考水平,超过或低于该水平的值被标记为异常。

在下面的示例中,将低于 3% 和高于 97% 的值视为异常值,会导致总体超出分位数的值百分比为 5.8%(图 3)。

anomaly_detector = QuantileDetector(low_quantile=0.03, high_quantile=0.97)
anomalies = anomaly_detector.fit_detect(ts)

l = anomalies.pd_series().values
print("Percentage of values outside quantiles:", 
      round(sum(l)/len(l)*100, 3), "%")

idx = pd.date_range(min(ts.pd_series().index), max(ts.pd_series().index))
anomalies = ts.pd_series()[np.array(l,dtype=bool)].reindex(idx,
                                                         fill_value=np.nan)
normal = ts.pd_series()[~np.array(l,dtype=bool)].reindex(idx, 
                                                         fill_value=np.nan)

normal.plot(color="black", label="Normal")
anomalies.plot(color="red", label="Anomalies")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3: 分位数异常检测(图像由作者提供)。

基线模型

此时,我们准备深入分析我们的时间序列,并检查是否存在任何季节性模式。如预期的那样,并且在下面的代码片段中所示,序列似乎在统计上每年大致遵循类似的季节性模式。

for m in range(2, 370):
        seasonal, period = darts.utils.statistics.check_seasonality(train, 
                                              m=m, max_lag=400, alpha=0.05)
        if seasonal:
            print("Seasonality of order:", str(period))
Seasonality of order: 354
Seasonality of order: 356
Seasonality of order: 361

根据这些信息,我们可以训练第一个简单的基线模型,该模型仅考虑序列中的季节性模式而不考虑其他信息(图 4)。使用这种方法,结果的 MAPE(平均绝对百分比误差)为 11.35%。使用 MAPE 作为评估指标的两个主要优点是:

  • 使用绝对值,正负误差不会相互抵消。

  • 错误不依赖于因变量的缩放。

k = 361
naive_model = darts.models.NaiveSeasonal(K=k)
naive_model.fit(train)
naive_forecast = naive_model.predict(len(val))

print("MAPE: ", darts.metrics.mape(ts, naive_forecast))
ts.plot(label="Actual")
naive_forecast.plot(label="Naive Forecast (K=" + str(k) + ")")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:基准模型预测(作者提供的图像)。

统计模型选择

现在提供了一个良好的基准模型,我们准备尝试一些更先进的技术(例如指数平滑,ARIMA,AutoARIMA,Prophet)。如果需要,许多额外的模型如:CatBoost,卡尔曼滤波器,随机森林,递归神经网络和时间卷积网络可作为 Darts 的一部分使用。

def model_check(model):
    model.fit(train)
    forecast = model.predict(len(val))
    print(str(model) + ", MAPE: ", darts.metrics.mape(ts, forecast))
    return model

exp_smoothing = model_check(darts.models.ExponentialSmoothing())
arima = model_check(darts.models.ARIMA())
auto_arima = model_check(darts.models.AutoARIMA())
prophet = model_check(darts.models.Prophet())
ExponentialSmoothing(), MAPE:  37.758
ARIMA(12, 1, 0), MAPE:  41.819
Auto-ARIMA, MAPE:  32.594
Prophet, MAPE:  9.794

根据上述结果,Prophet 似乎是迄今为止考虑的模型中最有前景的。无论如何,通过一些额外的工作,结果甚至可以通过超参数优化得到改善,特别是利用传统统计模型如 ARIMA 和指数平滑的业务领域知识。有关 ARIMA 工作原理及其不同超参数的更多细节可以在这里找到。

回测

为了进一步验证我们模型的优度,我们现在可以通过使用现有的历史数据进行测试(图 5)。在这种情况下,记录到的 MAPE 为 7.8%。

historical_fcast = prophet.historical_forecasts(ts,
                           start=0.6, forecast_horizon=30, verbose=True)

print("MAPE: ", darts.metrics.mape(ts, historical_fcast))
ts.plot(label="Actual")
historical_fcast.plot(label="Backtest 30 days ahead forecast")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:Prophet 回测(作者提供的图像)。

协变量分析

结论我们的分析后,我们现在可以检查使用数据集中其他列的信息如湿度和风速是否能帮助我们创建更高性能的模型。协变量主要有两种类型:过去和未来。对于过去的协变量,预测时仅有过去的值可用,而对于未来的协变量,预测时也有未来的值可用。

在这个例子中,N-BEATS(神经基础扩展分析时间序列)模型与湿度和风速列作为过去的协变量一起使用(图 6)。

humidity = darts.TimeSeries.from_series(df.humidity)
wind_speed = darts.TimeSeries.from_series(df.wind_speed)

cov_model = darts.models.NBEATSModel(input_chunk_length=361, 
                                     output_chunk_length=len(val))
cov_model.fit(train, past_covariates=humidity.stack(wind_speed))
cov_forecast = cov_model.predict(len(val), 
                               past_covariates=humidity.stack(wind_speed))

print("MAPE: ", darts.metrics.mape(ts, cov_forecast))
ts.plot(label="Actual")
cov_forecast.plot(label="Covariate Forecast")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:协变量分析预测(作者提供的图像)。

作为训练过程的结果,记录到的 MAPE 分数为 10.9%,因此在此情况下表现不如我们原始的 Prophet 模型。

联系方式

如果你想了解我的最新文章和项目,请在 Medium 上关注我并订阅我的邮件列表。这些是我的一些联系方式:

参考文献

[1] “每日气候时间序列数据”(SUMANTHVRAO,许可协议 CC0: 公共领域)。访问地址: www.kaggle.com/datasets/sumanthvrao/daily-climate-time-series-data?select=DailyDelhiClimateTrain.csv

Lucene 透视 — 处理整数编码和压缩

原文:towardsdatascience.com/lucene-inside-out-dealing-with-integer-encoding-and-compression-fe28f9dd265d

深入探讨 PackedInts、VInt、FixedBitSet 和 RoaringDocIdSetRoaring Bitmaps

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Peggy Chang

·发布于 Towards Data Science ·13 分钟阅读·2023 年 6 月 28 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Gerd Altmann 提供,来自 Pixabay

早些时候,我们学习了使用 产品量化 进行相似性搜索的向量压缩。

## 产品量化用于相似性搜索

如何在内存中压缩和适配一个巨大的向量集,以便在不对称距离计算下进行相似性搜索……

towardsdatascience.com

在这篇文章中,我们将探讨并深入了解整数在 Lucene 中的编码和压缩方式,那里倒排索引是核心。

Lucene — 简介

Lucene 是一个用 Java 编写的开源搜索引擎库。由 Doug Cutting 于 1999 年创建,以全文搜索和索引著称。

这个开源软件项目在 Apache 软件基金会 旗下,经过二十多年仍在积极开发中。多年来,它不断发展壮大,成为一个强大、功能齐全的高性能搜索引擎库。

毫无疑问,Lucene 的成功在很大程度上归功于其强大的社区以及贡献者们的卓越工作。他们的参与和合作使得 Lucene 达到了今天的水平。许多流行的企业搜索平台和解决方案,如 SolrElasticsearch,都是建立在 Lucene 之上的。

“对于一个开源项目来说,20 年是很长的时间。毫无疑问,Lucene 的长期存在证明了其社区的力量和多样性” — 庆祝 Apache Lucene 20 年

反向索引

反向索引是 Lucene 的核心。反向索引包括两部分 —— 左边是术语字典,右边是每个术语的 postings。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1: 术语字典和 postings 列表形成 Lucene 中的反向索引。所有图片均由作者提供,除非另有说明。

Postings 是关于术语在文档中出现的信息。Postings 列表包含术语出现的文档的 Doc ID。

如果定义了,它也可能包括诸如术语在文档中的频率,甚至位置、字符偏移量和有效负载等信息。

是的,这些都是整数,确实有大量的整数需要在 Lucene 中处理。正如 Apache 软件基金会的 Lucene 贡献者 Adrien Grand [1] 所引述的:

“搜索引擎最重要的构建块之一是能够高效地压缩和快速解码排序的整数列表”

在接下来的部分中,我们将深入探讨 Lucene 用于编码和压缩整数的技术,特别是来自 postings 列表的整数 —— Doc ID 和术语频率。

Delta 编码

让我们首先看看 Lucene 如何在磁盘上编码和存储 postings 数据。包含每个术语的文档列表保存在 .doc 文件中。跳过数据也保存在同一个文件中,但在本文中不作讨论。

首先,正如 图 1 所示,每个术语所指向的 Doc ID 基本上是一个排序好的整数列表。对于每个术语,我们首先将排序好的 Doc ID 列表转换为 Doc Deltas。

通过 delta 编码,Doc Deltas 是通过计算每个 Doc ID 和前一个 Doc ID 之间的差异得到的,除了第一个 Doc ID。

接下来,将 Doc Deltas 切分为固定的 128 个整数块。这些块被称为 PackedDocDeltaBlock。每个块随后使用 PackedInts 进行编码,这是一种 Lucene 实现的位打包方式。剩余的 Doc Deltas 则用 VInt 编码。

以下图是一个简化的示意图,其中 PackedDocDeltaBlock 的块大小为 4,而不是实际的 128。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2: Doc ID 的编码过程

你是否注意到,大多数整数在经过 delta 编码后变得更小了?

较小的整数需要更少的位来表示,这对于使用 PackedIntsVInt 进行编码的下一步至关重要。

PackedInts

**Integer data types:**
Byte                  = 8 bits
Short integer (*short*) = 16 bits (2 bytes)
Integer       (*int*)   = 32 bits (4 bytes)
Long integer  (*long*)  = 64 bits (8 bytes)

通常,位打包将多个值在位级别组合成一个或多个字节(或一个或多个长整型*)。例如,四个 2 位的值可以打包成一个字节,八个 16 位的值可以打包成两个长整型。

通过位打包,存储整数所需的空间可以显著减少。

这是可能的,因为典型的 32 位 *int*(最常用的整数类型)几乎总是包含位级别的前导零,位打包会丢弃这些零。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3: 将 4 个整数打包到 1 个字节中的示例,每个值占用 2 位。存储空间从 16 字节减少到 1 字节。

PackedInts 中,数据以每个值消耗固定数量位数的方式存储,这个位数在 1 到 64 之间,被称为bitsPerValue

在 Doc IDs 的编码过程之后,如果一个数据字段定义为包含术语频率(即该字段的索引选项设置为 IndexOptions.DOCS_AND_FREQS),那么每个 PackedDocDeltaBlock 都会紧随其后一个 PackedFreqBlock

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4: PackedDocDeltaBlockPackedFreqBlock

正如其名,PackedFreqBlock 包含在前一个 PackedDocDeltaBlock 中出现的术语的对应频率。与文档 ID 不同,术语频率没有经过 delta 编码。

每个 PackedDocDeltaBlockPackedFreqBlock 都使用 PackedInts 独立编码。bitsPerValue 是根据块中表示最大整数所需的位数得出的。尽管如此,每个值所消耗的位数可能会比预期的更多。这是为什么,如何发生的?

实际上,Lucene 可能会调整 bitsPerValue 以基于一个称为 acceptableOverheadRatio 的参数来获得最佳的读写性能。这个参数是为了在内存效率与快速随机读取之间进行权衡所愿意接受的开销。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5: PackedInts 中的压缩模式

在 Lucene 中,bitsPerValue调整 是以这样一种方式进行的:当 bitsPerValue 增加到 8、16、32 或 64 时,结果开销不会超过 acceptableOverheadRatio。但为什么是 8、16、32、64,而不是其他数字?

大多数情况下,当表示一个值的位数是字节对齐或是 8 的倍数(即 8、16、32、64)时,读写性能最佳。因为没有值在一个字节内共同存在,读写操作得以简化。

换句话说,1 字节、2 字节、4 字节或 8 字节的空间完全用于表示一个值。

在内存效率最差的情况下,bitsPerValue 从 1 调整到 8。对于每一个存在的有效位,消耗 7 个其他未使用的位。这导致了 700% 的内存开销。实际上,即使只有 1 位在使用,每个值也会消耗 8 位。

acceptableOverheadRatio 为 7 时,随机读取访问速度往往最快。这是当bitsPerValue从 1 到 7 调整为 8,bitsPerValue从 9 到 15 调整为 16,以此类推时的结果。实现的内存开销有不同程度,最高达到 700%。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:当每个值的位数调整为 8 时的内存开销

另一方面,当acceptableOverheadRatio为 0 时,bitsPerValue保持不变,不进行调整。数据被紧密打包以实现最佳内存效率,但随机读取可能较慢。可能会有多个值占据一个字节的空间。因此,表示一个值的位可能会溢出到下一个字节。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:具有每值 6 位的紧凑数据示例。字节数从 4 减少到 3。

综上所述,Lucene 默认使用的PackedInts压缩模式的acceptableOverheadRatio为 0.25。此设置确保任何产生的内存开销永远不会超过 25%。

VInt

VInt 是一种基础 128 压缩类型,生成可变长度整数。每个整数单独编码为 1 到 5 字节。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8:将 17000 转换为VInt的示例

生成VInt时,位被分成 7 位一组,从右侧的低位开始。

对于每个从右到左的 7 位块,另一个位被添加以形成一个字节。这个位作为续接标志,构成字节的最高有效位。如果后面还有更多字节,这个位的值为1,否则为0

使用这种表示方式,1 字节足以表示从 0 到 127 的小整数。大多数整数需要 3 字节或更少,因为 3 字节的VInt能够表示 16,384 到 2,097,151 之间的值。

之前,参见图 2,我们提到剩余的文档增量是用VInt编码的。当术语频率被索引时会发生什么呢?

在这种情况下,文档增量现在定义了文档编号和频率。表示文档增量的位将向左移动一步,这样最不重要的位现在用于标记频率是否为 1。如果频率为 1,则最不重要的位为1,否则为0

如在Lucene90PostingsFormat中所记录,当文档增量为奇数时,频率为 1。当文档增量为偶数时,频率被读取为另一个VInt

下图显示了如何在VInt中对文档增量7, 10(其中词语分别出现一次和三次)进行编码,序列为15, 20, 3,当词频被索引时。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9: VInt 编码与不编码词频的对比

FixedBitSet

在 Lucene 中,FixedBitSet是一个固定长度的位图实现,用于在内存中存储文档 ID。

位图是一组映射到整数列表的位。一个被设置为1的位表示一个整数,其值是该位的索引。

FixedBitSet在 Lucene 中内部实现为*long[]*整数数组,因此每个整数占 64 个位。该数组的长度(或数组中的整数数量)基于位图所需的位数来确定。

例如,要编码一个最大值为 190 的文档 ID 列表,需要至少 191 个位来表示从 0 到 190 的位图索引。因此,将分配一个长度为 3 的数组,该数组能够容纳3*64=192位。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10: FixedBitSet 示例 — bitset[] 数组包含 3 个 64 位的长整数

在上述示例中,FixedBitSet使用 24 个字节来编码一个最大值为 190 的 6 个整数的列表。这是稀疏数据的一个例子,其中在位集中的 192 个位中仅有 6 个位被设置为1

在这里,如果这些整数使用*int*类型存储,则内存没有节省,所用的字节数相同。

这表明,当其表示的数据是稀疏时,FixedBitSet或一般位图的效率较低。

RoaringDocIdSet(Roaring 位图)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Glen Carrie拍摄,来源于Unsplash

在 Lucene 中,查询通过 LRU 缓存,LRU 是一种缓存方案,当缓存满时会驱逐最少使用的项以为新的项腾出空间。缓存允许快速访问经常查询的数据。

并非所有查询在 Lucene 中都被缓存。但是,对于那些被缓存的内容,缓存的内容包含文档 ID 结果集。

Lucene 中的LRU 查询缓存对密度小于 1% 的集合使用RoaringDocIdSet。否则,使用FixedBitSet

RoaringDocIdSet 是受Roaring Bitmaps的思想和设计结构启发的实现。那么Roaring Bitmaps是什么呢?正如 roaringbitmap.org所描述的那样,

Roaring 位图是压缩位图,其性能通常优于传统的压缩位图,如 WAH、EWAH 或 Concise。在某些情况下,它们的速度可以快几百倍,而且通常提供显著更好的压缩效果。

Roaring Bitmaps通过将数据分区并存储到不同的容器中来工作。Roaring Bitmaps中的稀疏和密集数据容器根据容器的基数以不同的方式存储。在 Lucene 的文献中,这些容器被称为块。

RoaringDocIdSet中,块号由 16 个最重要的位标识。剩余的 16 个最低有效位是将存储在块中的值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11:文档 ID 的二进制格式示例

从上述示例可以看出,前四个文档 ID 会被存储在Block 0中。接下来的两个文档 ID 将存储在Block 1中,而最后三个文档 ID 则存储在Block 4中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 12:RoaringDocIdSet中数据的块分区

这样,RoaringDocIdSet 可以容纳最多2¹⁶ = 65536个块,每个块可以存储最多 65536 条记录。

那么这些块中的数据究竟是如何存储的呢?

每条记录 16 位(或 2 字节),一个*short[]*整型数组占用 128 kB 存储 65536 条记录。数组所需的空间随着记录数量线性增长。

相反,一个可以容纳 65536 位的位图仅占用 8 kB。与 128 kB 相比,这是一种巨大的差异,空间减少了 16 倍。因此,人们倾向于认为使用位图更为高效。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 13:使用short[]整型数组与位图

但稍等,我们来做些分析,仔细查看图表。可以观察到,当记录总数低于 4096 时,使用*short[]*整型数组实际上占用的空间不到 8 kB。

这就是决定每个块存储方法的原因。

使用混合数据结构,包含少于 4096 条记录的稀疏块使用*short[]*整型数组存储,而包含 4096 条或更多记录的密集块则使用位图存储。

Lucene 进一步改进了这一点,通过使用*short[]*整型数组存储超密集块的集合的反向数据。

这意味着当记录数超过 61440 时,存储的是具有不到 4096 个值的集合的逆。这是多么聪明的做法!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 14:对于超密集块,Lucene 使用*short[]*整数数组存储集合的逆。

有趣的是,RoaringDocIdSet在与FixedBitSet进行基准测试时的表现。根据这个补丁,从图表可以观察到,当 Doc ID 集的密度低于 1%时,

  • RoaringDocIdSet的内存占用可比FixedBitSet小超过 128 倍。

  • RoaringDocIdSet的构建时间可快约 64 倍于FixedBitSet

  • RoaringDocIdSet的迭代性能和跳过性能(使用*nextDoc()**advance()*)可快约 90 倍于FixedBitSet

相反,当 Doc ID 集的密度高于 1%时,FixedBitSet的表现优于RoaringDocIdSet

关键要点

  • 在压缩方面,没有一种通用的方法。为了实现最佳结果,Lucene 使用了多种技术和策略来处理整数压缩。

  • Delta 编码在有效减少整数大小后,再进行PackedIntsVInt的压缩中起着重要作用。

  • 如果存在大值会怎样?数据压缩质量会受到影响,因为块中的最大整数决定了用于PackedInts的每个值的位数。将 Doc Deltas 和术语频率拆分成固定大小的块是缓解此问题的明智方法。其影响仅限于块内的数据,而其他数据保持不变。

  • 尽管位图非常适合于密集整数集合,但看到RoaringDocIdSetRoaring Bitmaps的一种变体)以巧妙的方式处理密集和稀疏集合,确实令人着迷。

结论

Lucene 的大部分工作涉及整数。因此,整数压缩在减少存储和内存占用,以及缩短从磁盘或内存读取或写入数据的传输时间方面至关重要。

如 Lucene 所示,采用正确的策略来匹配用例,并通过创新方式优化高效访问,是促成搜索引擎领域持续增长和发展的成功因素之一。

这些实现可以在存储、内存和网络带宽方面带来显著的成本节约,同时提升性能。

参考

[1] A. Grand, 参考框架与 Roaring 位图(2015)

[2] 庆祝 Apache Lucene 成立 20 周年

[3] 咆哮位图:更好的压缩位集

[4] S. Chambi, D. Lemire, O. Kaser 和 R. Godin, 使用咆哮位图提高位图性能 (2016)

[5] D. Lemire, G. Ssi-Yan-Kai 和 O. Kaser, 使用咆哮位图实现一致更快且更小的压缩位图 (2018)

[6] D. Lemire, O. Kaser, N. Kurz, L. Deri, C. O’Hara, F. Saint-Jacques 和 G. Ssi-Yan-Kai, 咆哮位图:优化软件库的实现 (2022)

[7] V. Oberoi, 咆哮位图简介:它们是什么以及如何工作 (2022)

[8] D. Lemire 和 L. Boytsov, 通过向量化每秒解码数十亿个整数 (2021)

在你离开之前…

🙏 感谢你阅读这篇帖子,希望你喜欢了解 Lucene 中的整数编码和压缩。

👉 如果你喜欢我的帖子,不要忘记点击 关注订阅,以便在我发布新内容时通过电子邮件收到通知。

😃 可选地,你也可以 注册 成为 Medium 会员,以获得对 Medium 上每个故事的完全访问权限。

📑 访问这个 GitHub 仓库,获取我在帖子中分享的所有代码和笔记本。

© 2023 保留所有权利。

Ludwig — 一个“更友好”的深度学习框架

原文:towardsdatascience.com/ludwig-a-friendlier-deep-learning-framework-946ee3d3b24

使用这个低代码、声明式框架让深度学习变得简单

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 John Adeojo

·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 6 月 26 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:作者:使用 Midjourney 生成

背景 — 深度学习,是否过于复杂?

我一直倾向于避免将深度学习应用于行业用例。这并不是因为缺乏兴趣,而是因为我觉得流行的深度学习框架很繁琐。我欣赏 PytorchTensorFlow 是用于研究目的的绝佳工具,但它们的 API 并不是最用户友好的。在需要为客户快速交付概念验证的情况下,我最不希望做的就是捣鼓 Pytorch 张量。

在伦敦参加 AI 峰会时,我偶然发现一个团队声称他们有解决我的深度学习问题的方案。他们采用了一种不同的方法,他们将其描述为“介于 TensorFlow 和 AutoML 之间的中点”,这是一个名为 Ludwig 的框架。

什么是 Ludwig?

Ludwig 是由 Uber 开发的,Ludwig 是一个用于构建深度学习模型的开源框架。它是声明式的,这意味着你无需像在 TensorFlow 中那样一层层构建复杂的模型,而是只需通过配置文件声明模型的结构。这听起来好得令人难以置信,所以我决定亲自体验一下。在这篇文章的其余部分,我将通过一个我从 Kaggle 上获得的示例项目详细描述我对 Ludwig 的体验。在此过程中,我将讨论它的一些优点、痛点,并给出是否值得使用的结论。

注意 — 尽管最初由 Uber 开发,Ludwig 是一个 开源 库,采用 Apache 2.0 许可证。该项目由 Linux Foundation AI & Data主办。我与 Uber 或 Ludwig 的开发者没有商业关系。

项目 — 需求预测

项目简介:预测零售商 WOMart 各个门店的最后 30 天订单。

你的客户 WOMart 是领先的营养和补充品零售连锁店,提供全面的产品以满足你的健康和健身需求。

WOMart 遵循多渠道分销战略,在 100 多个城市拥有 350 多个零售店。

数据

数据集共有 22,265 个观察值,每个观察值对应于一个特定门店的一天销售数据。为了简洁起见,我不会详细介绍数据集的所有细节,但你可以在这里查看一些描述性统计数据。

注意:数据在 Open Data Commons 许可证下可以用于任何目的。

数据字典:

方法论概述

我不会在这里详细介绍方法论,因为这不是本文的主要目的。我将高层次地介绍我如何框定问题,以便为你提供一些背景。

我将预测问题框定为一个“伪”序列到序列深度学习问题。这种方法涉及利用 360 天的时间序列数据点来预测接下来的 30 天的客户订单。我引入了一些分类变量,并且需要为每一天的订单生成单独的预测,这导致了一个略显非传统的设置——因此使用了“伪”序列到序列来描述这个问题。我将在本文后面详细讨论特征工程的具体细节。

除此之外,我遵循的方法论对于模型开发来说是标准的。我将数据分为训练数据集和保留数据集,并对特征和标签进行了重新缩放。模型训练在训练数据上进行,测试在保留数据上进行。

注意:Ludwig 确实提供了在 API 中本地拆分数据的功能。然而,为了保持严谨性,我建立了一个单独的保留数据集。训练数据集随后被进一步划分为训练、验证和测试子集。保留数据集完全被排除,仅用于分析模型生成的预测。

特征工程

在撰写本文时,我认为在 Ludwig 中进行时间序列预测的序列到序列建模是很棘手的。这是因为特征工程。Ludwig API 在处理序列作为输入方面表现出色,但它们尚未(还未)开发出对时间序列作为输出的连贯方法。你可以通过声明多个输出来开发一个“伪”序列到序列模型,但整体特征工程体验感觉相当“黑客”。

序列特征:除了那些随时间变化的特征外,我将所有预测特征工程为“Ludwig 格式”的序列。每个输入序列是每个“时间序列”特征在预定义时间范围内的水平堆叠。每个特征序列在商店级别确定,并封装在数据框的一个单元格中(看起来就像听起来那么乱)。

序列标签:对于序列标签,你必须将序列中的每一点声明为模型的单独标签。结果是我为每个商店声明了 30 个标签,每天一个标签,用于预测订单。

下面是特征工程过程的示例:

数据示例:粗体值将用于构造序列标签,常规值将用于构造序列特征。

特征工程示例:Order_sequence 是一个“Ludwig 格式”的序列。标签会被单独返回,以便后续声明为模型输出(标签)。

设计你的模型

Ludwig API 允许你通过声明方式构建相当复杂和可定制的模型。Ludwig 通过 .yaml 文件来实现这一点。现在,我知道许多数据科学家可能没有使用过*.yaml* 文件,但在软件开发中,这些文件通常用于配置。文件乍一看可能显得吓人,但实际上非常友好。让我们逐步了解一下我创建模型时使用的文件的主要部分。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像:模型架构

在深入配置之前,值得简要介绍一下 Ludwig 深度学习框架的核心架构:架构:编码器、组合器和解码器。你在 Ludwig 中配置的大多数模型将主要遵循这一架构。理解这一点可以简化堆叠组件的过程,从而快速构建你的深度学习模型。

声明你的模型

在文件的最上方,你声明所使用的模型类型。Ludwig 提供了两种选择:基于树的模型和深度神经网络,我选择了后者。

model_type: ecd

声明数据拆分

你可以通过声明拆分百分比、拆分类型以及你要拆分的列或变量来本地拆分数据集。出于我的目的,我希望确保一个商店只能出现在一个数据集中,哈希拆分正好适合这一点。

最佳实践是,我建议在 Ludwig API 之外构建一个保留集,尤其是在进行初步特征工程(如独热编码或归一化)时。这有助于防止数据泄漏。

model_type: ecd
split:
    type: hash
    column: Store_id
    probabilities:
    - 0.7
    - 0.15
    - 0.15
#...omitted sections...

声明模型输入

你通过名称、类型和编码器来声明输入。根据模型输入的类型,你有多种编码器选项。编码器本质上是一种将输入转换为模型可以解读的方式。编码器的选择实际上取决于数据和建模任务。

model_type: ecd
split:
    type: hash
    column: Store_id
    probabilities:
    - 0.7
    - 0.15
    - 0.15
input_features:
  - name: Sales
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Order
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Discount
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: DayOfWeek
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: MonthOfYear
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Holiday
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Store_Type
    type: category
    encoder: dense
  - name: Location_Type
    type: category
    encoder: dense
  - name: Region_Code
    type: category
    encoder: dense
#...omitted sections...

声明组合器

组合器,顾名思义,用于合并你的编码器的输出。Ludwig API 提供了多种不同的组合器,每种都有其特定的使用场景。组合器的选择可能取决于模型的结构和特征之间的关系。例如,如果你想简单地将编码器的输出进行连接,可以使用“concat”组合器;如果你的特征有顺序关系,可以使用“sequence”组合器。

model_type: ecd
split:
    type: hash
    column: Store_id
    probabilities:
    - 0.7
    - 0.15
    - 0.15
input_features:
  - name: Sales
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Order
    type: sequence
    encoder: stacked_cnn
    reduce_output: null 
  # ... omitted sections ...

  - name: Location_Type
    type: category
    encoder: dense
  - name: Region_Code
    type: category
    encoder: dense
combiner:
    type: sequence
    main_sequence_feature: Order
    reduce_output: null
    encoder:
    # ... omitted sections ...

与深度学习的许多方面一样,最佳的组合器选择通常取决于你的数据集和问题的具体情况,并可能需要一些实验。

声明模型输出

完成你的网络就像声明输出一样简单,输出就是你的标签。我对 Ludwig 的时间序列处理的一个小抱怨是,当前你无法(还)声明时间序列输出。正如我之前提到的,你必须通过单独声明时间序列中的每个点来“破解”它。这让我有了三十个单独的声明,说实话看起来非常杂乱。对于每个输出,你还可以指定损失函数,增加额外的可配置性。Ludwig 为不同的输出类型预设了大量选项,但我不确定你是否能够像在 Pytorch 中那样实现自定义损失函数。

model_type: ecd
split:
    type: hash
    column: Store_id
    probabilities:
    - 0.7
    - 0.15
    - 0.15
input_features:
  - name: Sales
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Order
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
# ...omitted sections...

  - name: Location_Type
    type: category
    encoder: dense
  - name: Region_Code
    type: category
    encoder: dense
combiner:
    type: sequence
    main_sequence_feature: Order
    reduce_output: null
    encoder:
        type: parallel_cnn
output_features:
  - name: Order_sequence_label_2019-05-02
    type: number
    loss:
      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-03
    type: number
    loss:
      type: mean_absolute_error
#...omitted sections...

      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-30
    type: number
    loss:
      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-31
    type: number
    loss:
      type: mean_absolute_error
#...omitted sections...

声明训练器

Ludwig 的训练器配置虽然是可选的(因为 Ludwig 提供了合理的默认设置),但允许高度的自定义。这让你能够控制模型训练的具体细节。这包括指定所用优化器的类型、训练轮数、学习率以及早停标准等参数。

model_type: ecd
split:
    type: hash
    column: Store_id
    probabilities:
    - 0.7
    - 0.15
    - 0.15
input_features:
  - name: Sales
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
  - name: Order
    type: sequence
    encoder: stacked_cnn
    reduce_output: null
# ...omitted sections...

  - name: Location_Type
    type: category
    encoder: dense
  - name: Region_Code
    type: category
    encoder: dense
combiner:
    type: sequence
    main_sequence_feature: Order
    reduce_output: null
    encoder:
        type: parallel_cnn
output_features:
  - name: Order_sequence_label_2019-05-02
    type: number
    loss:
      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-03
    type: number
    loss:
      type: mean_absolute_error
#...omitted sections...

      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-30
    type: number
    loss:
      type: mean_absolute_error
  - name: Order_sequence_label_2019-05-31
    type: number
    loss:
      type: mean_absolute_error
trainer:
    epochs: 200
    learning_rate: 0.0001
    early_stop: 20
    evaluate_training_set: true
    validation_metric: mean_absolute_error
    validation_field: Order_sequence_label_2019-05-31

对于你的特定用例,你可能会发现自己定义这些参数会更有益。例如,你可能希望根据模型的复杂性和数据集的大小调整学习率或训练轮数。同样,早停可以成为防止过拟合的有用工具,通过在模型在验证集上的表现不再改善时停止训练过程。

训练你的模型

训练你的模型可以通过 Ludwig 的 Python 实验 API 轻松完成。请参见下面的脚本示例:

其他配置

除了我提到的,Ludwig 还有大量可能的配置。它们都记录得很好且结构清晰。我建议阅读他们的 文档 来熟悉它们。

模型性能分析——简要视图

本文旨在通过一个实际的示例项目来探讨 Ludwig 框架的一些功能。虽然展示模型性能是其中的一部分,但无需深入探讨指标的细节。我将讨论限制在展示一些从模型分析中生成的图表。请注意,全面的端到端脚本在我的 GitHub 上可以找到,链接在文章结尾。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片:过去 30 天的数据中,模型预测(红色)与实际订单(蓝色)的对比。这些例子来自保留集。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片:误差分布,其中误差为实际订单减去预测订单。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片:模型训练的损失曲线。损失指标为均方绝对误差

我的判决

我将首先承认,最初我对 Ludwig 持怀疑态度。然而,在自己实验之后,我相信它的能力,并认为它如承诺般有效。我认为有几个真正令人印象深刻的功能值得突出。

编码体验:编码体验更像是在构建一个精致的乐高模型。你可以通过玩弄组件和不同的架构来找到你的完美模型,确实非常有趣。

文档:文档清晰且结构良好。很容易搞清楚如何实现不同的架构和更改模型配置。大部分文档似乎也很及时更新,这是一大优势。

后端:后端体验非常出色。库的开发者在抽象掉训练深度神经网络所需的大部分常规配置方面做得很好。我在 Google Collab 上训练了我的模型,Ludwig 自动将工作负载转移到 GPU 上。

Ludwig 还有一个很棒的特点,就是后端高度可配置。例如,如果你在运行大规模工作负载并需要一个 GPU 集群,你也可以进行配置!

实验追踪:Ludwig 提供了一个实验 API,可用于在实验运行之间跟踪模型工件。我相信它也与 MLflow 集成,这对于商业规模的 MLOps 来说非常棒。

个人偏好

有一些领域可以进一步增强这个框架,让我们一起来探讨一下。

可视化:Ludwig 确实提供了一个可视化 API 来跟踪数据集中的训练损失。然而,在撰写本文时,它的功能并不特别好,其使用也不够直观。我尝试在 Google Collab 中运行,但没有成功。最终,我通过编写一个 Python 函数来解析 Ludwig 在每次实验运行后保存的 training_statistics.json 文件,创建了自己的损失曲线可视化。

支持:虽然有一定的支持可用,但 Ludwig 的社区似乎还没有 TensorFlow 或 Pytorch 那么广泛。在 GitHub 上提出了一些问题,有些线程可能提供帮助,但大部分情况下,感觉你只能依靠自己。至于 ChatGPT,它提供了截至 2021 年的一定程度的支持。

透明度:Ludwig 在消除构建深度学习模型中更具挑战性的方面表现出色。然而,这也以牺牲透明度为代价,偶尔使日志显得有些难以理解和调试。

结论

在我看来,Ludwig 是一个出色的工具,适合那些希望开始使用深度学习的人,无论是在商业环境中还是仅仅为了学习。虽然它可能对前沿研究目的来说过于高层次,但它非常适合快速解决明确定义的问题。尽管仍然需要对深度学习有一定的理解,但一旦掌握了概念,Ludwig 的入门门槛比 TensorFlow 或 Pytorch 低得多。

端到端的笔记本可以在我的 GitHub 仓库中找到,请随意进行实验。

关注我在 LinkedIn

订阅 Medium 以获取更多我的见解:

[## 通过我的推荐链接加入 Medium — John Adeojo

我分享数据科学项目、经验和专业知识,以助你一臂之力。你可以通过 Medium 注册…

johnadeojo.medium.com](https://johnadeojo.medium.com/membership?source=post_page-----946ee3d3b24--------------------------------)

如果你有兴趣将 AI 或数据科学集成到业务运营中,我们邀请你预约一次免费的初步咨询:

[## 在线预约 | 数据驱动解决方案

通过免费的咨询发现我们在帮助企业实现雄心勃勃目标方面的专业知识。我们的数据科学家和…

www.data-centric-solutions.com](https://www.data-centric-solutions.com/book-online?source=post_page-----946ee3d3b24--------------------------------)

机器学习算法第一部分:线性回归

原文:towardsdatascience.com/machine-learning-algorithms-part-1-linear-regression-a7079238edc9?source=collection_archive---------13-----------------------#2023-01-06

使用线性回归预测钻石价格

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Rohan Vij

·

关注 发表在 Towards Data Science · 12 分钟阅读 · 2023 年 1 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Bas van den Eijkhof 提供,发布在 Unsplash

线性回归是一种强大但相对简单的工具,可以用来理解变量之间的关系。本教程将以初学者友好的方式探索线性回归的基础知识。在本教程结束时,你将对线性回归有一个扎实的理解,并知道如何使用实际数据实现它。

什么是线性回归?

线性回归,作为一种统计方法,首次用于 1877 年,用于预测因变量的值。实质上,它通过对提供给模型的多个点进行“拟合”来最准确地匹配因变量和自变量之间的关系,这类似于散点图。通过图表最容易观察到这一点:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

来源:维基百科

线性回归通过创建一条线性直线(形式为y=mx+b)来最准确地预测因变量的值,通过求解值m(斜率)和b(y 截距)。

最小二乘法

为此,模型使用了一种称为最小二乘法的方法,以最准确地找到最佳拟合线。该方法的目标是尽可能减少特定数据点到最拟合线的偏差平方总和。最拟合的直线将具有最小的最小二乘函数结果值。我们可以使用以下方程计算从每个提供的点到最拟合线的偏差:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者提供。

实质上,线性函数的输出因变量(y 值)从给定数据点的因变量中减去。这个值可以是正数也可以是负数,取决于函数的值是否大于或小于数据点的值。然而,偏差是正还是负并不重要——无论如何,数值都会被平方。

更简单地说,我们可以用一个总和来找到所有偏差平方的总值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者提供。

最小二乘法声称,最准确/拟合的数据线将具有最小的总和(S)。

示例

使用点:(1,2), (3,5), (5,2)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图像由作者提供。

现在是提到为什么最小二乘法中的值被平方的好时机。首先,如前所述,它确保所有偏差都是正值。然而,更重要的是,它确保较大的偏差被赋予更多的权重。这使得拟合线能够更多地关注异常值。

以最左边的图为例。我们可以看到,从直线到前两个数据点的偏差要么是 0,要么可以忽略不计。相比之下,其他两个图的直线在平均上更接近数据点。比较原始(未平方的)偏差值时,我们可以看到它们彼此相当接近。比较平方偏差值时,我们可以看到最左边的图的偏差比右边的图大 600%以上。这是因为较大的偏差受到的惩罚更大,这意味着异常值对最终直线的影响更大。

使用最小二乘法

最小二乘法可以通过两种方式实现。虽然使用矩阵运算是计算效率最高、最广泛使用的方法,我们将探讨使用梯度下降来寻找最佳直线。梯度下降是一种优化算法,我们将在其中计算和的导数,然后根据导数指示的方向调整系数值。这个过程会重复进行,直到找到最优解。这只是梯度下降的简要概述;请关注将为不懂微积分的人解释梯度下降的文章。

MSE 与 SSE

我们将使用均方误差(Mean of Squared Errors)作为我们的成本函数。基本上,我们想要最小化这个成本函数的值,以输出最拟合的直线。之前,我们使用 SSE(平方误差和)来确定哪条直线最适合其数据点。MSE 相当直观——它与 SSE 相同,但我们将最终的和除以数据点的数量:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。

MSE 比 SSE 更受欢迎,有很多原因。

首先,均方误差(MSE)比平方误差和(SSE)对离群点的敏感度更低。由于 MSE 通过数据点的数量来归一化误差,离群点的影响较小。假设数据集中的误差为1, 4, 1, 25。离群点(25)只占通过 MSE 计算误差的 25%。因此,MSE 将是7.75。SSE 将是31

其次,使用 MSE 也允许比较不同直线的拟合度,即使它们使用的数据点数量不同。例如,考虑使用不同数量数据点的两个模型,模型 A 和模型 B。如果模型 A 使用 100 个数据点,模型 B 使用 50 个数据点,大多数情况下模型 A 将有更高的 SSE。然而,如果通过使用 MSE 来归一化误差,无论模型使用多少数据点,这些模型都可以直接进行比较。

上述因素的结合意味着 MSE 比 SSE 更易于解释。数据集中的离群点可能会使一个模型看起来比另一个模型显著更好,如果它们使用 SSE 进行比较,即使该模型可能更好地拟合大多数数据点。

代码时间!

有了我们对线性回归的所有知识,我们现在可以使用 Python 自行实现它!

开始使用

对于本教程,你需要:

  • Python(版本 3.7 或更高)— 推荐有基础经验

(安装教程:www.tutorialspoint.com/how-to-install-python-in-windows

然后,在你的终端中使用pip安装三个包:

  • pip install notebook

  • pip install numpy

  • pip install matplotlib

在终端中运行jupyter notebook。你的默认网页浏览器会打开一个选项卡,在其中你会看到文件资源管理器。简单地进入你希望创建程序的目录,然后创建一个 Python 3 Notebook(在右上角选择new)。你现在应该会看到以下界面:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

你可以通过点击“Untitled”来重命名文件。

初始数据

我们将开始导入 Python 库 NumPy,它在对数字数组进行数学操作时非常有用。然后,我们将定义点的 NumPy 数组,并将斜率和截距变量初始化为 0。看这些点,很容易推断出这些点的最适合的线是y=1x+0。我们使用这样的可预测值,以便对模型进行基准测试。

# Importing numpy for number processing
import numpy as np

# Define the data points as a matrix, where each row represents a data point
# and each column represents a variable
points = np.array([(1, 1), (2, 2), (3, 3)])

# Defining initial values for the slope and y-intercept of the line
slope = 0
intercept = 0

alt + enter来创建一个新单元格。

线性函数

由于我们在执行线性回归,拥有一个可以评估线性函数的函数是有用的:

def result_of_function(independent_variable, slope, intercept):
    """
    Function to model y=mx+b

    :param slope: the slope of the linear function (m)
    :param intercept: the y-intercept of the linear function (b)
    :param independent_variable: the independent variable (x value) being inputted into the linear function (x)

    :returns: the value of the dependent variable of the function (y)
    """
    return independent_variable * slope + intercept

alt + enter来创建一个新单元格。

成本函数

包含一个成本函数来衡量我们回归的有效性也会有帮助:

def cost_function(x, y, slope, intercept):
    """
    Calculate the mean squared error of a linear function with given parameters.

    :param x: The independent variable (x-values) of the data points.
    :param y: The dependent variable (y-values) of the data points.
    :param slope: The slope of the linear function.
    :param intercept: The y-intercept of the linear function.

    :returns: The mean squared error of the linear function.
    """

    # Predict the y-values using the given slope and intercept
    y_preds = result_of_function(x, slope, intercept)

    # Calculate the squared errors between the predicted and actual y-values
    squared_errors = (y_preds - y)**2

    # Return the mean of the squared errors
    return squared_errors.mean()

alt + enter来创建一个新单元格。

梯度下降

我们将开始定义我们的输入和输出值(x 坐标和 y 坐标):

# Define the input and output data
X = np.array(points[:, 0])
Y = np.array(points[:, 1])

最后,我们将实现梯度下降:

alpha = 0.01

# Iterate for a 1000 of epochs
for i in range(1000):
  # Calculate the gradients of J with respect to the slope and intercept
  grad_slope = -2 * ((Y - result_of_function(X, slope, intercept)) * X).mean()
  grad_intercept = -2 * ((Y - result_of_function(X, slope, intercept))).mean()

  # Update m and b using the gradients and the learning rate
  slope -= alpha * grad_slope
  intercept -= alpha * grad_intercept

  print(cost_function(X, Y, slope, intercept))

# Print the final values of m and b
print(f'Final values: slope = {slope}, intercept = {intercept}')

如果这不太有意义,别担心——梯度下降过于复杂,无法在本文中深入解释,所以请留意一篇解释梯度下降的文章,特别是对于不懂微积分的读者!

但你理解周期是很重要的。也就是说,这一行:

for i in range(1000):
  # code

一个周期基本上是梯度下降程序的一次迭代。在每次迭代中,线性函数的斜率和截距会使用梯度下降计算中的数学公式进行调整。运行的周期越多,值调整和微调得越多。

运行此程序后,你应该得到以下斜率和截距值:

Final values: slope = 0.98539125301466, intercept = 0.033209115678908344

如果我们运行更多周期,这些值会更接近 1 和 0。然而,这些值已经非常接近预期结果。

从图形上看,结果如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

结果分析

最终的均方误差(MSE)是0.00015821003618271694——这是一个极低的值。然而,如果我们将 MSE 图形化显示每个周期(或迭代),我们将得到以下图表:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

这些似乎是非常、非常小的收益。实际上,在第 25 个周期左右,MSE 似乎完全没有变化!让我们从不同的角度看这个图,省略前 50 个周期:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

看似直线的并非直线——从第 50 个 epoch 到第 1000 个 epoch,MSE 几乎减小了 100 倍。你可能会问——MSE 约 0.015 不是已经够低了吗?让我们尝试再次运行梯度下降,但这次只用 50 个 epochs:

Final values: slope = 0.8539016923117737, intercept = 0.32575579906831564

接近,但还不够接近:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。

相反,让我们用 100,000 个 epochs 运行梯度下降:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。

完美!似乎运行 100,000 个 epochs 的模型给出了几乎完美的结果。虽然用更多的 epochs 运行线性回归可以提高模型的准确性,但重要的是要考虑准确性和时间之间的平衡。一般来说,您应该使用足够的 epochs 来拟合模型的数据,但不要使用过多,以至于模型训练的时间不必要地延长。通常在各种模型中使用一种叫做早期停止的技术,当模型达到一定的准确性时会自动停止。这允许模型尽可能快地训练,但仍确保一定的准确性。

应用线性回归

最后但同样重要的是,是时候将我们的线性回归知识应用到实际数据上了!

查找数据集

让我们从寻找数据集开始。Kaggle是一个很好的资源,可以找到高质量且在结构或主题上各异的免费数据集。对于这个小项目,我选择使用钻石数据分析数据集,以便开发钻石克拉(自变量)与其价格(因变量)之间的线性关系。

选择数据集

通常,在选择或构建用于线性回归的数据集时,需要考虑以下因素:

  1. 自变量和因变量之间的强线性相关性——如果这些变量似乎没有相关性,或其相关性是非线性的,可能需要选择不同的回归方法。

  2. 异常值——一个好的数据集应该相对没有异常值,因为它们会严重影响回归的性能。

  3. 适用性——数据集必须与您试图解决的问题相关。例如,如果您想根据房屋的平方英尺预测纽约市的房价,那么用蒙大拿乡村农场的数据来训练模型就不合适。

下载数据集

在 Kaggle 上下载数据集非常直观。点击屏幕右上角的黑色下载按钮,将.zip文件保存到计算机上。然后解压缩文件,将其中的.csv文件移动到与 Jupyter Notebook 文件相同的目录中。

实施

现在,剩下的就是使用数据进行线性回归模型。再次地,在这个实例中,我们将根据钻石的克拉数预测价格。注释掉以下行:

X = np.array(points[:, 0])
Y = np.array(points[:, 1])

用以下内容替换它们:

diamond_data = np.genfromtxt('diamonds.csv', delimiter=',')
Y = diamond_data[1:][:, 7] # Costs of the diamonds
X = diamond_data[1:][:, 1] # Carats of the diamonds

运行程序(1000 个周期)后的输出结果是:

Final values: slope = 7756.425617968576, intercept = -2256.3605800455275

最终的 MSE 是:

2397955.0500126793

哇!这是一个极其庞大的数字。诚然,我们使用的钻石数据集存在缺陷。图上的线性回归是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

以 1 克拉钻石的不同价格为例,它的价格范围可以从 ~$1000 到近$20,000!这是一个典型的例子,说明这个数据集在两个变量之间的线性关系不足。在这种情况下,钻石的切工、颜色和清晰度也都对价格产生了重大影响。同时,需要考虑的是,现实世界的数据中,MSE 为 0 几乎是不可能的。现实世界现象受到众多因素的影响,捕捉所有这些因素并在回归模型中反映出来几乎是不可能的。留给读者作为练习的是探索 Kaggle 上更强的两个给定变量之间的线性相关性数据集。

测试

让我们来测试一下我们的模型。根据CreditDonkey,1 克拉钻石的最佳价值在$4500 到$6000 之间。使用以下code

carat = 1
function_result = result_of_function(carat, slope, intercept)

print(f"A {carat}-carat diamond will cost: ${round(function_result, 2)}")

模型输出结果:

A 1-carat diamond will cost: $5488.47

成功!

结论

总结一下——线性回归是一种统计方法,用于理解两个线性相关变量之间的关系。这是通过将一条形式为y=mx+b的直线拟合到提供的自变量和因变量上来完成的。通过使用称为最小二乘法的方法,可以找到最适合的直线,该方法最小化每个点到其对应直线上的点的平方偏差之和。最小二乘法可以通过矩阵运算和梯度下降来实现,本文重点介绍了梯度下降的应用。使用 MSE 成本函数(即平方误差的均值)来确定模型的准确性。通过最小化 MSE,我们可以优化模型并提高其准确性。

我给你留下了一个令人满意的 GIF,展示了模型逐渐收敛到最适合的直线:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

谢谢你的阅读!

机器学习不仅仅预测未来,它还积极地创造未来

原文:towardsdatascience.com/machine-learning-does-not-only-predict-the-future-it-actively-creates-it-1615895c80a9

关于位置偏差的入门(以及它为何重要)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Samuel Flender

·发表于Towards Data Science ·4 min read·2023 年 1 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Stable Diffusion 生成

标准的机器学习课程教导我们,机器学习模型从过去存在的模式中学习,以预测未来。

这是一个很好的简化,但一旦这些模型的预测被用于生产环境中,情况会发生戏剧性的变化,因为它们会产生反馈循环:现在,模型的预测本身正在影响模型试图从中学习的世界。我们的模型不再仅仅是预测未来,它们实际上是在创造未来。

其中一个反馈循环是位置偏差,这一现象已经在排名模型中被观察到,这些模型支持搜索引擎、推荐系统、社交媒体信息流和广告排名器等。

什么是位置偏差?

位置偏差意味着排名最高的项目(Netflix 上的视频、Google 上的页面、Amazon 上的产品、Facebook 上的帖子或 Twitter 上的推文)之所以创造了最多的互动,不是因为它们实际上是用户最需要的内容,而仅仅是因为它们的排名最高。

这种偏差的表现形式是因为排名模型非常好,以至于用户开始盲目相信排名最高的项目,而不再进一步查看(“盲目信任偏差”),或者用户根本没有考虑其他可能更好的项目,因为它们的排名太低,用户甚至没有注意到(“展示偏差”)。

为什么这是一个问题?

让我们回到基础。排名模型的目标是展示最相关的内容,按参与概率的顺序排序。这些模型是基于隐式用户数据进行训练的:每次用户点击搜索结果页面上的一个项目或参与界面时,我们将该点击作为下一个模型训练迭代中的正标签。

如果用户只是因为内容的排名而非其相关性而开始与内容互动,我们的训练数据就会被污染:模型不仅仅是从用户真正想要的东西中学习,而是从自身过去的预测中学习。随着时间的推移,预测会变得静态,缺乏多样性。结果,用户可能会感到厌倦或烦恼,并最终转向其他地方。

另一个位置偏差的问题是离线测试变得不可靠。根据定义,位置偏倚的用户参与数据总是会偏向现有的生产模型,因为这是生成用户看到的排名的模型。一个实际上更好的新模型在离线测试中可能看起来更差,可能会被过早地丢弃。只有在线测试才能揭示真相。

我们如何减轻位置偏差?

模型从数据中学习,因此为了去偏模型,我们需要去偏训练数据。正如Joachims et al(2016)所示,这可以通过根据位置偏差的倒数加权每个训练样本来实现,为低偏差的样本赋予更多权重,为高偏差的样本赋予较少的权重。直观地,这很有意义:点击排名第一的项目(具有高位置偏差)可能比点击第十个项目(具有低位置偏差)信息量少。

因此,减轻位置偏差的问题归结为测量它。我们如何做到这一点?

一种方法是结果随机化:对于服务人群中的一个小子集,简单地随机重新排序前 N 项,然后测量在该人群中排名变化所引起的参与度变化。这种方法有效,但成本较高:随机搜索结果或推荐,尤其是对于较大的 N,会导致用户体验较差,从而影响用户留存率和商业收入。

因此,更好的替代方法可能是干预采集,由Argawal et al(2018)在全文档搜索的背景下提出,同时由Aslanyan et al(2019)在电子商务搜索的背景下提出。关键思想是,成熟排名系统中记录的用户参与数据已经包含了来自多个不同排名模型的排名,例如来自历史 A/B 测试或仅仅来自时间上推出的不同版本的生产模型。这种历史多样性在排名中创造了固有的随机性,我们可以“采集”这些数据来估计位置偏差,而无需任何昂贵的干预。

最后,还有一个更简单的想法,即谷歌的“规则 36”。他们建议在训练模型时将排名本身作为另一个特征添加,然后在推断时将该特征设置为默认值(例如 -1)。直觉是,通过提前将所有信息提供给模型,它会在后台隐式地学习参与模型和位置偏见模型。无需额外步骤。

最终思考

让我们回顾一下。位置偏见是一个在整个行业中都被观察到的真实问题。它之所以成问题,是因为它可能会限制排名模型的多样性。但我们可以通过使用偏见估计对训练数据进行去偏差来减轻它,这些偏见估计可以通过结果随机化或干预采集获得。另一种减轻策略是直接将排名作为模型特征,并让模型隐式地学习偏见,无需额外步骤。

从整体上考虑,位置偏见的存在确实有些讽刺。如果我们不断改进我们的排名模型,这些改进可能会导致越来越多的用户盲目相信排名最高的结果,从而增强位置偏见,并最终降低我们的模型效果。除非我们采取有意识的步骤来监测和减轻位置偏见,否则任何模型改进最终可能会变得适得其反。

[## 现实中的机器学习:真实世界 ML 应用的设计与操作

这本书是什么?机器学习最奇特的事情之一是学术 ML 研究的二分法…

samflender.gumroad.com](https://samflender.gumroad.com/l/mlontheground?source=post_page-----1615895c80a9--------------------------------)

机器学习工程师——他们实际上做什么?

原文:towardsdatascience.com/machine-learning-engineers-what-do-they-actually-do-e666081c3181?source=collection_archive---------7-----------------------#2023-08-09

“机器学习工程师”对我们领域来说意味着什么新事物吗?如果是,那么是什么呢?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Stephanie Kirmer

·

关注 发表在 Towards Data Science ·4 分钟阅读·2023 年 8 月 9 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Letizia Bordoni,来自 Unsplash

这个标题当然是一个圈套问题。就像之前的数据科学家一样,机器学习工程师这个头衔正在成为我们职业市场的一个趋势,但对这个头衔的含义或它应该包含的职能和技能没有共识。我想新进入数据科学/机器学习领域的从业者会发现这很难解读。(即便是有经验的人也会如此!)所以,让我们谈谈根据谁在说话,它可能意味着什么。

当我前几天与朋友讨论这个问题时,我用“MACHINE LEARNING 工程师”或机器学习 ENGINEER 来描述它。基本上,根据我看到的情况,这些头衔下的角色和期望要么是:

  • A. 期望具备广泛的软件工程技能,并且对机器学习有一定的经验或至少熟悉,或者

  • B. 对机器学习经验有较高期望,通常包括深度学习或生成式 AI,并且他们希望你能够在需要时编写一个函数。

以前这一类人可能只是“软件工程师”,而后一类人则舒适地归入“数据科学家”之下,回到我刚开始职业生涯的时候(尽管当时生成式 AI 确实不是游戏的一部分)。

这反映了我们职业领域更广泛的发展中的一个有趣模式。我们一直没有很好地将我们领域的角色划分为明确的子类别,以清晰界定角色的技能集(或职责)。这是一个快速发展、不断变化的年轻领域,所以这并不令人惊讶!这一直以来都是数据科学家这个头衔的特点,它本质上是“比数据分析师更具技术技能”的一个标识。曾经有些人把数据科学家称为能够处理非结构化或无序数据的人,而从我看来,这个定义因素已经不再存在。

我强烈怀疑 MLE 的增长是因为招聘 SWE 类型的人才时,雇主们对找不到懂得机器学习模型的人感到不满,而招聘数据科学家时,他们得到的是分析专家而不是具有机器学习技能的建模师。他们从两个方向交汇,形成了一个新头衔,在这个头衔下,对于每项技能的重视程度存在内部分歧。因此,现在我们有了一个新的领域需要思考。

虽然这一领域的细分可能非常自然,作为对这种困难的回应,我想指出这对候选人和领域的意义。每当发生新的分化,职业路径有了新的可能分支时,这两个方向会被赋予不同的地位和特权,最常通过每个方向的薪资差异来体现。现在,随着数据科学领域的正规化及更多教育机会的出现,人们进入这一职业的途径变得更加容易。这包括在更广泛社会中处于劣势或边缘化的人。我相信我们面临着数据科学家“粉红领”效应的风险。

(简而言之,粉色领带效应是指当女性在某一领域中的工作比例增加时,那些以她们为主的角色的薪资和社会地位系统性地降低。兽医学是一个常见的例子。情况也会相反,例如 1960 年代和 1970 年代初,女性在计算机编程领域占主导地位,当男性在该领域的代表性增加时,他们的薪资和声望也随之上升。)

这真的发生了吗?我不完全确定。我仅从像 Harnham 和 Burtch Works 这样的行业报告以及浏览 LinkedIn 等地方的招聘信息中看到一些轶事证据,这些证据表明数据科学家和机器学习工程师之间的薪资差距似乎正在出现。我确实比五年前遇到更多年轻女性、有色人种及不同性别认同和性取向的人在数据科学家角色中。

我非常希望研究人员能够发现这一薪资变化是否在统计上显著,如果显著的话,是否与我怀疑的员工人口统计变化相对应。

无论如何,对招聘领域的挑战是,不让更具声望的、更“技术性的”职位(例如现在的机器学习工程师)被男性和具有社会优势的人主导,同时相应地确保数据科学家职位不会成为一个较低地位的变种,导致其他人无论能力如何都被排挤。给这些职位支付符合你业务价值的薪资,但不要让这影响你考虑或设想每个角色中的人员组成。这是我们在不断发展的游戏阶段中能做到的最低限度。

你可以在 www.stephaniekirmer.com上找到更多我的作品。

使用机器学习进行柔术

原文:towardsdatascience.com/machine-learning-for-jiu-jitsu-94a0b44f57ab

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Kampus Production 提供,来自 Pexels: https://www.pexels.com/photo/a-judoka-throwing-an-opponent-to-the-ground-6765024/

使用 mediapipe 的姿态估计来跟踪柔术动作

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Lucas Soares

·发表于Towards Data Science ·阅读时间 18 分钟·2023 年 3 月 13 日

姿态跟踪以提升柔术水平

巴西柔术是一种最近因其在实际战斗中的有效性和适用性而变得非常受欢迎的武术。

我已经练习巴西柔术超过 10 年,并决定将我对武术和机器学习的兴趣结合起来,提出一个位于这两个非常有趣领域交汇处的项目。

因此,我转向了姿态估计,作为一种有前景的技术,用于作为辅助工具帮助我在柔术中的发展。

在本文中,我想与大家分享如何使用姿态跟踪来增强战斗动作中的反馈纠正。

如果你更喜欢视频,可以在这里查看我关于此主题的 YouTube 视频:

什么是姿态跟踪?

姿态跟踪是利用计算机视觉技术实时检测和跟踪人体运动的过程。它涉及使用算法捕捉和解释各种身体部位(如手臂、腿部和躯干)的运动。

这一技术对于分析运动中的身体动作具有相关性,因为它允许教练和运动员识别和纠正可能对表现产生负面影响或导致受伤的运动模式。

通过提供实时反馈,运动员可以调整他们的技术,从而提高表现并减少受伤风险。此外,这项技术还可以用于将动作与顶级运动员的动作进行比较,例如,帮助初学者识别需要改进的领域,并相应地完善他们的技术。

什么是 Jiu Jitsu?

Jiu Jitsu 是一种以通过组合使用固定技和提交保持(如关节锁和窒息)来制服对手的武术。

Jiu Jitsu 专注于抓取和地面战斗技巧。它最初在日本开发,后来在巴西进行了修改和推广。但现在,由于其在美国特别是普及率的增加,它已经传播到全球。

基本原则是,较小、较弱的人可以通过使用杠杆和技巧来防御较大、较强的对手。练习者的目标是控制对手的身体,并将自己置于一个主导位置,以便执行诸如窒息、关节锁和投掷等技巧。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Timoth Eberly 提供,链接: https://unsplash.com/photos/7MRajrPiTqw*

Jiu Jitsu 现在是一项在全球范围内流行的运动和自我防卫系统。它要求身体和心理上的纪律,以及学习和适应的意愿。

研究还发现,它具有许多好处,包括 改善身体健康和心理敏锐性、增加自信和自尊、以及缓解压力

为什么选择 Jiu Jitsu 的姿势追踪

对技巧的高度重视使得这门武术相当独特,在 Jiu Jitsu 俱乐部的环境中,通常是黑带教练的职责来给学生反馈,评价他们对不同技巧的执行是否得当。

然而,人们常常希望学习,但要么无法接触到专家,要么班级人数过多,导致授课者难以提供具体和个人化的反馈,无法确定学生是否正确执行了动作。

在这种反馈的空白中,我认为像姿势追踪这样的工具可以极大地惠及武术世界,尤其是 Jiu Jitsu(尽管可以对柔道、摔跤和以打击为基础的武术做同样的论证),因为它们可以无缝集成到智能手机中,只需运动员在尝试改进的动作时拍摄自己。

这种反馈的形式需要进行开发,本文旨在提供有关这种基于机器学习的反馈系统如何帮助学生提高运动基础动作的指导。

我为什么要这样做?

好的,故事是这样的。

通常,当你发展你的柔术技能时,你最终会落入两个类别之一:下位选手或上位选手。这意味着你是否倾向于从下方使用“防守”(指使用双腿对对手进行攻击)进行比赛,还是从上方先将对手摔倒,然后继续穿过对方的双腿(通常)达到如骑乘对手或抓住对手背部等占据主导地位的目标。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由诺兰·肯特提供,来源于 unsplash.com/photos/x_V62hOwnDk?utm_source=unsplash&utm_medium=referral&utm_content=creditShareLink

这种二元性显然是人为的,通常,大多数经验丰富的选手能很好地掌握两种位置。

然而,许多人在柔术旅程的开始时倾向于偏好某些技巧,这可能会严重影响他们在其他领域的进步,如果他们不断重复同样的策略的话。

在某种程度上,这就是我所经历的,我曾经作为防守选手打斗很多,这主要是由于巴西的主流文化,鼓励从膝盖开始摔跤以避免受伤,或者因为比赛垫的空间不像美国大中学体育馆中的摔跤垫那么大。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。比赛中我进行防守的照片。

这种主动坐下来并从背后打斗的习惯,未能让对手参与站立战斗,对我的武术发展产生了负面影响,因为随着我在柔术中变得越来越好,我意识到阻碍我的一个因素是我缺乏高水平的摔倒对手的知识。

这激发了我从站立位置开始更多地进行训练,于是我在获得棕带的几年后,开始学习和练习摔跤和柔道。

在过去的 2 年里,我主要是一个上位选手,确实在将对手摔倒在地的能力上有了很大提升。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片。我的摔跤之旅。

不过,例如在柔道中有一些基础动作非常难以掌握,因为我不认识任何柔道专家,也不住在任何高水平的柔道或摔跤馆附近,我意识到我需要另一种方式来提高某些基础动作,特别是像“内腿技”和其他基于髋部的摔法的髋部灵活性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Kampus Production 提供,来源于: https://www.pexels.com/photo/a-judoka-throwing-an-opponent-to-the-ground-6765024/

机器学习的作用

好的,为了提高我执行像内股这样的柔道投掷技术的能力,我制定了一个“书呆子”的计划:我要使用机器学习(我知道,这个计划真是太具体了)。

我决定要调查是否可以使用姿态追踪来获取关于如何纠正脚的速度和方向以及执行这些动作的其他方面的见解。

那么让我们来看看我是怎么做的。

使用姿态追踪生成柔术见解的步骤

整体计划是这样的:

1. 找到一个包含我想要模仿动作的视频参考

2. 录制自己多次执行该动作

3. 使用姿态追踪和 Python 可视化生成见解。

为了做到这一点,我需要一个顶级练习者执行我试图学习的动作的参考视频。对于内股,我找到了一段奥林匹克级别选手在墙边进行热身的技术视频,这直接与我想要学习的内容相关:

然后我开始录制自己执行这些动作的视频,至少是我正在积极学习的某些动作。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。

拥有了参考视频,并且录制了一些自己的镜头之后,我准备尝试一些有趣的机器学习内容。

姿态追踪来追踪身体关节

对于姿态追踪,我使用了一个叫做mediapipe的工具,这是谷歌的开源项目,旨在促进机器学习在实时和流媒体中的应用。

## GitHub - google/mediapipe: 跨平台、可定制的机器学习解决方案,适用于实时和流媒体。

MediaPipe 提供跨平台、可定制的机器学习解决方案,适用于实时和流媒体。端到端加速……

github.com

这个选项的易用性让我很兴奋,迫不及待地想要尝试。

本质上,我做了以下工作:

1. 首先,我创建了一些叠加姿态估计的视频

2. 创建了实时绘图,展示脚的 x、y 和 z 坐标,以说明动作的主要方面

3. 创建了表示某个动作在特定时间执行的轨迹

4. 将我的尝试所产生的轨迹与专家视频生成的参考轨迹进行比较

初步结果

1. 姿态估计叠加

我写了这段代码来创建模型估计身体关节位置的视频

并将其叠加在实际镜头中,以展示模型的鲁棒性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

是的,是的,我知道,我看起来并不完全像顶级选手。但给我一点时间,我的柔道技能还在建设中!

我用来做这个的代码是:

from base64 import b64encode
import cv2
import mediapipe as mp
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import numpy as np
from natsort import natsorted
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import clear_output
%matplotlib inline
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import HTML, display
import ipywidgets as widgets
from typing import List # I don't think I need this!

# Custom imports
from pose_tracking_utils import *

mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose

def create_pose_tracking_video(video_path):
    # For webcam input:
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_path = pathlib.Path(video_path).stem + "_pose.mp4" 
    out = cv2.VideoWriter(output_path, fourcc, 30.0, (frame_width, frame_height))
    with mp_pose.Pose(min_detection_confidence=0.5,
                      min_tracking_confidence=0.5) as pose:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Ignoring empty camera frame.")
                break
            # To improve performance, optinally mark the iamge as 
            # not writeable to pass by reference.
            image.flags.writeable = False
            image= cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            results = pose.process(image)
            # Draw the annotation on the image.
            image.flags.writeable = True
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            mp_drawing.draw_landmarks(image, results.pose_landmarks,
                                      mp_pose.POSE_CONNECTIONS,
            landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style())

            # Flip the image horizontally for a self-view display.
            out.write(cv2.flip(image, 1))
            if cv2.waitKey(5) & 0xFF == 27:
                break

    cap.release()
    out.release()
    print("Pose video created!")

    return output_path

这基本上利用了 mediapipe 包来生成一个可视化,它检测关键点并将其覆盖在视频画面上。

2. 脚部的 X、Y 和 Z 坐标的实时图表

VIDEO_PATH = "./videos/clip_training_session_1.mp4"
# Initialize MediaPipe Pose model
body_part_index = 32
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5)

# Initialize OpenCV VideoCapture object to capture video from the camera
cap = cv2.VideoCapture(VIDEO_PATH)

# Create an empty list to store the trace of the right elbow
trace = []

# Create empty lists to store the x, y, z coordinates of the right elbow
x_vals = []
y_vals = []
z_vals = []

# Create a Matplotlib figure and subplot for the real-time updating plot
# fig, ax = plt.subplots()
# plt.title('Time Lapse of the X Coordinate')
# plt.xlabel('Frames')
# plt.ylabel('Coordinate Value')
# plt.xlim(0,1)
# plt.ylim(0,1)
# plt.ion()
# plt.show()
frame_num = 0

while True:
    # Read a frame from the video capture
    success, image = cap.read()
    if not success:
        break
    # Convert the frame to RGB format
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Process the frame with MediaPipe Pose model
    results = pose.process(image)

    # Check if any body parts are detected

    if results.pose_landmarks:
        # Get the x,y,z coordinates of the right elbow
        x, y, z = results.pose_landmarks.landmark[body_part_index].x, results.pose_landmarks.landmark[body_part_index].y, results.pose_landmarks.landmark[body_part_index].z

        # Append the x, y, z values to the corresponding lists
        x_vals.append(x)
        y_vals.append(y)
        z_vals.append(z)

        # # Add the (x, y) coordinates to the trace list
        trace.append((int(x * image.shape[1]), int(y * image.shape[0])))

        # Draw the trace on the image
        for i in range(len(trace)-1):
            cv2.line(image, trace[i], trace[i+1], (255, 0, 0), thickness=2)

        plt.title('Time Lapse of the Y Coordinate')
        plt.xlabel('Frames')
        plt.ylabel('Coordinate Value')
        plt.xlim(0,len(pose_coords))
        plt.ylim(0,1)
        plt.plot(y_vals);
        # Clear the plot and update with the new x, y, z coordinate values
        #ax.clear()
        # ax.plot(range(0, frame_num + 1), x_vals, 'r.', label='x')
        # ax.plot(range(0, frame_num + 1), y_vals, 'g.', label='y')
        # ax.plot(range(0, frame_num + 1), z_vals, 'b.', label='z')
        # ax.legend(loc='upper left')
        # plt.draw()
        plt.pause(0.00000000001)
        clear_output(wait=True)
        frame_num += 1

    # Convert the image back to BGR format for display
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # Display the image
    cv2.imshow('Pose Tracking', image)

    # Wait for user input to exit
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the video capture, close all windows, and clear the plot
cap.release()
cv2.destroyAllWindows()
plt.close()

然后,我生成了一个包含 x、y、z 坐标时间线的图表:

plt.figure(figsize=(15,7))
plt.subplot(3,1,1)
plt.title('Time Lapse of the x Coordinate')
plt.xlabel('Frames')
plt.ylabel('Coordinate Value')
plt.xlim(0,len(pose_coords))
plt.ylim(0,1)
plt.plot(x_vals)

plt.subplot(3,1,2)
plt.title('Time Lapse of the y Coordinate')
plt.xlabel('Frames')
plt.ylabel('Coordinate Value')
plt.xlim(0,len(pose_coords))
plt.ylim(0,1.1)
plt.plot(y_vals)

plt.subplot(3,1,3)
plt.title('Time Lapse of the z Coordinate')
plt.xlabel('Frames')
plt.ylabel('Coordinate Value')
plt.xlim(0,len(pose_coords))
plt.ylim(-1,1)
plt.plot(z_vals)

plt.tight_layout();

这个想法是为了对诸如执行动作时脚的位置方向等细节进行细致控制。

现在我对模型能够正确捕捉我的身体姿势感到自信,我创建了一些相关身体关节的轨迹可视化,比如脚部(在执行摔跤技术时非常重要)。

3. 创建运动轨迹

为了了解动作的执行情况,我制作了一个可视化,表示从身体部位的角度(在这种情况下是脚)执行该动作的过程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我为我的训练课程和包含我试图模仿的动作的参考视频做了这个。

这里需要注意的是,进行此操作存在许多问题,包括相机的分辨率、动作执行的距离以及每个视频的帧率。然而,我只是绕过了这些问题,创建了一个花哨的图表(哈哈)。

这是这种方法的代码:

def create_joint_trace_video(video_path,body_part_index=32, color_rgb=(255,0,0)):
    """
    This function creates a trace of the body part being tracked.
    body_part_index: The index of the body part being tracked.
    video_path: The path to the video being analysed.
    """
    # Initialize MediaPipe Pose modelpose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5)

    # Initialize OpenCV VideoCapture object to capture video from the camera
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_path = pathlib.Path(video_path).stem + "_trace.mp4" 
    out = cv2.VideoWriter(output_path, fourcc, 30.0, (frame_width, frame_height))

    # Create an empty list to store the trace of the body part being tracked
    trace = []

    with mp_pose.Pose(min_detection_confidence=0.5,
                        min_tracking_confidence=0.5) as pose:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Ignoring empty camera frame.")
                break

            # Convert the frame to RGB format
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Process the frame with MediaPipe Pose model
            results = pose.process(image)

            # Check if any body parts are detected
            if results.pose_landmarks:
                # Get the x,y coordinates of the body part being tracked (in this case, the right elbow)
                x, y = int(results.pose_landmarks.landmark[body_part_index].x * image.shape[1]), int(results.pose_landmarks.landmark[body_part_index].y * image.shape[0])

                # Add the coordinates to the trace list
                trace.append((x, y))

                # Draw the trace on the image
                for i in range(len(trace)-1):
                    cv2.line(image, trace[i], trace[i+1], color_rgb, thickness=2)

            # Convert the image back to BGR format for display
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

            # Display the image
            out.write(image)
            if cv2.waitKey(5) & 0xFF == 27:
                break

    cap.release()
    out.release()
    print("Joint Trace video created!")

在这里,我简单地处理每一帧,就像之前制作姿势视频一样,然而,我还将特定身体部位的 x 和 y 坐标追加到一个我称之为trace的列表中,这个列表用于生成伴随身体部位的轨迹线。

4. 比较轨迹

拥有这些能力后,我终于可以进入从这种方法中获取见解的阶段。

为了做到这一点,我需要一种比较这些轨迹的方法,以生成一些视觉丰富的反馈,这可以帮助我理解自己动作执行的不足与顶级运动员的表现相比如何。

现在,没有背景视频的实际轨迹已被绘制成图表。

def get_joint_trace_data(video_path, body_part_index,xmin=300,xmax=1000,
                             ymin=200,ymax=800):
    """
    Creates a graph with the tracing of a particular body part,
    while executing a certain movement.
    """
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))

    # Create an empty list to store the trace of the body part being tracked
    trace = []
    i = 0
    with mp_pose.Pose(min_detection_confidence=0.5,
                    min_tracking_confidence=0.5) as pose:
        while cap.isOpened():
            success, image = cap.read()
            if not success:
                print("Ignoring empty camera frame.")
                break

            # Convert the frame to RGB format
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Process the frame with MediaPipe Pose model
            results = pose.process(image)

            # Check if any body parts are detected
            if results.pose_landmarks:
                # Get the x,y coordinates of the body part being tracked (in this case, the right elbow)
                x, y = int(results.pose_landmarks.landmark[body_part_index].x * image.shape[1]), int(results.pose_landmarks.landmark[body_part_index].y * image.shape[0])

                # Add the coordinates to the trace list
                trace.append((x, y))

                # Plot the trace on the graph
                fig, ax = plt.subplots()
                #ax.imshow(image)
                ax.set_xlim(xmin,xmax)
                ax.set_ylim(ymin,ymax)
                ax.invert_yaxis()
                ax.plot(np.array(trace)[:, 0], np.array(trace)[:, 1], color='r')
                # plt.savefig(f'joint_trace{i}.png')
                # plt.close()
                i+=1
                plt.pause(0.00000000001)
                clear_output(wait=True)
                # Display the graph
                #plt.show()

            if cv2.waitKey(5) & 0xFF == 27:
                break

        cap.release()

        return trace

video_path = "./videos/clip_training_session_2.mp4"
body_part_index = 31
foot_trace = get_joint_trace_data(video_path, body_part_index)

video_path = "./videos/uchimata_wall.mp4"
body_part_index = 31
foot_trace_reference = get_joint_trace_data(video_path, body_part_index,xmin=0,ymin=0,xmax=1300)

foot_trace_clip = foot_trace[:len(foot_trace_reference)]
plt.subplot(1,2,1)
plt.plot(np.array(foot_trace_clip)[:, 0], np.array(foot_trace_clip)[:, 1], color='r')
plt.gca().invert_yaxis();

plt.subplot(1,2,2)
plt.plot(np.array(foot_trace_reference)[:, 0], np.array(foot_trace_reference)[:, 1], color='g')
plt.gca().invert_yaxis();

好的,通过这些,我们开始更清楚地看到在不同背景下脚部移动的特征形状之间的差异。

首先,我们看到,虽然顶级运动员在转弯时做了一个比较直的步骤,用脚生成了一个几乎完整的半圆,而我则在初步步骤内有一个弯曲的外观,并且在将腿抬到空中时也没有生成半圆。

此外,虽然顶级运动员在抬起腿时会生成一个宽大的圆圈,而我则会生成一个浅圆圈,几乎像是一个椭圆。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片,比较动作执行的轨迹

我发现这些初步结果相当不错,因为它们表明,尽管比较存在局限性,但通过观察这些轨迹,可以评估动作执行的特征形状之间的差异。

除此之外,我还想看看是否可以比较动作执行的速度,为此我可视化了身体关节的实时运动,将我和专家的图放在一起,看看我的时机偏差有多大。

这项分析的挑战在于,由于视频的速度各不相同且未对齐,我首先需要以有意义的方式对齐它们。

我不确定该使用哪种技术,但与我的朋友 Aaron(里斯本 Champalimaud 神经科学研究所的神经科学家)的对话让我有了一个选择:动态时间规整。

使用动态时间规整比较速度和时机

动态时间规整(DTW)是一种用于测量两个具有不同速度的时间序列之间相似度的技术。

基本思想是你有两个不同的时间序列,它们可能包含一些你希望分析的模式,因此你试图通过应用一些规则将它们对齐,从而计算两个序列之间的最佳匹配。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

两次步态序列,尽管速度各异,我们可以观察到四肢的轨迹非常相似;取自维基百科,参考(Olsen et al, 2017)

我在这篇文章中找到了对这个话题的很好的介绍:

## 动态时间规整

解释与代码实现

towardsdatascience.com

作者:Jeremy Zhang

为了使用动态时间规整,我做了以下工作:

1. 将值归一化到相同范围

2. 使用了 DTW 算法的 Python 实现。

from fastdtw import fastdtw
from scipy.spatial.distance import euclidean

max_x = max(max(foot_trace_clip, key=lambda x: x[0])[0], max(foot_trace_reference, key=lambda x: x[0])[0])
max_y = max(max(foot_trace_clip, key=lambda x: x[1])[1], max(foot_trace_reference, key=lambda x: x[1])[1])

foot_trace_clip_norm = [(x/max_x, y/max_y) for (x, y) in foot_trace_clip]
foot_trace_reference_norm = [(x/max_x, y/max_y) for (x, y) in foot_trace_reference]

distance, path = fastdtw(foot_trace_clip_norm, foot_trace_reference_norm, dist=euclidean)

我得到的输出是:

1. distance:两个时间序列向量之间的欧几里得距离

2. path:两个时间序列之间的映射,以嵌套的元组列表形式存在

现在,我可以使用存储在path变量中的输出,创建一个对齐了两个序列的图:

foot_trace_reference_norm_mapped = [foot_trace_reference_norm[path[i][1]] for i in range(len(path))]
foot_trace_clip_norm_mapped = [foot_trace_clip_norm[path[i][1]] for i in range(len(path))]

plt.subplot(1,2,1)
plt.plot(np.array(foot_trace_reference_norm_mapped)[:, 0], np.array(foot_trace_reference_norm_mapped)[:, 1], color='g')
plt.gca().invert_yaxis();

plt.subplot(1,2,2)
plt.plot(np.array(foot_trace_clip_norm_mapped)[:, 0], np.array(foot_trace_clip_norm_mapped)[:, 1], color='r')
plt.gca().invert_yaxis();
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供,使用 DTW 算法对齐的时间序列

现在,由于参考轨迹的数据不足,我不能说这个图比之前讨论的元素给了我更多的见解,然而,它确实有助于突出我之前提到的运动形状。

然而,作为未来的一个备注,我的想法是,如果可以满足某些条件来帮助使两个视频更一致,我希望有一个参考轨迹,我可以用来比较我的尝试轨迹,以便用于即时反馈。

我将使用 DTW 算法输出的欧几里得距离作为我的反馈指标,并且会有一个应用程序可以突出显示我是否接近或远离我尝试模仿的签名形状。

为了说明这一点,让我给你展示一个例子。

def find_individual_traces(trace,window_size=60, color_plot="r"):
    """
    Function that takes in a liste of tuples containing x,y coordinates
    and plots them as different clips with varying sizes to allow the user to find
    the point where a full repetition has been completed
    """

    clip_size = 0
    for i in range(len(trace)//window_size):
        plt.plot(np.array(trace[clip_size:clip_size+window_size])[:, 0], np.array(trace[clip_size:clip_size+window_size])[:, 1], color=color_plot)
        plt.gca().invert_yaxis()
        plt.title(f"Trace, clip size = {clip_size}")
        plt.show()
        clip_size+=window_size

def get_individual_traces(trace, clip_size):
    num_clips = len(trace)//clip_size
    trace_clips = []
    i = 0
    for clip in range(num_clips):
        trace_clips.append(trace[i:i+clip_size])
        i+=clip_size

    return trace_clips

find_individual_traces(foot_trace_clip_norm)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。由我执行的脚部动作的轨迹。

这里我展示了视频中的剪辑,我在每个单独的动作中执行。每个这些轨迹可以与类似获得的参考轨迹进行比较:

find_individual_traces(foot_trace_reference_norm, window_size=45,color_plot="g")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。由精英玩家执行的脚部动作的轨迹。

当我获得参考轨迹时,也会得到一些噪声信号,但我将使用第三个作为我的参考:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

现在我可以循环遍历代表我实际动作的轨迹,并查看它们如何与在几次训练课程中得到的参考轨迹进行比较。

video_path = "./videos/clip_training_session_3.mp4"
body_part_index = 31
foot_trace_clip = get_joint_trace_data(video_path, body_part_index)

video_path = "./videos/uchimata_wall.mp4"
body_part_index = 31
foot_trace_reference = get_joint_trace_data(video_path, body_part_index,xmin=0,ymin=0,xmax=1300)

# Showing a plot with the tracings from the training session
plt.plot(np.array(foot_trace_clip)[:, 0], np.array(foot_trace_clip)[:, 1], color='r')
plt.gca().invert_yaxis();

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供。几次执行动作中,脚的 x,y 坐标的轨迹。

现在我从两个轨迹中获取标准化值。

max_x = max(max(foot_trace_clip, key=lambda x: x[0])[0], max(foot_trace_reference, key=lambda x: x[0])[0])
max_y = max(max(foot_trace_clip, key=lambda x: x[1])[1], max(foot_trace_reference, key=lambda x: x[1])[1])

foot_trace_clip_norm = [(x/max_x, y/max_y) for (x, y) in foot_trace_clip]
foot_trace_reference_norm = [(x/max_x, y/max_y) for (x, y) in foot_trace_reference]

我从训练剪辑和参考轨迹中获取轨迹,以帮助我设定目标。

剪辑大小是手动设置的。

traces = get_individual_traces(foot_trace_clip_norm, clip_size=67)
traces_ref = get_individual_traces(foot_trace_reference_norm, clip_size=60)

我展示了在经过经验观察手动分类为噪声后去除了一些轨迹的例子。

# Here I show an example trace from the new clip
index = 0
color_plot = "black"
plt.plot(np.array(traces[index])[:, 0], np.array(traces[index])[:, 1], color=color_plot)
plt.gca().invert_yaxis()
plt.title(f"Trace {index}")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

然后,我循环遍历轨迹,并将它们的得分与从精英玩家视频中获得的参考轨迹进行比较:

trace_ref = traces[2]
trace_scores = []

for trace in traces:
    distance, path = fastdtw(trace, trace_ref, dist=euclidean)
    trace_scores.append(distance)

plt.plot(trace_scores, color="black")
plt.title("Trace Scores with DTW")
plt.xlabel("Trace Index")
plt.ylabel("Euclidean Distance Score")
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

现在,我注意到的第一个奇怪的现象是指标的上下波动,这只能通过一些获得的轨迹指向脚下落而非上升来解释。

然而,这个图表的有趣之处在于,轨迹的得分似乎甚至略有改善,并且至少保持在 20(在这种情况下是两序列之间的欧几里得距离的度量)。

尽管此时无法明确解释这些数字,但我发现像这样的处理方法可以转换为一个可衡量的指标,用于比较一个动作相对于另一个动作的质量,这一点相当有见地。

最终备注

未来,我希望研究如何更好地提取训练片段,以获得每次动作执行的完美对齐段,以便产生更一致的结果。

总的来说,我认为做这些实验相当有趣,因为它突出了这种技术在提供动作的详细评估方面的力量,尽管它仍需大量工作才能成为一个有用的洞察工具。

如果你喜欢这篇文章, 加入 Medium,并订阅我的 Youtube 频道我的新闻通讯。谢谢,下次见! 😃

参考文献

使用不平衡数据进行回归的机器学习

原文:towardsdatascience.com/machine-learning-for-regression-with-imbalanced-data-62629d7ad330

为什么在数据集中预测异常值如此困难,以及你可以采取什么措施来应对

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Caroline Arnold

·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 8 月 10 日

什么是不平衡数据?

许多现实世界的数据集存在不平衡的问题,其中某些类型的样本在数据集中被过度代表,而其他类型则出现较少。一些例子包括:

  • 在将信用卡交易分类为欺诈性或合法交易时,大多数交易将属于后者类别

  • 强降雨发生的频率低于中等降雨,但可能对人类和基础设施造成更大的损害

  • 在尝试识别土地用途时,代表森林和农业的像素比城市定居点的像素更多

在这篇文章中,我们旨在提供对机器学习算法为何在不平衡数据上表现不佳的直观解释,展示如何使用分位数评估来量化算法的性能,并展示三种不同的策略来提高算法的性能。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Elena MozhviloUnsplash 上的照片

回归示例数据集:加州住房

数据集不平衡通常在分类问题中表现为一个多数类掩盖了一个少数类。在这里,我们关注的是回归问题,其中目标是一个连续的数值。我们将使用scikit-learn 提供的加利福尼亚住房数据集。该数据集包含超过 20,000 个房屋样本,特征包括位置、房间和卧室数量、房龄、面积和邻里中位收入。目标变量是中位房价,以百万美元计。为了查看数据集是否不平衡,我们绘制了目标变量的直方图。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

加利福尼亚房价数据集中目标变量的直方图。红线表示均值为 1.9 M$。

显然,并不是所有的中位房价都被同等地表示。目标变量的均值为 1.9 M ,标准差为 0.98 M ,标准差为 0.98 M ,标准差为0.98M,但这些值并不遵循正态分布——分布偏向于中位值超过 4.0 M$ 的昂贵房屋。

均方误差损失函数

我们在 Keras 中实现了一个小型神经网络,并使用它来预测中位房价。由于这是一个回归问题,均方误差函数是一个合适的损失函数。对于给定的一批样本,均方误差(MSE)计算如下:

在这个公式中,预测值和实际值的距离是损失的决定性因素。

模型训练很快,损失曲线看起来合理。训练集的最终损失为 0.2562,验证集为 0.2584,因此我们似乎达到了偏差-方差权衡的最佳点。

使用目标变量的分位数进行评估。

总体而言,我们的机器学习算法在保留的测试集上产生了 0.27 的均方误差。但这与不同房价相比如何呢?我们将目标变量分成每个 1 M$ 价格区间的分箱,并分别计算每个分箱中样本的均方误差。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

加利福尼亚房价数据集中目标变量的均方误差。

正如我们所见,我们的机器学习算法在接近目标变量均值的样本上表现最佳。在最高的分箱中,房价超过 4 M$,误差几乎高出 10 倍!

分位数评估是探索机器学习算法在数据集不同区域表现的好方法。这种快速分析可以直接指出实验设置中的问题,你应始终在仅报告整个数据集的平均性能指标之前考虑它。

为什么模型难以预测高房价?

简单规则学习

我们甚至不能责怪机器学习算法,因为它正好做了我们要求的事情:对于大多数样本,它具有良好的预测能力。只是价格超过 400 万美元的房屋在数据集中不足代表——只有 4%的训练数据落在这个范围内——因此算法没有足够的激励来优先考虑这些样本。

我们应该始终记住,机器学习算法在学习我们提出的任务。它们容易简单规则学习,这由 Geirhos 等(2021)定义:

简单规则是在标准基准测试中表现良好,但在更具挑战性的测试条件下(如真实世界场景)失败的决策规则。相关问题在比较心理学、教育学和语言学中已知,表明简单规则学习可能是学习系统(无论是生物的还是人工的)共同的特征。

提出正确的问题

比如说,我们正在为一个主要关注估算高端房屋中位数价值的房地产经纪人开发算法。对于这个客户,目前的算法无法提供期望的预测能力,因为它在他们感兴趣的房屋类型上的表现不佳。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

专注于最昂贵的房屋,而不是在平均样本上表现出色。照片由 Daniel Barnes 提供,来源于 Unsplash

处理不平衡数据

策略 1:增加批量大小

增加批量大小时,每个训练批次包含来自不足代表组的样本的可能性更高。我们将批量大小设置为 512,并重复模型训练。

策略 2:在损失函数中引入权重

在这里,我们要求机器学习算法关注在训练中不足代表的样本。这些样本具有更高的权重。权重可以直接计算并传递给 Keras 中的函数 model.fit(..., sample_weights=...)

用语言解释权重的计算:

  • 计算每个区间的样本出现次数

  • 除以样本总数——你将获得给定区间的样本频率

  • 逆数是相关的权重

以及代码:

策略 3:将目标值转换为正态分布

在这种情况下,我们将目标变量转换为符合正态分布。正态分布最适合用于均方误差损失,并减少异常值的特征。重要——评估之前不要忘记重新缩放预测值!

评估

哪种策略表现最好?

现在是时候比较三种不同的策略了。我们跟踪了每种策略下每个目标变量区间的均方误差,如下图所示。仅关注最高区间时,带权重的损失表现最佳。虽然它在平均房价的算法表现有所降低,但在我们最感兴趣的区域,它增加了预测能力。其他两种策略,即增加批量大小和缩放目标变量,甚至增加了最高区间的均方误差。因此,它们没有被证明对解决我们的问题有帮助。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

各策略下目标变量每个区间的均方误差。

因此,我们建议我们的客户——高端房地产经纪人——使用一种机器学习算法,该算法使用权重来强调数据集中表现不足的样本。然而,请注意,机器学习是一门经验科学,对于其他数据集,您可能会找到解决数据集不平衡问题的不同方案。

《机器学习插图:分类的评估指标》

原文:towardsdatascience.com/machine-learning-illustrated-classification-evaluation-metrics-dfc33b373c43

一个全面(且丰富多彩)的指南,介绍你需要了解的关于评估分类模型的一切

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Shreya Rao

·发表于 Towards Data Science ·12 分钟阅读·2023 年 4 月 20 日

我在学习过程中意识到自己是一个非常视觉化的学习者,我喜欢使用颜色和有趣的插图来学习新概念,特别是那些通常像这样解释的科学概念:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从我之前的文章中,通过大量可爱的评论和消息(感谢所有的支持!),我发现有很多人对这种情感产生了共鸣。因此,我决定开始一个新的系列,我将尝试用插图来讲解机器学习和计算机科学的概念,希望能使学习变得有趣。所以,请系好安全带,享受这段旅程吧!

让我们通过探索机器学习中的一个基本问题来开始这个系列:我们如何评估分类模型的性能

在之前的文章中,如 决策树分类逻辑回归,我们讨论了如何构建分类模型。然而,量化这些模型的表现至关重要,这就提出了一个问题:我们应该使用什么指标来做到这一点?

为了说明这个概念,让我们构建一个贷款还款分类模型

我们的目标是预测一个人是否可能还清贷款,基于他们的信用评分。虽然年龄、薪水、贷款金额、贷款类型、职业和信用历史等其他变量也可能影响这样的分类器,但为了简便起见,我们只考虑信用评分作为我们模型的主要决定因素。

根据逻辑回归文章中列出的步骤,我们构建了一个分类器,该分类器根据信用评分预测某人是否会还款。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从中我们可以看到,信用评分越低,越可能该人不会还款,反之亦然。

目前,该模型的输出是一个人会还款的概率。然而,如果我们想将贷款分类为会还款不会还款,我们需要找到将这些概率转换为分类的方法。

一种方法是设定 0.5 作为阈值,将低于该阈值的人分类为不会还款,高于该阈值的人分类为会还款

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从中我们推断出,这个模型会将信用评分低于 600 的人分类为不会还款(粉色),高于 600 的人分类为会还款(蓝色)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 0.5 作为分界线,我们将一个信用评分为 420 的人分类为…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

不会还款。而这个信用评分为 700 的人则为…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

…会还款

现在为了测试我们的模型效果,我们需要远远超过 2 个人的数据。因此,让我们深入挖掘过去的记录,收集 10000 人的信用评分以及他们是否还款的信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

注意:在我们的记录中,有 9500 人还款,只有 500 人未还款。

然后我们对每个人运行我们的分类器,根据他们的信用评分预测他们是否会还款。

混淆矩阵

为了更好地可视化我们的预测与实际情况的比较,我们创建了一个称为混淆矩阵的东西。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这个特定的混淆矩阵中,我们将还款的个体视为正例标签,将未还款的个体视为负例标签。

  • 真阳性(TP):实际还款的人,并且被模型正确地分类为会还款

  • 假阴性(FP):实际还款的人,但被模型错误地分类为不会还款

  • 真阴性(TN):实际上未还款的人,并且被模型正确地分类为不会还款

  • 假阳性(FP):实际上未还款的人,但被模型错误地分类为会还款

现在假设,我们将 10000 人的信息输入到我们的模型中。我们得到的混淆矩阵如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

从中我们可以推导出——

  • 在 9500 名还款的人中——9000 人被正确分类(TP),500 人被错误分类(FN)

  • 在 500 名没有还款的人中——200 人(TN)被正确分类,300 人(FP)被错误分类。

准确率

直观上,我们首先要问自己的是:我们的模型有多准确?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在我们的案例中,准确率是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

92%的准确率确实令人印象深刻,但需要注意的是,准确率通常是评估模型性能的一个简化指标。

如果我们更仔细地查看混淆矩阵,我们可以看到,虽然许多还款的个体被正确分类,但在 500 名未还款的个体中,只有 200 人被模型正确标记,其余 300 人被错误分类。

那么,让我们深入探讨一些其他常用的指标,以评估我们模型的表现。

精准度

我们可以问的另一个问题是:被预测为会还款的个体中,实际还款的百分比是多少

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算精准度时,我们可以将真正例数除以预测为正例的总数(即,分类为会还款的个体)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,当我们的分类器预测一个人会还款时,我们的分类器在 96.8%的情况下是正确的。

敏感性(也叫召回率)

接下来,我们可以问自己:我们模型正确识别的实际还款个体的百分比是多少

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

计算敏感性时,我们可以取真正例数,并将其除以实际还款的总人数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分类器正确标记了 94.7%实际还款的人,而其余的则错误标记为不会还款

注意:精准度和敏感性公式中的术语有时可能会令人困惑。一个简单的记忆法是记住两个公式都使用 TP(真正例),但分母不同。精准度的分母是(TP + FP),而敏感性的分母是(TP + FN)。

为了记住这个区别,可以将 FP 中的“P”与精准度中的“P”联系起来:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这就剩下 FN,我们在敏感性的分母中找到它:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

F1 Score

另一个结合了敏感性和精准度的有用指标是 F1 分数,它计算了精准度和敏感性的调和平均值。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在我们的案例中,F1 分数是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通常,F1 分数提供了对模型性能更全面的评估。因此,F1 分数通常比准确率在实际中更有用。

特异性

另一个需要考虑的关键问题是特异性,它提出了这样一个问题:未偿还贷款的个体中有多少百分比被正确识别为不会偿还

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

要计算特异性,我们将真正负例除以未偿还贷款的个体总数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以看到,我们的分类器仅正确识别了 40%的未偿还贷款的个体。

特异性与其他评估指标之间的明显差异强调了选择适当指标评估模型性能的重要性。考虑所有评估指标并进行适当解释是至关重要的,因为每个指标可能提供对模型有效性的不同视角。

注意:我经常发现结合各种指标或根据问题制定自己的指标是有帮助的

在我们的场景中,准确识别不会偿还贷款的个体更为关键,因为向这些个体提供贷款可能会带来相较于拒绝那些会偿还的个体更高的成本。因此,我们需要考虑改善性能的方法。

实现这一点的一种方法是调整分类的阈值

虽然这样做可能看起来违反直觉,但对我们来说,重要的是正确识别那些不会偿还贷款的个体。因此,错误标记实际上会偿还贷款的人对我们来说并不是那么重要。

通过调整阈值,我们可以让模型对负类(不会偿还的人)更敏感,代价是对正类(会偿还的人)的敏感度下降。这可能会增加假阴性(将偿还的人分类为不会偿还),但可能减少假阳性(未能正确识别未偿还的人)。

直到现在,我们使用了 0.5 的阈值,但让我们尝试调整一下,看看我们的模型是否能表现得更好。

让我们从将阈值设置为 0 开始。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这意味着每个人都会被分类为将偿还(由蓝色表示)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这将导致以下混淆矩阵:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个人都被分类为将偿还

…准确率为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

…敏感性和精确度:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

…以及特异性:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

当阈值=0 时,我们的分类器无法正确分类任何没有偿还贷款的个人,即使准确性和敏感度看起来可能令人印象深刻,它也是无效的。

让我们尝试 0.1 的阈值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,任何信用分数低于 420 的人将被分类为不会偿还。这将导致如下的混淆矩阵和指标:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们再次看到,除了特异性外,所有指标都非常出色。

接下来,让我们去到另一个极端,将阈值设置为 0.9:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因此,任何信用分数低于 760 的个人都将被标记为不会偿还。这将导致如下的混淆矩阵和指标:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这里,我们看到指标几乎是翻转的。特异性和精准度很好,但准确性和敏感度很差。

你明白了。我们可以为更多的阈值(0.004, 0.3, 0.6, 0.875…)进行类似操作。但是这样会导致大量的混淆矩阵和指标,从而造成很多混淆。这绝对是有意为之。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

ROC 曲线

这就是接收者操作特征(ROC)曲线的作用,用以消除这种混淆。

ROC 曲线总结并允许我们可视化分类器在所有可能阈值下的表现

曲线的 y 轴是真正例率,即敏感度。x 轴是假正例率,即 1-特异性。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

假阳性率告诉我们那些没有偿还却被错误分类为 将要偿还 (FP)的人的比例。

所以当阈值=0 时,从之前我们看到的混淆矩阵和指标是:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们知道真正例率 = 敏感度 = 1假正例率 = 1 — 特异性 = 1 — 0 = 1。

现在让我们将这些信息绘制在 ROC 曲线上:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这条虚线蓝色线显示了真正例率=假正例率的位置:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这条线上的任何点都意味着正确分类为偿还的人的比例与错误分类为未偿还的人的比例相同。

关键在于我们希望我们的阈值点尽可能远离左侧的线,并且我们不希望有任何点低于这条线。

现在当阈值=0.1 时:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 ROC 曲线上绘制这个阈值:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于新点 (0.84, 0.989) 位于蓝色虚线的左侧,我们知道偿还的正确分类人群比例大于未偿还的错误分类人群比例。

换句话说,新阈值比蓝色虚线上的第一个阈值更好。

现在让我们将阈值提高到 0.2。我们计算该阈值的真正例率和假阳性率,并绘制图表:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

新点 (0.75, 0.98) 更远离蓝色虚线,显示新阈值比之前的更好。

现在我们继续使用其他几个阈值(=0.35, 0.5, 0.65, 0.7, 0.8, 1)重复相同的过程,直到阈值=1。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在阈值=1 时,我们处于点 (0, 0),其中真正例率 = 假负例率 = 0,因为分类器将所有点分类为不会还款

现在无需排序所有混乱的矩阵和指标,我可以看到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

因为在紫色点处,当 TPR = 0.8 且 FPR = 0,

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

换句话说,这个阈值没有产生假阳性。而在蓝色点处,尽管 80%还款的人被正确分类,但未还款的人的正确分类率只有 80%(相比于之前阈值的 100%)。

现在如果我们连接所有这些点……

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

…我们最终得到 ROC 曲线。

AUC

现在假设我们想比较我们构建的两个不同的分类器。例如,第一个分类器是我们迄今为止看到的逻辑回归分类器,它产生了这个 ROC 曲线:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们决定建立另一个决策树分类器,结果得到了这个 ROC 曲线:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

比较两个分类器的一种方法是计算它们各自曲线下的面积或 AUC。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

由于逻辑回归曲线的 AUC 值更大,我们得出结论它是一个更好的分类器。

总结一下,我们讨论了评估分类模型的常用指标。然而,选择指标是主观的,取决于对问题和业务需求的理解。使用这些指标的组合或创建更适合特定模型需求的新指标也可能是有用的。

向 StatQuest 致以巨大的感谢,我最喜欢的统计学和机器学习资源。欢迎在LinkedIn上与我联系,或发邮件至shreya.statistics@gmail.com

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值