概率机器学习与分位数匹配:一个 Python 示例
原文:
towardsdatascience.com/probabilistic-ml-with-quantile-matching-an-example-with-python-c367eee85f18
一种将分位数回归预测转化为概率分布的鲜为人知的技术。
·发表于 Towards Data Science ·阅读时长 8 分钟·2023 年 9 月 4 日
–
“分位数匹配”,由Giulia Roggia。已获许可使用。
-
分位数回归
-
分位数匹配
-
Python 示例:预测糖尿病进展
-
结论
当我们训练回归模型时,我们获得的是点预测。然而,在实际应用中,我们通常对估计每个预测值的不确定性感兴趣。为实现这一目标,我们假设我们试图预测的值是一个随机变量,目标是估计其分布。
目前有许多方法可以估计预测的不确定性,如方差估计、贝叶斯方法、保形预测等。分位数回归是这些著名方法之一。
分位数回归
分位数回归包括为每个感兴趣的分位数估计一个模型。这可以通过使用一种不对称的损失函数来实现,这种损失函数称为pinball 损失。分位数回归简单易懂,并且在高效的库中如LightGBM中很容易获得。然而,分位数回归也存在一些问题:
-
没有保证分位数的顺序是正确的。例如,你对 50%分位数的预测可能会比 60%分位数的预测要大,这显然是不合理的。
-
为了获得整个分布的估计,你需要训练许多模型。例如,如果你需要每个百分位点的估计,你必须训练 99 个模型。
这就是分位数匹配如何提供帮助的。
分位数匹配
分位数匹配的目标是给定一组分位数估计来拟合分布函数。我们可以将此视为回归问题,因此曲线不必完全符合分位数。相反,它应该“尽可能接近”,同时保持使其成为分布函数的特性。
具体来说,我们感兴趣的是估计逆累积分布函数:给定一个概率alpha,我们想知道P(X<v)=alpha的值是什么,其中P代表概率,X是我们尝试预测的随机变量。
在以下示例中,我们提供了 3 种适配这种分布的替代方案。
Python 示例:预测糖尿病进展
图片由Towfiqu barbhuiya提供,来源于Unsplash
在本节中,我们展示了一个应用于糖尿病数据集的分位数匹配示例,该数据集可在Sklearn中获得:
对每个 n = 442 糖尿病患者,获得了十个基线变量,包括年龄、性别、体重指数、平均血压和六项血清测量值,以及感兴趣的响应,即基线后一年病情进展的定量测量。
让我们首先导入所需的库:
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from lightgbm import LGBMRegressor
from scipy import optimize, stats
from scipy.interpolate import PchipInterpolator
from sklearn.datasets import load_diabetes
分位数匹配方法
我们定义了三种替代方案来从一组分位数中估计逆累积分布函数:
-
拟合正态分布
-
拟合“半正态”分布:由两个不同标准差的正态分布组成的分布,一个在中位数以下,一个在中位数以上(与半正态分布的绝对值不同,后者也被称为半正态)。
-
三次插值:使用三次样条来估计平滑的递增曲线。
注意这三种方法逐渐更灵活。第一种方法将输出限制为遵循正态分布。第二种方法允许存在不对称性,这在实际世界的例子中很常见,例如预测价格回报。第三种方法对基础分布没有任何假设,例如,它允许多模态。
为了实现这些方法,我们使用了一个易于扩展的设计模式:
-
一个定义匹配器类接口的基础抽象类
-
一组实现不同算法的具体类
-
一个返回所需方法类的工厂
首先,我们定义基类和工厂。为了简单起见,我们建立一个接口来一次拟合和预测一个样本:
class QuantileMatcherBase(ABC):
@abstractmethod
def fit_one(self, alphas, quant_values):
pass
@abstractmethod
def predict_one(self, alphas):
pass
def quantile_matcher_factory(match, **kwargs) -> QuantileMatcherBase:
matcher_map = {
"normal": QuantileMatcherNormCurvFit,
"half_normal": QuantileMatcherHalfNormCurvFit,
"cubic_interpolation": QuantileMatcherCubicInterpolation,
}
if match not in matcher_map:
raise ValueError(f"Unknown matcher {match}")
return matcher_mapmatch
然后,我们可以继续进行具体的实现。首先是常规分布:我们将问题框定为一个非线性优化问题,其中我们需要估计参数以最小化拟合曲线与观察值之间的平方差。
class QuantileMatcherNormCurvFit(QuantileMatcherBase):
"""Normal distribution quantile matcher."""
def __init__(self):
self.params = None
def fit_one(self, alphas, quant_values):
self.params, _ = optimize.curve_fit(
lambda x, mu, sigma: stats.norm.isf(x, mu, sigma),
alphas,
1 - quant_values,
)
def predict_one(self, alphas):
return 1 - stats.norm.isf(alphas, *self.params)
对于半常规分布,我们重复使用上面定义的类:一次用于中位数以下的值,一次用于中位数以上的值。
class QuantileMatcherHalfNormCurvFit(QuantileMatcherBase):
"""Half-Normal distribution quantile matcher."""
def __init__(self):
self.below = QuantileMatcherNormCurvFit()
self.above = QuantileMatcherNormCurvFit()
def fit_one(self, alphas, quant_values):
self.below.fit_one(alphas[alphas<=0.5],quant_values[alphas<=0.5])
self.above.fit_one(alphas[alphas>=0.5],quant_values[alphas>=0.5])
# trick to ensure same median
mu = (self.below.params[0] + self.above.params[0]) / 2
self.below.params[0] = mu
self.above.params[0] = mu
def predict_one(self, alphas):
pred = self.above.predict_one(alphas)
pred_below = self.below.predict_one(alphas)
pred[alphas<0.5] = pred_below[alphas<0.5]
return pred
请注意,在 fit_one 方法中,我们应用了一个小技巧,以确保两个分布具有相同的中位数。
三次插值的实现很简单:
class QuantileMatcherCubicInterpolation(QuantileMatcherBase):
"""Increasing cubic interpolation quantile matcher."""
def __init__(self):
self.params = None
def fit_one(self, alphas, quant_values):
self.interp = PchipInterpolator(alphas, quant_values)
def predict_one(self, alphas):
return self.interp(alphas)
量化回归包装器
我们定义一个类,用于拟合一些 Lightgbm 模型,并进行量化回归,以适应预定义的量化集合。我们实现了一个方法 predict_raw 来获取每个模型的原始预测,以及一个方法 predict_cdf 来使用之前定义的 QuantileMatcher 类在量化网格上获取(反向)累积分布函数。
class ProbLGBMRegressor:
_forbidden_keys = (
"objective",
"objective_type",
"app",
"application",
"loss",
"alpha",
)
def __init__(
self,
alphas=np.array([0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]),
**lgbm_args
):
self.alphas = alphas
for key in self._forbidden_keys:
if key in lgbm_args:
raise ValueError(f"{key} parameter is not allowed.")
self._models = {}
for alpha in self.alphas:
self._models[alpha] = LGBMRegressor(
objective="quantile", alpha=alpha, **lgbm_args
)
def fit(self, x, y):
for alpha in self.alphas:
self._models[alpha].fit(x, y)
def predict_raw(self, x):
return pd.DataFrame(
{alpha: model.predict(x) for alpha, model in self._models.items()}
)
def predict_cdf(
self,
x,
inference_alphas=np.linspace(0.001, 0.999, 999),
match="normal_curve_fit",
**matcher_params,
):
# Compute predictions for the limited set of quantiles.
raw_preds = self.predict_raw(x)
# Estimate the cumulative distribution for each sample.
matcher = quantile_matcher_factory(match, **matcher_params)
predictions = []
for _, row in raw_preds.iterrows():
matcher.fit_one(self.alphas, row.values)
preds = matcher.predict_one(inference_alphas)
predictions.append(preds)
return pd.DataFrame(predictions, columns=inference_alphas)
糖尿病数据集:拟合和预测
现在我们可以加载糖尿病数据集,并使用上面定义的类来训练模型并预测目标值的分布。
x,y = load_diabetes(return_X_y=True, as_frame=True)
# Fit a regressor
prob_lgbm = ProbLGBMRegressor()
prob_lgbm.fit(x,y)
# Predict the distributions with all methods
predicted_cdf = {}
for match in ["normal","half_normal","cubic_interpolation"]:
predicted_cdf[match] = prob_lgbm.predict_cdf(x, match=match)
# For visualization purposes, we predict also the "raw" values
predicted_raw = prob_lgbm.predict_raw(x)
预测的图形分析
为了了解我们的模型是什么样的,我们可以绘制几个样本的预测分布。让我们定义一个帮助函数来绘制预测的累积分布函数。
def get_fig_cumulative_distribution_function(predicted_cdf, predicted_raw, idx):
# Small artifact to ensure same range in figures
max_limit = max([pred.iloc[idx, -1] for pred in predicted_cdf.values()]) + 5
min_limit = max([pred.iloc[idx, 0] for pred in predicted_cdf.values()]) - 5
# Create traces for each distribution.
trace = []
for match, pred_cdf in predicted_cdf.items():
x = [min_limit] + list(pred_cdf.iloc[idx].values) + [max_limit]
y = [0] + list(pred_cdf.columns) + [1]
trace.append(go.Scatter(x=x, y=y, mode="lines", name=match.title()))
# Add trace for raw quantile predictions.
trace.append(
go.Scatter(
x=predicted_raw.iloc[idx],
y=predicted_raw.columns,
mode="markers",
name="Raw Predictions",
marker={"size": 10},
)
)
# Create the figure
fig = go.Figure(trace)
fig.update_layout(
title="Cumulative Distribution Functions",
yaxis_title="alpha",
xaxis_title="quantile",
)
# Set x-axis limits
fig.update_xaxes(range=(min_limit, max_limit))
return fig
这里是通过预测数据集中第一个样本获得的图表:
索引 0 的预测累积分布。图片由作者提供。
我们可以看到三种方法产生了不同的曲线。虽然常规分布和半常规分布非常接近且平滑,但三次插值则较为不规则,完美拟合了所有“原始”预测。
虽然评估其量化值很方便,但从累积分布的角度分析分布的全球行为可能比较困难。为了获得更好的视图,我们可以通过使用 有限差分法 来估计相应的概率分布。让我们定义一个帮助函数来完成这项工作:
def get_fig_probability_distribution_function(predicted_cdf, idx):
trace = []
for match, pred_cdf in predicted_cdf.items():
quantiles = pred_cdf.iloc[idx].values
icdf_values = pred_cdf.columns.values
# Estimate the PDF using finite differences
diff_icdf = np.diff(icdf_values)
diff_quantiles = np.diff(quantiles)
pdf_est = diff_icdf / diff_quantiles
# Create a Plotly figure for the estimated PDF
trace.append(
go.Scatter(
x=quantiles[:-1],
y=pdf_est,
mode="lines",
fill="tozeroy",
name=match,
)
)
fig = go.Figure(data=trace)
# Add labels and title
fig.update_layout(
xaxis_title="Quantiles",
yaxis_title="Estimated PDF",
title="Estimated Probability Density Function",
)
return fig
通过对数据集中第一个样本应用上述函数,我们得到以下图表:
索引 0 的预测概率分布。图片由作者提供。
让我们再看几个样本的累积分布和概率分布。
这些是我们对第二个样本得到的结果:
索引 1 的预测累积分布。图片由作者提供。
索引 1 的预测概率分布。图片由作者提供。
这些是我们对第三个样本得到的结果:
索引 2 的预测累计分布。图像由作者提供。
索引 2 的预测累计分布。图像由作者提供。
我们可以看到,正常分布和半正态分布并不一致,这表明真实的潜在分布存在不对称性。
我们还注意到,三次插值给出了多模态且常常极端的结果。这是因为插值不受特定形式的约束,并且在拟合接近的点时往往具有高导数。这些结果可能不切实际,平滑技术可能有助于缓解这个问题。
结论
从初步来看,半正态分布似乎是最佳选择,因为它提供了现实的分布,同时能够建模不对称行为。然而,选择匹配算法的最佳方式是交叉验证预测并评估相关指标,例如预测区间的宽度及其准确性(90%的区间应在大约 90%的时间内包含目标)。
如开头所述,这种技术并不很受欢迎,我还没有机会在实际场景中使用它。因此,如果你在项目中使用了它,请告诉我!
本示例中使用的完整代码可在 此处 获取。
喜欢这篇文章? 查看我的其他文章 并关注我以获取更多内容! 点击这里 阅读无限制文章并在不增加你额外成本的情况下支持我 ❤️
主成分分析的概率视角
原文:
towardsdatascience.com/probabilistic-view-of-principal-component-analysis-9c1bbb3f167
潜在变量、期望最大化与变分推断
·发表于 Towards Data Science ·9 分钟阅读·2023 年 7 月 12 日
–
寻找隐藏变量(图片来源:作者)
在数据科学和机器学习中,主要使用的降维技术之一是主成分分析(PCA)。之前,我们已经讨论过将 PCA 应用于 管道与 支持向量机 的一些例子,在这里我们将从概率的角度来看 PCA,以提供对数据结构的更全面和稳健的理解。概率 PCA(PPCA)的一个最大优点是它能够处理数据集中缺失的值,这是经典 PCA 无法做到的。 由于我们将讨论潜在变量模型和期望最大化算法,你还可以查看 这篇详细的文章。
你可以从这篇文章中学到什么?
-
PCA 简短介绍。
-
PPCA 的数学构建块。
-
期望最大化(EM)算法或变分推断?用于参数估计时应该使用哪一个?
-
使用 TensorFlow Probability 在玩具数据集上实现 PPCA。
让我们深入探讨一下吧!
1. 奇异值分解(SVD)和 PCA:
线性代数中一个重要的概念是SVD,它是一种对实数或复数矩阵进行分解的技术,例如,一个矩阵(假设为A)可以被分解为:
方程 1:矩阵 A 的 SVD。
其中 U,Vᵀ 是正交矩阵(转置等于逆),而 Σ 将是一个对角矩阵。A 不需要是方阵,例如它是一个 N×D 矩阵,因此我们可以将其视为我们的数据矩阵,其中 N 个实例和 D 个特征。U,V 分别是方阵 (N×N) 和 (D×D),Σ 将是一个 N×D 矩阵,其中 D×D 的子集是对角的,其余条目为零。
我们也知道特征值分解。给定一个可以对角化的方阵 (B) 可以分解为:
方程 2:矩阵的特征值分解
其中 Q 是一个方形的 N×N 矩阵,其第 i 列是 B 的特征向量 q_i,而 Λ 是对角矩阵,其对角元素是对应的特征值。
让我们尝试通过乘以Aᵀ来修改方程 (1)。
方程 3:乘以 A 的转置。
在这里,AᵀA 将是一个方阵,即使 A 起初不需要是(可以是 m×n)。Σ Σᵀ 是一个对角矩阵,而 V 是一个正交矩阵。现在,这基本上是矩阵 AᵀA 的特征分解。这里的特征值是方程 (1) 中 A 的奇异值的平方。
对于正半定矩阵,SVD 和特征分解是等效的。PCA 最终归结为协方差矩阵的特征分解。找到最大特征值和相应的特征向量基本上可以视为找到最大方差的方向。
给定 D 维数据(特征),完全的特征分解将是昂贵的 ∼O(D³),但现在如果我们选择一些潜在空间维度 M(<D),则计算会便宜 ∼O(MD²)。
2. PPCA 的构建块:
2.1. 假设:
PPCA 是一个潜在变量模型 (LVM),我们之前在 详细讨论过 包括期望最大化 (EM) 算法。LVM 提供了数据的低维表示。假设我们的数据 (x) 是 N×D 维的,具有 D 个特征;那么 PCA 的 LVM 寻求一个 M 维的潜在变量向量 z,它可以用来生成观察变量 (x),并且它们通过线性关系相互关联:
方程 2.1:PPCA 的生成过程;x 条件于潜在变量 z。
上述方程中的噪声 ϵ 是一个 D 维向量,具有零均值高斯分布和 σ²I 的协方差;
方程 2. 2:噪声被建模为具有零均值和协方差 σ² 的正态分布。
由于我们知道潜在空间是 M 维的,这使得我们的 W 向量是 D×M 维的。假设潜在变量 z 具有零均值、单位协方差的高斯分布:
Eq. 2.3:潜在变量的先验是均值为零、协方差为单位的正态分布。
上述两个方程导致条件分布 x 在 z 下的如下:
Eq. 2.4:根据前两个方程,我们得到条件分布 x。
这是另一个均值为 Wz(我们可以设定 μ = 0)和协方差为 σ² 的正态分布。
上述方程应让我们想起正态分布的一个基本属性:即,如果 x 服从多元正态分布 x∼N(μ, Σ),则 x 的任何线性变换也是多元正态分布 y = Ax + b ∼ N(Aμ+b, AΣAᵀ)。
给定联合分布,边际分布也将是高斯分布:
Eq. 2.5:数据分布也服从正态分布。
由于我们希望确定参数 W、μ、σ,我们可以通过 MLE 或 EM 算法来解决这个问题。这里我们将关注 EM 方法,然后是变分推断。两种方法在 Bishop 的书中都有详细描述。Bishop 认为,随着数据维度的增加,通过迭代的 EM 步骤,我们可能在计算上比 MLE 获得优势。这主要涉及协方差矩阵的计算成本,其中 D 维数据协方差矩阵的评估需要 O(ND²),N是数据点的数量。
2.2. PPCA 的 EM 步骤:
在讨论了EM 算法 参考于高斯混合模型之后,这里我将其参考于 PPCA 进行描述。EM 算法的步骤如下:
-
在期望步骤中,我们计算完整数据对数似然相对于潜在变量的后验分布 (p(z|x)) 的期望,使用的是“旧”参数。
-
最大化该数据对数似然函数将得到“新”的参数,这些参数将被插入到第 1 步中。
由于数据点是独立的,完全数据似然将是:
Eq. 2.6:包括观察变量和潜在变量在内的完全数据似然。
E 步的主要目标是计算上述表达式的期望。在这里,我们需要使用方程 3 和 4 中 p(x|z) 和 p(z)。推导在 Bishop 的书中给出,但重要的是推导需要计算 E[z_n]、E[z_n zᵀ_n],这些可以通过后验分布 p(z|x) 推导得出。
一旦 E 步完成,M 步则涉及最大化相对于参数 W、σ² 的期望对数似然。
变分推断、EM 和 ELBO:
上述 EM 步骤依赖于一个关键假设,即后验分布*p(z|x)*是可处理的(这是方程 2.6 中的 E 步所必需的)。如果不是这样呢? 如果后验没有任何解析表达式?这就是变分推断的基础。
我们现在借助变分方法。主要思想是我们尝试找到一个分布q(z),使其尽可能接近后验分布p(z|x)。这个近似分布可以有自己的变分参数:q(z|θ),我们尝试找到使q接近感兴趣后验的参数设置。q(z)应该相对简单,更易于推断。为了衡量两个分布q(z)和p(z|x)*的接近程度,常用的度量是 Kullback-Leibler (KL) 散度。变分推断中的 KL 散度自然引入了证据下界(ELBO):
方程 2.7:*q(z)和后验p(z|x)*的接近程度,ELBO 加在一起给我们数据的可能性。
其中 ELBO (q) 被定义为:
方程 2.8:ELBO 的定义
对于推导,你可以查看参考中的笔记本或其他可用的讲义。
由于 KL 散度是非负的,log p(x)* ≥ ELBO(q)。所以我们在变分推断(VI)中所做的就是最大化 ELBO。
我们还可以很容易地看到 VI 和传统 EM 算法之间的联系;当q(x)==p(z|x)*时,KL 散度项消失。
由于我们现在已经完成了 PPCA 的概述,我们将使用 TensorFlow Probability 来实现这一点,并使用 ELBO 的定义尝试最大化它,这反过来等同于最大化数据的可能性 log p(x)。
使用 TensorFlow Probability 实现 PPCA:
为了通过变分推断实现一个简单的 PPCA 示例,我将遵循 TensorFlow Probability 应用中的原始示例。
实现这一点时,我们将假设我们知道噪声的标准差(σ,我选择了 3.0),如方程(2.2)所定义,并且我们对w进行先验设定,并尝试通过变分推断来估计。让我们定义联合分布:
在这里我们使用了JointDistributionCoroutineAutoBatched
,这是JointDistributionCoroutine
的“自动批处理”版本。它根据输入参数的形状自动应用批处理语义,从而允许更灵活地处理批处理维度。我们可以将批处理或非批处理参数传递给联合分布,它会自动处理批处理语义。我们从这个联合分布中采样后,使用tf_model.sample()
在第 51 行中,我们绘制观察数据(x)的分布(2D):
图 1:数据的观测分布,给定上述代码定义的联合分布。
顺便提一下,由于这些数据点是在你运行代码时随机采样的,你可能不会得到完全相似的点。
我们尝试认为后验p(W, Z|X)可以用一个由θ参数化的更简单的分布*q(W, Z)来近似。在 VI 中,我们的目标是最小化q(W, Z)和p(W, Z|X)*之间的 KL 散度,这从方程(2.8)来看,则是最大化证据下界(ELBO)。
在这种情况下的 ELBO:
我们通过 VI 尝试最小化的 ELBO。
为了最小化 ELBO,我们将首先定义类似于定义联合分布的方式的替代分布,并使用tfp.vi
方法将替代后验拟合到目标(未归一化)对数密度。
在获得替代分布后,我们可以使用它来采样新的数据点,结果如下图所示:
图 2:采样的分布与原始分布非常相似。
结论:
我们已经探讨了 PPCA 背后的数学原理,并使用 TensorFlow 通过一个简单的玩具数据集测试了我们的理解。通过使用替代后验生成新样本的可能性,使我们能够在数据中填补缺失值、生成新样本等,这在标准 PCA 中是不可能的。许多实际数据集展示了复杂的关系、噪声污染和不完整的观测,这使得经典 PCA 的效果不佳。通过考虑不确定性、处理缺失数据,并提供概率建模框架,PPCA 为数据分析、模式识别和机器学习开辟了新的途径。如果下次你打算使用 PCA 来处理你的数据集,但你知道观测可能会有噪声且存在缺失数据,为什么不尝试 PPCA 呢?
概率 PCA 还与因子分析(FA)密切相关,后者是一种线性高斯潜变量模型。FA 与 PPCA 之间的主要区别在于,在方程 2.2 中描述噪声分布时,PPCA 假设协方差是各向同性的,而 FA 中它是对角矩阵。我会在下面留下更多参考文献,以便你可以根据这篇文章进行进一步探索。
参考文献:
[1]《概率主成分分析》;M. Tipping, C. Bishop;J.R. Statist. Soc. B (1999)。
[2]《模式识别与机器学习》;C. Bishop;第十二章:连续潜变量。
[3]《连续潜变量模型》;多伦多大学;讲义。
[5] 我的笔记本链接。GitHub
数据分析师实际生活中的概率面试问题
将概率面试问题与数据分析师的日常任务联系起来
·
关注 发表在 Towards Data Science · 5 分钟阅读 · 2023 年 10 月 22 日
–
如果你申请数据分析师或数据科学家的职位,在面试中你会经常遇到概率问题。但问题是:有些人确信这些问题与实际工作关系不大。像“为什么我们要计算掷骰子 5 次都掷出 6 的概率?”这样的问题经常出现。在这篇文章中,我将分享一些真实的例子来解释为什么理解概率比你想象的更重要。为此,我们来看看一些面试任务,并了解它们在现实世界中的应用。
Q1. 你连续抛掷 10 次硬币,所有硬币都是正面朝上的概率是多少?
想象你是一个食品配送服务的数据分析师。每次订单完成后,客户可以评分食物的质量。团队的主要目标是提供顶级服务,如果餐厅收到差评,你需要检查。所以,关键问题是——多少条差评应该触发对餐厅的检查?
有时候,一个餐厅偶尔会收到一些不太好的反馈,这并不是他们的错。如果一个餐厅处理了 1000 个订单,他们可能会因为偶然原因收到几条差评。
这样考虑:大约 5%的订单偶然会收到负面评价。然后,每个餐厅的差评数量遵循二项分布Bin(n, p),其中“n”是订单数量,“p”是负面评价的可能性(在我们这里是 5%)。
所以,如果一个餐厅有 100 个订单,他们收到至少 7 条差评的概率大约是 23.4%,而收到至少 10 条差评的概率则小得多,只有 2.8%。你可以通过计算器这里来检查,参数是n=100、x=10、p=0.05,别忘了选择选项x>=X。
作者提供的图片。
结论是:如果你将检查阈值设定为 100 个订单中的 7 条差评,你可能会过于频繁地检查餐厅,这意味着你会增加额外成本,并对餐厅施加更多压力。
Q2. 你从 52 张标准扑克牌中抽取 10 次。抽到没有红色牌的概率是多少?
现在,想象你处在电子商务网站的世界里。你和你的团队刚刚引入了一种新的支付方式,你想知道客户使用这个新功能的频率。但问题是——由于一个小 bug,大约 2%的新支付请求会失败。换句话说,客户在 98%的会话中看到这个新支付选项。为了弄清楚客户选择这种支付方式的频率,你想关注那些始终可以使用它的用户。但这就有点棘手了。
设想一个只有一个会话的用户——你以 2%的概率将他们排除在分析之外。现在,考虑一个有 25 个会话的用户。对于他们来说,至少一个会话中没有该功能的机会是 1–0.98²⁵ = 39.7%。所以,你可能会无意中遗漏一些最忠诚的客户,这可能会扭曲你的分析。
图片由作者提供。
Q3. 如果你掷骰子三次,得到两个连续的三的概率是多少?
想象你在一家像 Uber 这样的打车公司工作。在某些国家,人们仍然用现金支付车费,这对司机来说可能是个麻烦。他们需要携带零钱,处理现金交易等等。
你的团队担心如果司机连续接到三个现金订单,他们可能会感到沮丧并且零钱用完。因此,你考虑在这种情况下限制现金订单。但在此之前,你想了解这种情况发生的频率。
假设每个司机每天的平均行程数是 10,其中 10%的行程以现金支付。
因此,得到三个连续现金订单的概率是 0.10.10.1 = 0.001. 但它可以是第 1、第 2、第 3 次订单;第 2、第 3、第 4 次订单,等等。这意味着连续三个现金订单的机会仅为 80.10.1*0.1 = 0.008. 看起来相当低,你可能要考虑暂时不实现这个功能。
Q4, 一项 HIV 测试的准确率是 99%(双向)。只有 0.3%的人口是 HIV 阳性。如果一个随机人检测结果为阳性,这个人 HIV 阳性的概率是多少?
原文文章见这里。
你在银行或信用行业,建立模型来预测客户是否会归还贷款。总体来说,85%的贷款通常会被偿还。在你最新的模型中,对于那些还款的客户,预测正确率为 92%。然而,当客户没有还款时,预测的正确率仅为 60%。现在,你有一个担忧:如果你的模型表示客户不会还款,那么他们实际还款的真实概率是多少?
首先,让我们计算模型预测“客户不会还款”的可能性。这涉及两个部分:
-
从不会还款的客户那里得到这种预测的概率:0.6*(1–0.85) = 0.09
-
从会还款的客户那里得到这种预测的概率:(1–0.92)*0.85 = 0.068
-
如果我们的模型认为客户不会还款,那么客户实际还款的概率是:0.068/(0.068+0.09) = 0.43
因此,如果你认为客户不会归还贷款,实际上他们有相当高的概率会归还。
这篇文章的全部意义是什么?它强调了理解概率和组合数学对数据科学家和分析师至关重要。在你的日常生活中,你会遇到需要掌握概率的情况,否则你可能会得出错误的结论。然而,从雇主的角度来看,面试问题应更具实际性,以帮助未来的分析师认识到这些知识在工作中的实际应用。
感谢你花时间阅读这篇文章。我非常希望听到你的想法,请随时分享你的评论或问题。
探讨最小样本量公式:推导与应用
A/B 测试中的样本量公式简明指南
·
关注 发表在 Towards Data Science · 18 分钟阅读 · 2023 年 2 月 1 日
–
太长;不读
本文回答了两个围绕 A/B 测试中最小样本量计算的重要“如何做”问题:
1) 如何推导最小样本量公式 𝜨**?**
公式背后的核心思想是反转假设检验中的 p 值计算,特别关注于统计功效,即在原假设确实为假的情况下拒绝原假设的概率。
图 1:最小样本量计算公式的直观解释
2) 如何在现实场景中使用公式计算最小样本量?
我们使用历史数据,将其汇总到正确的实验单位,来计算我们感兴趣指标的标准差以及历史平均值作为基线水平。然后,我们手动确定所需的显著性水平𝛼、统计功效1-𝛽,以及最重要的,期望在处理组和对照组之间的最小可检测效果𝑑𝑚𝑖𝑛。
这些是样本量计算所需的四个输入。值得注意的是,如果我们想检测到更微妙的增量提升(小𝑑𝑚𝑖𝑛),则在更强的统计能力和更长的实验周期之间存在权衡。因此,在设计 A/B 实验时,我们通常会计算在不同𝑑𝑚𝑖𝑛水平下的多个样本量版本,以量化这些权衡。
目录
-
我们为什么要关心
-
设置
-
I. 推导最小样本量(简单公式/无𝛽 𝛽)
• 关键思想
• 公式推导
• 示例:分析芝加哥两个社区之间的下载速度差异
-
II. 推导最小样本量(标准公式/带𝛽) • 关键思想 • 公式推导
• 示例:“揭开”Evan 的 A/B 样本量计算器的神秘面纱
-
实际问题
• 如何从我们的通用公式推导经验法则样本量?
• 如何确定𝑑𝑚𝑖𝑛?
• 如果两个组的方差和样本量不相等怎么办?
• 如何使用历史数据估算标准差𝜎?
• 我们需要多少历史数据?
• 我们应该如何准备这些数据?
-
总结
为什么我们需要关心?
在 2022/2023 年,当网上有大量现成的样本量计算器(例如,Evan Miller’s Calculator、Optimizely)可用,以及大多数科技公司都有自己建立的实验平台和工具时,深入探讨公式背后的原理似乎没有必要。然而,了解这些原理仍然很重要,因为:
-
A/B 测试样本量计算是产品数据科学家****面试中最常见的问题之一,我们需要不仅能够背诵公式,还要能解释其原理。
-
在在线随机 A/B 测试和离线对照试验中,现实世界的业务背景可能很复杂,因此我们需要了解公式的来龙去脉,以正确选择或调整这些工具并避免谬误。
-
此外,我们在假设检验和功效分析中建立“统计直觉”至关重要,因为由于样本量不足,常常会遇到无显著性检验结果。我们当然希望提高对可能的第二类错误(忽视真正有效的处理策略)的敏感度,并准备估算下一步所需的样本量。这不仅仅涉及技术公司中的典型数据科学工作,还涉及学术界和日常生活中的研究问题——“女性员工的薪酬是否低于男性?”、“财富更高的社区是否比欠发达社区有更好的互联网连接和下载速度?”等。我们不仅要问自己“我们从数据中看到的是否具有统计学意义?”,还要问“我们是否有足够的样本?如果没有,我们还需要多少数据或未来需要发送多少份调查问卷?”
关于这个主题,有许多优秀的入门文章和资源,但大多数仅仅“呈现”公式。站在这些前人的肩膀上,我希望在这里进一步深入探讨直觉和推导,回答**“这些公式为何如此定义”**的问题。此外,我还将讨论使用这些公式的一些实际问题。
设置
测试
本文中样本量计算的讨论将主要基于两个样本单侧 z 检验。
原因在于,在 A/B 测试场景中,测试统计量通常是样本均值(例如实验期间每用户的平均 XXX,其中该期间特定用户的所有行为被视为一个样本,而该指标只是所有用户的样本均值统计量),并且通常根据中心极限定理呈正态分布。这里倾向于使用单侧检验,因为我们通常对处理效果有明确的方向性信念(即处理组应表现得比对照组更好或更差)。
符号和概念
-
𝜨: 最小所需样本量。
-
𝑝: p 值。
-
𝛼: 显著性水平,或称为第一类错误。
-
𝛽: 第二类错误。1-𝛽表示统计功效。
-
X̄1/X̄2/X̄d: 第 1 组(对照组)、第 2 组(处理组)和两个随机均值变量之间差异的样本均值。
-
x̄1/x̄2/x̄d: 第 1 组(对照组)、第 2 组(处理组)和两个均值统计量之间差异的样本均值。
-
z_(1-𝛼): 切割顶部**𝛼(%)**的 z 分数,或标准正态分布的临界值。例如,在单侧检验中,z_(1–5%)=z_(95%)≈1.65; z_(1–2.5%)=z_(97.5%)≈1.96; z_(2.5%)≈-1.96。
-
**𝜎:**每组观察值的标准差。我们假设两个组的标准差相同(我们还将讨论当方差不相等时如何调整公式)。
-
𝑑𝑚𝑖𝑛:最小可检测效应,即在给定样本量 𝜨 的情况下,可以检测到的统计显著差异水平。例如,如果对照组的转换率为 10%,且我们预期处理能够将转换率从 10% 实际提高到 15%(请注意这是我们的预期),则(绝对)𝑑𝑚𝑖𝑛 为 15%-10%=5%。直观地说,如果我们预期的处理效应非常小(即我们对处理策略没有信心),我们将需要更多的样本来将其与自然变异区分开,因为真实效应可能非常微弱。如果最小可检测效应较大,它应该相当显著,因此少量样本就足以告诉我们两组之间的差异是否实际具有统计学意义。
目标
给定 I 型错误𝛼、II 型错误𝛽、标准差𝜎和最小可检测效应𝑑𝑚𝑖𝑛,我们希望估计最小所需样本量 𝜨,以便我们有足够的统计功效来检测到至少与𝑑𝑚𝑖𝑛一样大的差异。
换句话说,样本量 𝜨(仅)保证我们能检测到至少𝑑𝑚𝑖𝑛的处理效应或组间差异。如果实际的处理效应实际上小于𝑑𝑚𝑖𝑛,我们可能无法将这种微小的改进与自然变异区分开。
I. 推导最小样本量(天真的公式/不考虑𝛽)
虽然统计功效1-𝛽,或者当实际上存在差异时检测差异的概率,是估计样本量的关键因素,但当我开始学习 A/B 测试时,我发现很难理解。因此,在本文中,我们首先从这种*“天真的”*最小样本量估计版本开始,仅考虑显著性水平𝛼,然后在下一节中逐步进入“正式”的推导,其中引入统计功效的概念。
关键思想
计算最小所需样本量的关键思想就是假设检验的逆过程,或者换句话说,询问样本量 “就像 p 值已知为显著性水平 𝛼”**。
给定两个样本,我们可以基于样本均值、标准差和N计算 p 值。用图 2的话来说,我们希望样本均值差异统计量X̄d 落在红色区域内,因为这是原假设 H0 的拒绝区域。
图 2:假设检验中的拒绝区域和临界值(作者提供)
现在,由于显著性水平 𝛼 本质上是 p 值的 阈值,如果我们知道历史数据中的 标准差 𝜎 并指定 𝛼 的水平以及我们希望检测的两组之间的预期差异 𝑑𝑚𝑖𝑛,我们可以计算 N。这种公式下所需的最小样本量 N 确保我们有足够的“显著性”来检测两组之间的 𝑑𝑚𝑖𝑛 量的差异。
公式推导
基于上述思想,我们可以概述以下高级公式,我将围绕这个逻辑逐步讨论 N 的推导:
公式 1: 朴素最小样本量公式的高级概念(作者提供)
首先,我想明确我们的假设和测试统计量。零假设是组 1(对照组)和组 2(处理组)的均值统计量之间没有统计学差异。备择假设则表明我们的处理组‘优于’对照组,表明这是一个单侧检验。
公式 2: 两样本单侧 T 检验/Z 检验中的 H0 和 H1(作者提供)
我们关注的是组 1 和组 2 的均值统计量的差异:
公式 3: 双样本 T 检验/Z 检验的测试统计量(作者提供)
在进一步探讨之前,让我们先探讨分布性质。我们不需要对两组 𝛸1、𝛸2 做任何分布假设,因为根据中心极限定理,它们的 样本均值统计量 以及推导出的 样本均值差异统计量 遵循正态分布。
公式 4: 关键随机变量和测试统计量的分布性质(作者提供)
记住,p 值表示在零假设为真的情况下,随机变量获得一个值与当前样本统计量一样极端或更极端的概率。因此,p 值和 z 统计量可以通过以下公式计算:
公式 5: 两样本单侧 Z 检验的 p 值公式(作者提供)
因此,如果我们进行假设检验,当我们的 p 值小于 𝛼 时,我们将以 𝛼 显著性水平拒绝零假设,或者样本 z 统计量大于临界 z 值:
公式 6: 当 H0 被拒绝时的两样本单侧 T 检验/Z 检验的假设检验标准(作者提供)
太棒了!现在我们已经得出了给定两个样本组的假设检验最终公式。你可能已经注意到分母中有𝜨。尽管很简单,计算所需的最小样本量的方法是(1)将x̄d视为预定的最小可检测效应(MDE 或 𝑑𝑚𝑖𝑛);(2)假设原假设成立(μ2-μ1=0),以及(3)通过反转不等式来求𝜨。
-
要进行假设检验,我们通过计算组 1 和组 2 之间的样本均值差异来计算x̄d。
-
要计算最小样本量,我们用预期的最小可检测效应𝑑𝑚𝑖𝑛替换基于样本的x̄d,这代表了来自组 1(对照组)和组 2(处理组)的预期差异,或者在产品分析语言中,就是我们期望治疗策略带来的潜在增量效果。
因此,给定𝑑𝑚𝑖𝑛、𝛼和𝜎的𝜨的推导如下:
公式 7:给定𝑑𝑚𝑖𝑛、𝛼和𝜎的𝜨的推导(作者)
示例:分析芝加哥两个社区之间的下载速度差异
假设你是一个研究助理,与你的教授一起工作。团队当前的研究是调查互联网不平等,这通过下载速度,在服务不足和富裕社区之间来衡量,你的教授坚信/假设存在这种不平等,因为基础设施、服务提供商等方面存在差异。
两个关注的地理社区是林肯公园(芝加哥北区,最富裕的社区之一)和南岸(芝加哥南区,通常被认为是服务不足的社区)。团队已经对每个社区的 10 个随机选择的家庭进行了调查,教授让你进行数据分析以验证他的假设。
表 1:初步小样本(N=10)(合成数据,作者编造)
你意识到t 检验可能是检查林肯公园和南岸之间每个家庭平均下载速度是否存在统计显著差异的一个好方法。然而,在整理数据后,你发现结果是不具有统计显著性,你认为这是因为样本量太小,实际差异被自然变异所掩盖。
接下来出现的问题是:你建议团队收集多少更多的样本? 我们需要知道三个数字: 𝛼、𝜎 和 𝑑𝑚𝑖𝑛来根据公式 7 计算所需的最小样本量。
我们通常指定 5% 的显著性水平,并且在假设方差齐性的情况下,标准差可以用组 1 的样本标准差进行近似。对于𝑑𝑚𝑖𝑛,我们需要与教授沟通,借用领域知识并确定*组 1(南岸)和组 2(林肯公园)*之间的预期差异。目前,假设教授建议这两个社区之间家庭下载速度的预期差异或最小可检测提升(%)应为 1%。因此,我们得到:
公式 8: 𝛼、𝜎 和 𝑑𝑚𝑖𝑛 在下载速度示例中的规格(作者提供)
插入这些数字后,我们能够计算出所需的最低样本量,为 77。这个数字意味着我们应该调查 67 户家庭(我们已经有 10 户),而这个更大的样本(N=77)将使我们能够检测到 1% 或更大的下载互联网速度差异,前提是这两个社区之间确实存在这样的差异。
公式 9: 计算给定 𝑑𝑚𝑖𝑛、𝛼 和 𝜎 的 𝜨 的示例(作者提供)
完整的 Python 实现可以在我的 GitHub Gist 上找到,或嵌入的笔记本如下:
II. 推导最低样本量(标准公式/与 𝛽)
关键思想
在科技公司设计和启动 A/B 测试时,我们通常不仅关心获得显著的 p 值以拒绝原假设,更重要的是在原假设确实错误时拒绝原假设,这表明我们的处理确实有效,而不是第一类错误。因此,我们所需的样本量需要有足够的统计功效 (1-𝛽) 来检测这种处理效应。**图形上,如图 3所示,我们最‘渴望’的是两件事:
-
样本统计量(X̄d)超出了临界值,因为这是 H0 的拒绝区域。
-
样本统计量(X̄d)落在 分布 Ha 中的蓝色阴影区域,因为蓝色区域代表了替代分布,表示处理组和对照组之间的总体差异不为零。
这是计算所需最低样本量的关键思想。
图 3: 假设检验中的拒绝区域、临界值、统计功效(作者提供)
从上述提到的朴素版本到正式表示的公式转换,反映了统计功效 1-𝛽 的考虑,如下所示。随着由 1-𝛽 指定的临界值添加到分子中,所需的样本量 增加,这很有意义,因为我们在推断中“购买”了更多的效率。
公式 10:样本大小估计中的从简单公式到标准公式(作者)
另一种看法是,A/B 测试所需的最小样本大小是上述简单版本的扩展,因为现在包含/控制了第二类错误𝛽。计算最小样本大小的逻辑仍然是假设检验的反向,只是我们需要考虑两种分布 H0 和 H1,并纳入第一类错误/显著性水平𝛼和第二类错误𝛽。
公式推导
因此,我们通过设定拒绝错误 H0 的概率小于预设的统计功效 (1-𝛽) 来建立我们的高阶公式,如下所示。
请注意,我们现在是在以 H1 作为统计量来源的实际分布进行条件化,而不是 H0,后者假设处理组和对照组之间没有差异。
公式 11:正式最小样本大小推导背后的高阶思路(作者)
计算最小所需样本大小需要4 个输入:
-
我们需要指定显著性水平𝛼
-
我们需要指定的统计功效1-𝛽
-
标准差𝜎,我们可以从历史数据中计算得出
-
期望的处理效应大小,即我们想要检测的两个组之间的差异𝑑𝑚𝑖𝑛 (μ2-μ1)。
回顾一下,样本均值差异统计量遵循正态分布,其均值为μ2-μ1,标准差为2𝜎/N的平方根。
公式 12:关键随机变量和检验统计量的分布特性(作者)
上述高阶公式的推导包括两个步骤。首先,我们使用𝛼表示x_crit,即当假设 H0 为真时拒绝原假设的临界值。其次,我们使用𝑑𝑚𝑖𝑛、𝛼和𝛽求解最小样本大小𝜨。
公式 13:给定𝑑𝑚𝑖𝑛、𝛼、𝛽和𝜎的𝜨推导(作者)
示例:“揭示”埃文的神奇 A/B 样本大小计算器
这个“简单”但优雅的 A/B 样本量计算工具,如图 4所示,是由 Evan Miller 创建并发布的。它在 Udacity 和 Google 的优秀 A/B 测试入门课程系列中被提及,因此成为了许多数据科学学生,包括我在内的首选 A/B 测试工具。虽然这些样本量计算工具非常方便使用,但人们可能会担心,并不会完全信任它们,因为通常没有具体解释输出背后的逻辑。因此,我希望提供一个简要说明,以帮助揭开帷幕,展示这些计算器是如何工作的。
图 4:Evan Miller 的样本量计算器
首先,Evan 的 A/B 测试样本量计算器专门针对z-比例检验,其中关键指标是比例(如果你对不同类型的统计测试感兴趣,我认为我写的另一篇文章 如何选择适合不同 A/B 指标的统计测试 会对你有帮助)。
值得注意的是,在估计比例指标的样本量时,我们不需要输入标准差。 这是因为像转化率这样的比例指标可以被视为所有用户转化状态的聚合。而且因为用户要么转化,要么未转化,用统计学语言来说,这是一个随机伯努利事件,记作Bernoulli§。因此,转化率的分析方差估计为𝑝∗(1-𝑝)。这一良好性质意味着我们“不需要”输入标准差来计算最小样本量,因为我们总是可以使用样本平均转化率𝑝来估计真实的分析标准差,而𝑝就是基线转化率,如 Evan 的计算器输入框中所述。
牢记这一重要概念,让我们深入了解这个工具(图 4)。该工具需要四个输入:
-
基线转化率 𝑝:假设当前转化水平为 20%。
-
最小可检测效应(MDE):这里的定义可能有些棘手,但概念很简单。有“绝对”MDE 和“相对”MDE。5%的绝对 MDE 表示我们计算的样本量只能检测 15%到 25%之间的范围,而相对 MDE 定义了可检测的范围为 20%(1–5%)和 20%(1+5%)。例如,如果我们的产品经理认为推出这个新产品功能会将转化率从 20%提高到 25%,那么这意味着我们应该将𝑑𝑚𝑖𝑛设置为 5%,或者将绝对 MDE 设置为 5%,或者将相对 MDE 设置为 25%,因为 20%*1/4=5%。
-
统计功效 1-𝛽
-
显著性水平 𝛼
将其整理成我们目前建立的公式:
公式 14:给定𝑑𝑚𝑖𝑛、𝛼、𝛽和𝜎(双侧)在 A/B 工具示例中的计算𝜨的说明(作者提供)。手动计算的样本量(1003)与 A/B 测试工具生成的结果(1030)略有不同,这可能是精度舍入差异造成的。
公式 14:给定𝑑𝑚𝑖𝑛、𝛼、𝛽和𝜎(双侧)在 A/B 工具示例中的计算𝜨的说明(作者提供)
说话便宜:
作者
实际问题
我认为公式 13是最小所需样本量的一般形式。以下实际问题部分讨论了对更具体情况的几个重要扩展。例如,当方差不等时,当指标是比例时,双侧测试使用时与经验法则公式的关系。
如何从我们的通用公式推导出经验法则样本量?
许多资源提到一个方便的公式,N等于16*𝜎2/𝑑𝑚𝑖𝑛。如果我们进行双侧测试并指定𝛼 = 0.05、𝛽 = 0.2,我们可以将两个 z 分数平方和乘以 2 的结果四舍五入到 16。有关更多详细信息,请参见*《可信的在线对照实验》第 17.6 方程或Belle(2008)*。
公式 15:从一般形式到经验法则形式(作者提供)
如何确定𝑑𝑚𝑖𝑛?
我个人认为,𝑑𝑚𝑖𝑛或最小可检测效应(MDE)应通过PM 和 DS 之间的密切讨论来确定。这是一个需要仔细沟通和商业直觉的艰难决定:
-
DS 可以提供与以下内容相关的见解:(1)类似历史实验的 MDE;(2)关键指标的波动幅度;(3)A/B 测试的可用流量/样本量。
-
PM 可以提供与以下内容相关的见解:(1)潜在处理效应的预期大小;(2)我们所拥有的信心量;(3)此测试处理的重要性/范围;(4)此测试处理的紧迫性/时间表。
可能需要反复沟通,因为更高统计功效与更长实验周期之间存在权衡。假设处理效应很微弱(即𝑑𝑚𝑖𝑛应设定为较小),且需要更高的统计功效来检测。在这种情况下,我们必须积累更多样本,因此在日常流量通常不变的情况下,预计需要更长时间。
尽管确定𝑑𝑚𝑖𝑛是棘手的,我建议将不同𝑑𝑚𝑖𝑛及其对应的样本量组织成表格,如表 2(原始 Google 表格可以通过 这里 )。我们数据科学家首先选择几个可能的 MDEs并计算相应的样本量以及所需天数。MDEs 可以随时轻松调整,表格为我们提供了一个直接的对比示例,使我们能够与产品经理就上述权衡进行沟通:
表 2:计算不同 MDEs(𝑑𝑚𝑖𝑛)下所需最小样本量的示例表(作者提供)
如果两组的方差和样本量不相等呢?
在观察性研究中,两组之间有不同方差是非常常见的。当同质方差假设不成立时,只需做一些微不足道的调整,即在调整公式时用处理组方差和对照组方差的总和替代 2*𝜎2。
公式 16:在方差不等情况下调整样本量公式(作者提供)
更重要的是,我们可以进一步释放隐含的假设,即处理组和对照组之间的流量是相等的。
正如你可能已经注意到的,以上展示的正态分布样本均值差异统计量的尺度是我们在网上帖子中经常看到的所谓汇总标准差的人口版本:
公式 17:汇总标准差公式(作者提供)
如何使用历史数据估计标准差𝜎?
在 A/B 测试场景中,为了评估潜在的流量及所需时间,样本量计算通常在实验开始前进行。计算的(唯一)资源是历史数据。因此,问题是我们需要多少历史数据,如何准备这些数据?
我们需要多少历史数据?
建议使用过去数据的大致相同时间长度作为预期的 A/B 测试周期。此外,需要注意的是,我们必须确保历史数据的范围与处理策略一致。否则,外部有效性无法保证。例如,如果产品经理想测试扩大“下一步”按钮的效果,并希望在两周内收集有关转化率是否增加的实验结果,那么我们需要收集过去两周内曾(应当)接触过该按钮的合格用户的转化率数据。
我们应如何准备这些数据?
除了选择正确的时间范围和应用适当的过滤器到历史数据中,我们还要确保数据以正确的方式进行聚合,以便分析单元与随机化单元相匹配。换句话说,如果流量是在用户层面随机分配的,而用于度量的源数据通常是用户行为层面的,我们需要按每个用户进行分组,并聚合他们的行为以计算度量,然后计算历史样本均值和标准差,这将用于计算所需的样本大小。
摘要
在这篇文章中,我们探讨了最小样本大小公式的推导,并广泛讨论了使用该公式的实际问题。你已学到:
-
如何推导 A/B 测试中的最小样本大小公式
-
如何理解许多在线样本大小计算器以及其他版本的最小样本大小公式。
-
如何使用这个公式来计算样本大小,通过将历史数据准备成正确的格式,与产品经理讨论以确定𝑑𝑚𝑖𝑛
这篇文章特别受到*<可信的在线受控实验>*(作者:Diane Tang, Ron Kohavi, 和 Ya Xu)、Emma Ding 的 YouTube 视频“如何估算 A/B 测试中的样本大小”以及我在 TikTok 工作中的经验的启发和基础。
参考文献
-
Glen, S. 统计中的样本大小(如何找到):Excel,Cochran 公式,常见技巧。StatisticsHowTo.com.
-
Kohavi, R., Tang, D., & Xu, Y. (2020). 可信的在线受控实验:A/B 测试的实用指南。在可信的在线受控实验:A/B 测试的实用指南(第 I 页)。剑桥:剑桥大学出版社。
-
Singh, A. S., & Masuku, M. B. (2014). 采样技术与应用统计研究中的样本量确定:概述。国际经济、商业与管理期刊, 2(11), 1–22.
-
Van Belle, G. (2008). 样本量。统计学经验法则, 27–51.
-
Yamane, T. (1973). 统计学:入门分析-3.
使用广义加性模型(GAMs)生成见解
学习如何解释广义加性模型(GAMs)并从数据中提取有用的见解
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 1 月 30 日
–
今天我们将学习如何使用广义加性模型来预测 2011 年至 2012 年间华盛顿特区的自行车租赁数量。该数据集由共享单车公司 Capital Bikeshare 提供。共享单车系统是一种新一代服务,允许用户在方便的位置取车和还车,从而促进零排放的交通,具有对交通、环境和健康问题的积极影响。
GAMs 是什么?
“广义加性模型是一个广义线性模型,其线性预测变量涉及协变量平滑函数的总和”(Wood, 2017)。
GAMs(广义加性模型)通过添加协变量的估计权重来与线性回归相同。最大的区别在于,这些权重代表的是灵活的关系,而不仅仅是线性的关系,我们使用链接函数来建模目标变量。
我们可以将这种模型用于不同的应用:使用泊松分布建模疾病的传播,基于数值/分类变量使用二项分布(逻辑回归)预测患者是否有某种疾病,使用纬度和经度数据研究物种在某区域的空间行为。GAMs 是一个多功能的框架,可以应用于几乎任何领域。
在我们的背景下,我们可以创建一个模型来解释由于时间、湿度和温度的变化,租赁自行车的数量是如何变化的。此外,租赁自行车的数量表现得像一个正态分布。
我们可以正式写出我们的模型:
示例 GAM 的数学表达式
函数 f1、f2 和 f3 允许我们建模目标变量与解释变量之间的灵活关系。最终,这些关系的估计权重之和就是预测/估计的租赁自行车数量。在训练模型之前,我们将描述数据集,并开始进行一些探索性数据分析,以决定使用哪些解释变量。
加载库
# Basic wrangling functions
library("dplyr")
# Beautiful plots and data exploration
library("ggplot2")
# Comparing variables and data exploration
library("GGally")
# Library to fit gams
library("mgcv")
# Modern library for visualizing and assessing gams
library("gratia")
数据集
我们从 Interpretable Machine Learning 书籍的 GitHub 仓库下载数据集,这些数据已被预处理,并且准备好供我们使用。让我们解释一下每个变量:
-
season: 年中的季节。
-
holiday: 那天是否为假日。
-
workingday: 这一天是否为工作日,本质上是周末与否。
-
weathersit: 那天的天气情况,有三种类别。
-
temp: 温度,以摄氏度计。
-
hum: 相对湿度百分比。
-
windspeed: 风速,以 km/h 计。
-
days_since_2011: 一个时间步长变量,用于考虑时间的流逝。
-
cnt: 租赁的自行车数量。
-
weekday: 一周中的天
我们将使用下面的函数并总结感兴趣的变量。
get.bike.data = function(){
url = "https://raw.githubusercontent.com/christophM/interpretable-ml-book/master/data/bike.csv"
# Download file and save it as bikes.csv in our current folder
file = download.file(url, "bikes.csv")
# Read the file and return it as a data frame
bike = read.csv('bikes.csv', stringsAsFactors = T) %>% data.frame()
}
# Relevant variables
variables.of.interest = c('season','holiday', 'workingday', 'weathersit', 'temp', 'hum', 'windspeed', 'days_since_2011', "cnt", "weekday")
# Read data and extract variables of interest
bikes = get.bike.data() %>% dplyr::select(variables.of.interest)
# Summarise data
summary(bikes)
summary 函数的输出
我们巧妙地用一个函数总结了我们的数据,我们拥有每天的分类和数值数据。我们的数据集包含 730 条记录,大约是两年的数据。
数据探索
让我们开始分析数值变量,以检查它们是否对租赁自行车数量有影响。
# Compare each variable against the other
ggpairs(bikes %>% select(c(temp, hum, windspeed, cnt))) +
labs(subtitle = "Numeric variable exploration") +
my_theme()
图 1. 检查变量相关性
我们使用了ggpairs函数来创建这个图形,在这里我们可以可视化变量之间的相互影响并发现模式。总体而言,我们可以看到温度对自行车租赁数量有正面影响(见左下角的图)。让我们进一步探讨温度对其他解释天气情况和日期的分类变量的影响。
图 2. 按季节划分的天气情况及其对租赁自行车数量的影响
无论季节如何,只要天气好,租赁自行车的数量就会比天气差时高。例如,在冬季有三种天气情况,但总体而言,当天气不好时,租赁自行车的数量较低。此外,与冬季相比,夏季的租赁自行车数量显著较高。也许我们应该考虑天气情况对模型的影响。
图 3. 温度在不同季节和天气情况中的影响
左图确认了我们的发现,表明良好的天气显著有助于增加租赁自行车的数量。另一方面,右图显示温度对较冷的季节如秋冬有正面影响。有趣的是,当温度接近 30 摄氏度时,夏季出现了轻微的负面趋势,表明一些骑行者不愿意在高温下骑行。
训练广义加法模型(GAM)?
在 R 中训练/拟合广义加法模型(GAM),我们将使用mgcv库。它是一个功能强大且维护良好的库。
M = gam(cnt ~ season + weathersit + s(days_since_2011, bs ="cr", k = 70) +
s(temp, bs = "cr", by = season, k = 15), data = bikes, )
我们的模型将预测租赁自行车数量,并考虑以下因素:
-
weathersit: 每种天气情况将估算一个不同的截距。什么是截距?截距是指在其他所有变量设为零的情况下,基于给定天气情况的平均租赁自行车数量。
-
s(days_since_2011, bs = “cr”, k = 70): 这个术语表示我们将使用三次回归样条(平滑函数)来估算租赁自行车数量随时间变化的情况。k值是函数的阶数,决定了它的灵活性。稍后会变得清晰。
-
s(temp, by = season, k = 15): 这个术语估算了在不同季节中温度对租赁自行车数量的影响。由于 k 值较小,它不如上述术语灵活。
-
season: 每个季节都有不同的截距,因为我们使用了一个估算季节温度的术语。
所有这些项的总和将导致我们模型的预测结果。
总结和检查我们的模型
# Summarize model
summary(M)
图 4. 模型摘要
首先,图 4. 的模型摘要告诉我们目标变量的分布和链接函数。在我们的案例中:高斯(正态)和恒等函数(变量未改变)。我们有参数系数或截距,为每个季节和每种天气情况计算了一个截距(也称为参数效应)。在下一部分中,展示了我们平滑函数的有效自由度。这些值告诉我们关系的灵活性(本质上是它们与线性关系的差异,线性关系的 edf=1)。最后,我们有解释的偏差,在我们的案例中是 88%。我们的模型解释了 88%的数据。
# Checking k-value and edf
k.check(M)
图 5. 检查模型的灵活性
图 5. 中的下一个模型检查与平滑函数的秩有关。例如,第一个解释自行车数量与时间关系的平滑函数被赋予了 70 的秩。其中一个秩用于计算截距,其余的用于建模关系。edf 告诉我们所使用函数使用了多少可用的灵活性(69 个秩)。作为一个好的经验法则,我们希望 k 大于 edf,k-index 大于 1。我们还希望 p 值较大,而不是较小,这通常是情况。在我们的案例中,我们未能获得后两者,但我们已给函数足够的灵活性,如 edf 所示。
解释平滑效应
本节将展示平滑函数如何建模解释变量与租赁自行车数量之间的复杂关系。以下图形可以使用以下函数创建:
# Plot smooth and parametric effects
draw(M, parametric = TRUE)
图 6. 时间与租赁自行车之间的非线性关系
需要注意的一点是,所有的部分效应图都以均值为中心,阴影区域表示 95%的置信区间。这意味着 y 轴上显示的增减反映了租赁自行车的平均预测值。例如,可以通过陈述在前半段时间内自行车数量低于平均值来解释图 6.。我们还可以看到,由于所使用函数的 30 个有效自由度,关系是相当灵活的。
图 7. 温度对秋季租赁自行车的影响
图 7 中温度的影响没有很大的灵活性,但我们可以清晰地看到一些*“波动”。* 一个重要方面是 x 轴上的 rug 图,它显示了数据点的数量。当没有数据点时,置信区间会显著增大,因为我们不确定关系。我们可以通过说在秋季,温度的升高会增加租赁自行车的数量来解释此图。更准确地说,当温度约为 5 摄氏度时,预测的自行车数量比平均值低——1200(低 1200)。
图 8. 春季温度对租赁自行车的影响
在春季,温度的变化方式不同,预测的自行车数量在温度从 4 摄氏度升高到 20 摄氏度时急剧上升,但之后预测的自行车数量开始下降。我们需要仔细考虑较高温度值的的不确定性,rug 图表明数据点不多。如果有更多数据,这部分图表可能会显示平滑曲线?
图 9. 夏季温度对租赁自行车的影响
夏季非常有趣,因为我们看到在温暖的日子里数量下降。在图 9 中,我们可以看到温度超过 30 摄氏度对预测的租赁自行车数量有负面影响,低于预测平均值约——2500。这可能是因为骑行者不愿意在非常热的日子里骑车。
图 10. 冬季温度对租赁自行车的影响
最后,在冬季,随着温度的升高,预测的租赁自行车数量增加。注意图 10 显示了一个几乎线性的关系,与模型总结中的 1.24 edf 相符。虽然它的灵活性不如其他平滑函数,但按季节建模温度似乎是合适的。我们可以通过说在冬季 15 摄氏度时,预测的自行车数量大致与平均值相同来解释此图。
接下来我们将解释估计的截距或参数效应。
解释参数效应
图 11. 一年四季对租赁自行车的影响
图 11. 显示了带有 95% 置信区间的估计参数效应。我们可以通过陈述春季预测的自行车数量比秋季低 1489 辆来解读此图。我们对这一陈述充满信心,因为置信区间不包括零(我们的比较点)。如果预测也包括秋季的预测数据,春季的效应将不会显著,对吗?我们可以对冬季对预测自行车数量的影响做出相同的陈述。另一方面,夏季的效应不显著,因为其置信区间相当大并且包括零。这表明,在夏季,我们的模型预测的租用自行车数量也可以在秋季预测。我们可以从中提取一个见解:租用自行车数量的增加不是因为夏季,而是因为温度逐渐变暖。
图 12. 天气情况对租用自行车的影响
最后,图 12. 告诉我们,当天气雾霾或下雨/下雪/暴风雨时,预测自行车数量的变化是显著的(因为置信区间不包括零)。在这种情况下,效应是负面的,因为我们可以说,在雾霾天气的日子里,预测的租用自行车数量比良好天气的日子少 690 辆。
结论
我们训练了一个广义加性模型,以预测华盛顿特区的租用自行车数量,基于时间变化、季节中的温度变化和整体天气情况。我们学习了如何读取模型摘要并检查有效自由度以获得正确的拟合。最后,我们解释了模型的平滑效应和参数效应,以理解是什么驱动了租用自行车的数量,并获得了以下见解:
-
骑自行车的人在夏季气温过高时(高于约 25 摄氏度)通常会避免租用自行车。
-
租车数量的上升发生在天气温和、气温开始变暖的季节(秋季和冬季)。
-
我们对雨雪是负面影响租用自行车数量的显著因素充满信心。
总之,广义加性模型是一种强大的机器学习框架,具有很好的解释性,可以用于从数据中提取见解。将噪声转化为句子。
除非另有说明,所有图片均由作者提供。
参考文献
Wood, S.N., 2017. 广义加性模型:R 语言入门. 第二版。Chapman and hall/CRC.
代码
可以在我的 GitHub 仓库中找到:
github.com/alvarofps/Producing-Insights-with-GAMs
产品经理必须决定:功能还是用户设计
原文:
towardsdatascience.com/product-managers-must-decide-features-or-user-design-e3a14a27859
观点
产品经理可以优化用户设计或功能开发。不能两者兼顾。
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 1 月 18 日
–
图片由 Austin Distel 提供,来自 Unsplash
市场的二分化似乎在许多应用中都是一个普遍的主题。无论是在新闻报道中的左右之分,对数据科学家的需求,还是折扣与奢侈零售之间的分歧;稳定状态通常会定格在极端位置。当涉及到软件开发时,我注意到的两个极端是优化功能开发与优化用户设计。如果你的产品处于两者之间,你可能不会成功。
作者提供的图片
为什么只能选择一个?
之所以趋向极端,是因为这两个开发目标有着正交的路径。由于项目管理固有的约束(见下文),开发高质量的软件产品有一定的局限性。一个团队在给定的时间框架内只能优先处理有限数量的任务。这些约束,再加上你的定价和销售策略,将迫使开发在这两个极端之间做出选择。
作者提供的图片
让我们假设一个早期阶段的科技初创公司正在决定构建 B2B 项目管理软件的最佳开发路径。他们有 10 名员工负责在 12 个月内发布软件的 beta 版本。对于这个软件产品的最早决策之一是选择它是一个 web 应用程序、移动应用程序、两者之一,还是两者都有(即桌面 vs. 手机)。这一决定影响了技术栈,从而产生巨大的级联效应。如果你只追求 web 应用程序,但想要整合多个外部软件,这将需要更复杂的工程工作。这将减少优化用户设计的时间。另一方面,如果你希望有一个 web 应用程序和一个本地 iOS 移动应用程序,这通常需要两个 不同 的前端框架。由于前端工作量基本上翻倍,这减少了添加功能的时间。
在另一个工作流中,正在讨论最佳定价和销售策略。软件产品的复杂性,通常与功能数量相关,直接影响最佳销售策略。产品越复杂,推动以产品为主导的增长策略就越困难,这种策略通常强调自助获取用户。复杂的软件产品通常需要客户演示,定价更高,并且可能受益于第三方部署合作伙伴。优化用户设计而非功能的简单产品更容易实现以产品为主导的增长策略,因为用户会快速学会使用这些产品。虽然其中一些决策可能会被延迟,但在开发过程中,你的决策会更倾向于在功能开发与用户设计之间做出权衡。
根据个人经验,我曾积极参与选择某组织的企业实验室信息管理系统(LIMS)。这涉及了一个漫长的信息收集过程,我们研究了 LIMS 选项,接收了一些供应商的演示,并采访了用户以了解他们的需求。在所有这些工作之后,我们的决策最终选择了具有最多功能的 LIMS 供应商和具有最佳用户设计的供应商。也许了解这个框架可以节省一些时间。
为什么市场奖励极端?
在 Startup=Growth 中,Paul Graham 说道:“如果你有这样的想法,但增长不够快,竞争对手会更快地增长。在具有网络效应的业务中,增长过慢尤其危险,而最好的初创公司通常在某种程度上具备这种效应。”
我相信,由于网络效应和指数增长,软件产品市场趋向于两个极端。优化功能或用户设计将导致用户采纳的增长速度比选择两者组合更快,因为增长不是线性的。
作者插图
功能优化的产品将满足更多样化团队的需求,并增加未来需求的选择性,因为它比竞争对手有更多的功能。而简单的产品可能只适用于特定的利基市场,复杂的产品则可能满足多个部门类型或行业的需求。此外,开发更多功能通常需要依赖外部软件集成。团队的多样性,加上外部软件集成,将使产品接触到更大的潜在客户市场。每个新用户接触到产品都是免费的广告,推动增长循环。
用户设计优化的产品有几个固有优势。当与自助式免费增值模式(即产品驱动增长)结合时,简洁性将转换更多潜在客户,因为产品更容易学习。简洁性减少了用户获取过程中的摩擦。在已经使用产品的组织中,直观的产品将受益于增加的用户采纳。用户喜欢易于使用的产品。以简洁性为核心,用户池将增长,因为它也包括了较少技术用户。回想当我的技术挑战型父母准备购买他们的第一部智能手机时,我立刻推荐了 iPhone,因为史蒂夫·乔布斯对优化用户设计情有独钟。
何时应该专注于功能与用户设计?
开发策略应保持灵活,并尽可能保留选择权。当需要选择路径时,有许多因素我会考虑。
人才
您的员工在什么技术上最为擅长?如果您想开发一个本地移动应用,通常需要使用某些专门的语言。如果您想要一个具有无限外部集成功能的网页应用,那么您可能需要更多的全栈或后端工程师。您组织内的才能和/或您能负担得起的才能可能会决定您的策略。
目标消费者
您的目标消费者是谁?您是针对技术精通的千禧一代消费者,还是您的用户技术能力较差?如果您正在打造面向消费者的品牌,那么美学和简洁性可能会比 B2B 产品更为重要。
竞争
您是该行业的先行者,还是已经有几个直接和间接竞争者?竞争较少会给您更多选择,而竞争增加会促使您成为领导者的对立面以实现增长。
销售和定价策略
你打算如何获取和转化用户?你打算向用户收取多少费用?以产品为导向的增长战略对于更复杂的软件产品来说会更具挑战性,但也不是不可能的。Gitlab 前增长总监 Hila Qu 在一篇近期文章中深入探讨了这两种策略之间的权衡。
什么时候不重要?
分发能力 对于这一优化理论来说,至少有一个关键的例外——微软。当微软进入你的行业时,你可以把特性和用户设计扔到一边。看看 Teams 和 Slack 的对比。Slack 在几乎所有方面都是客观上更好的软件产品。然而,这些都无关紧要,因为 Microsoft Office/365 在组织中的嵌入程度如此之深,Teams 成为了大多数组织的预算考虑选择。分发能力让组织能够打开发展手册,选择自己的发展路径。
AI/ML 尽管大多数情况下特性开发和用户设计改进会被迫分开进行,但也有一些例外。推荐系统是许多产品中的常见 AI/ML 应用,我认为这是一个复杂的特性扩展,可以改善用户设计。一个大家熟悉的例子是 Twitter,当你选择关注某个人时,它会触发额外推荐的其他人关注。这个过程使得发现你更可能喜欢的其他账户变得更容易,这是一种用户体验的改进。然而,别搞错了,将 AI/ML 融入应用程序的后台可不是一件容易的事。
Twitter 推荐关注示例
结论
-
虽然有一些例外,但我注意到市场往往更青睐那些针对特性或用户设计进行了优化的软件产品——而不是两者兼顾。如果增长是主要目标,那么确定优化哪一条路径应当是优先考虑的事项。两者兼顾只会将用户转向那些正确优化的竞争对手。毕竟,独角兽之所以稀有是有原因的。
- 数据通才
使用无服务器容器服务将机器学习模型生产化
如何使用 Azure 容器应用程序创建无服务器容器化推理端点以用于你的机器学习模型
·发表于 Towards Data Science ·7 分钟阅读·2023 年 1 月 9 日
–
照片由 Jan Canty 提供,发布在 Unsplash 上。
介绍
无服务器容器架构是一种构建和运行容器化应用程序和服务的方法,无需管理底层基础设施。在这种架构中,容器用于打包和部署应用程序,这些容器在云服务提供商提供的完全托管环境中运行。
云服务提供商负责运行容器所需的基础设施,如硬件和操作系统。开发者无需担心设置或维护这些基础设施,可以专注于编写代码和构建应用程序。容器通常在集群中运行,云服务提供商会根据需求自动扩展容器的数量。这使得应用程序能够处理流量波动,而无需手动干预。无服务器架构可能比传统架构更具成本效益,因为用户仅需为实际使用的资源付费,而不是为可能未完全利用的固定计算能力付费。
一些无服务器容器服务的例子包括 Azure Functions、Azure Container Apps、AWS Lambda 和 Google Cloud Functions。本文将演示如何利用 Azure Container Apps,这是一个完全托管的无服务器容器服务,用于大规模构建和部署应用程序,用于生产化机器学习模型。Azure Container Apps 的常见用途包括部署 API 端点、托管后台处理应用程序、处理事件驱动的处理和运行微服务 [2]。
这些步骤将帮助我们训练并部署一个 scikit-learn 模型到 Azure Container Apps。
-
在本地训练模型
-
使用 FastAPI 创建推断 API
-
对应用程序进行 Docker 化
-
部署到 Azure Container Apps
0. 设置
这是以下示例使用的设置。
开发环境
-
Visual Studio Code
-
Azure CLI
-
Python 3.8
-
Docker
-
Python 包:有关
requirements.txt
的信息,请参阅第三部分
项目结构
项目文件夹结构如下:
FastAPI-Azure-Container-Apps
├─ fastapp
│ └─ app.py
├─ model
│ ├─ heart-disease.joblib
│ └─ train.py
├─ data
│ └─ heart-disease.csv
├─ Dockerfile
├─ requirements.txt
└─ .dockerignore
数据集
UCI 心脏病数据集 [3] 是一个公共数据集,包含关于被诊断为心脏病的患者的数据。它包括各种患者特征,如年龄、性别、血压和胆固醇水平。1
和 0
值在 condition
列中分别表示心脏病的存在与否。
1. 在本地训练模型
文件:train.py
为了演示目的,我们将使用仅 5 个特征来训练一个梯度提升分类器。
import pathlib
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from joblib import dump
print ('Load Data')
df = pd.read_csv(pathlib.Path('data/heart-disease.csv'))
y = df.pop('condition')
X = df.loc[:, ['age', 'sex', 'cp', 'trestbps', 'chol']].to_numpy()
print ('Train-Test Split')
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.2)
print ('Train Model')
gbc = GradientBoostingClassifier()
gbc.fit(X_train, y_train)
print ('Save Model')
dump(gbc, pathlib.Path('model/heart-disease.joblib'))
使用 CLI 运行训练:
python model/train.py
2. 使用 FastAPI 创建预测端点
文件:app.py
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
from joblib import load
import pathlib
app = FastAPI(title = 'Heart Disesase Prediction', version = '0.1.0')
model = load(pathlib.Path('model/heart-disease.joblib'))
class ModelData(BaseModel):
age:int=30
sex:int=0
cp:int=2
trestbps:int=115
chol:int=0
class ResponseData(BaseModel):
prediction_result:float=0.1
@app.post('/predict', response_model = ResponseData)
def predict(data:ModelData):
input_data = np.array([v for k,v in data.dict().items()]).reshape(1,-1)
prediction_result = model.predict_proba(input_data)[:,-1]
return {'prediction_result':prediction_result}
在 app.py
文件中我们
-
定义
app
,这是 FastAPI 的一个实例。 -
加载训练好的模型。
-
定义 API 接受的输入数据格式(
ModelData
) -
定义 API 响应格式(
ResponseData
) -
定义
/predict
路由,当对该路由发出 POST 请求时,将触发predict
函数。 -
predict
函数接收来自 POST 请求的输入数据,进行预测并返回患者患有心脏病的概率。
此时我们可以在本地测试 FastAPI 应用程序。--reload
标志有助于加速开发过程。FastAPI 会在检测到代码更改时自动重新加载,这意味着开发人员不需要手动重启 FastAPI 来测试代码更改。
# CLI
uvicorn fastapp.app:app --reload
你将在终端上看到以下消息:
INFO: Uvicorn running on <http://127.0.0.1:8000> (Press CTRL+C to quit)
给定的 URL 会将我们带到 Swagger UI,我们可以在这里测试 API。
3. 使用 Docker 对应用程序进行容器化
创建一个 **requirements.txt**
文件
requirements.txt
文件包含所有所需的 python 包。
fastapi>=0.68.0,<0.69.0
pydantic>=1.8.0,<2.0.0
uvicorn>=0.15.0,<0.16.0
numpy == 1.19.5
scikit-learn==0.23.2
joblib==1.1.0
nest_asyncio == 1.5.5
创建一个 **.dockerignore**
文件
.dockerignore
文件的目的是避免复制那些不用于推断的文件,例如训练脚本和数据。
data/
model/train.py
创建一个 **Dockerfile**
文件
FROM python:3.8
WORKDIR /app
COPY . /app
RUN pip install -r requirements.txt
CMD ["uvicorn", "fastapp.app:app", "--host", "0.0.0.0", "--port", "80"]
这是 Dockerfile 的简要描述:
-
使用
python:3.8
作为基础镜像 -
创建一个名为
/app
的工作目录 -
将项目文件夹中的所有文件复制到工作目录,除了
.dockerignore
文件中列出的文件或子目录。 -
安装
requirements.txt
中列出的 Python 包 -
CMD
在启动 Docker 容器时在 80 端口运行 FastAPI 应用。与本地测试不同,在这里运行 uvicorn 时不包括 --reload 标志。虽然 reload 标志对加快开发过程有帮助,但在生产环境中并不需要。
构建 Docker 镜像
docker build . -t heart-disease-app
启动 Docker 容器
我们将容器中 FastAPI 运行的 80 端口映射到 Docker 主机上的 8080 端口。
docker run -p 8080:80 -it heart-disease-app
测试应用
此时,我们可以通过访问以下 URL 再次通过 Swagger UI 测试应用:[
127.0.0.1:8080/docs](http://127.0.0.1:8080/docs)
4. 部署到 Azure 容器应用
在本节中,我们将把 Docker 镜像推送到 Azure 容器注册表,然后在 Azure 容器应用中部署 Docker 容器。
要将容器化应用部署到 Azure 容器应用,我们需要满足以下先决条件:
-
Azure 资源组
-
Azure 容器注册表
-
Azure 容器应用环境
以下命令将在命令行中执行。
创建资源组
资源组是用于支持应用的 Azure 服务的逻辑分组。我们在 eastus
区域创建了一个 heartdisease_rg
资源组。所有后续资源将分配给 heartdisease_rg
。
az group create --location eastus --name heartdisease_rg
创建 Azure 容器注册表
Azure 容器注册表(ACR)是一个用于存储和管理容器镜像的仓库集合。我们在 heartdisease_rg
资源组下创建了一个名为 heartdisease
的容器注册表,并选择了 Basic
SKU 定价计划。
az acr create --name heartdisease --resource-group heartdisease_rg --sku Basic
一旦容器注册表被配置,我们可以使用以下命令检查 ACR 登录服务器
az acr list --resource-group heartdisease_rg
上述命令返回一个包含登录服务器的长字符串。请注意登录服务器的详细信息,因为我们将在下一步中使用它。
...
"location": "eastus",
"loginServer": "heartdisease.azurecr.io",
"name": "heartdisease"
...
标记 Docker 镜像
要将本地 Docker 镜像推送到 ACR,Docker 镜像 heart-disease-app
被标记为 {login server}/{docker image name}/{version}
格式。
docker tag heart-disease-app heartdisease.azurecr.io/heart-disease-app:v0.1.0
登录 ACR
确保在推送镜像到 ACR 之前已登录。
az acr login -n heartdisease
将 Docker 镜像推送到 ACR
Docker push 是一个将本地 Docker 镜像上传到容器注册表的命令。这将使 Docker 镜像可供 Azure 容器应用使用。
docker push heartdisease.azurecr.io/heart-disease-app:v0.1.0
成功推送 Docker 镜像后,镜像将在 ACR 的 UI 中显示。
图片由作者提供。
创建 Azure 容器应用环境
在创建 Azure 容器应用之前,我们定义了容器应用将运行的 heartdiseaseenv
环境。
az containerapp env create --name heartdiseaseenv --resource-group heartdisease_rg --location eastus
创建 Azure 容器应用
在这一步中,我们创建了 heart-disease-container-app
Azure 容器应用,使用的是前一步中创建的 heartdiseaseenv
环境。我们还定义了应使用的 Docker 镜像:heartdisease.azurecr.io/heart-disease-app:v0.1.0
和容器注册表的登录服务器:heartdisease.azurecr.io
。ingress
设置为 external
,因为此 API 旨在公开发布到互联网。
az containerapp create --name heart-disease-container-app --resource-group heartdisease_rg --environment heartdiseaseenv --image heartdisease.azurecr.io/heart-disease-app:v0.1.0 --target-port 80 --ingress external --query properties.configuration.ingress.fqdn --registry-identity system --registry-server heartdisease.azurecr.io
如果 az containerapp create
命令成功,它将返回一个访问应用程序的 URL。
Container app created. Access your app at <https://heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/>
测试应用程序
我们可以使用 Swagger UI、curl 或 python 请求来测试应用程序。要访问 Swagger UI,只需在给定 URL 的末尾添加 docs
:[
heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/docs](https://heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/docs.)
.
使用 CURL,命令如下:
curl -X 'POST' \\
'<https://heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/predict>' \\
-H 'accept: application/json' \\
-H 'Content-Type: application/json' \\
-d '{
"age": 64,
"sex": 1,
"cp": 3,
"trestbps": 120,
"chol": 267
}'
我们也可以通过以下方式使用 Python 的请求库向预测端点 https://heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/predict
发送 POST 请求:
import requests
response = requests.post(url = '<https://heart-disease-container-app.nicehill-f0509673.eastus.azurecontainerapps.io/predict>',
json = {"age": 64,
"sex": 1,
"cp": 3,
"trestbps": 120,
"chol": 267
)
print (response.json())
# {'prediction_result': 0.8298846604381431}
结论
在本文中,我们讨论了使用无服务器容器化机器学习推理端点的优势,并演示了如何使用 FastAPI 创建 API 端点,用 Docker 容器化并使用 Azure 容器应用程序部署容器化应用程序的示例。
加入 medium 阅读更多类似的文章!
参考文献
[1] 无服务器计算和应用 | Microsoft Azure
[2] Azure 容器应用概述 | Microsoft Learn
[3] 心脏病数据集来自 UCI 机器学习库。根据 CC BY 4.0 许可协议授权。
工作效率技巧、数据职业见解及其他近期必读内容
·
关注 发表在 Towards Data Science · 作为 时事通讯 发送 · 阅读需 4 分钟 · 2023 年 10 月 26 日
–
数据科学是一个快速发展的领域,新工具不断出现,工作流程不断演变,职业路径也在迅速变化——有时仅在几周之内。
我们最受读者关注和讨论的文章反映了这些趋势,读者纷纷涌向那些由数据和机器学习专业人士撰写的优秀文章,这些专业人士根据他们的实际经验分享了深刻的见解。为了确保你不会错过我们最好的文章,我们很高兴分享一些过去一个月的亮点故事。这些文章涵盖了从编程到 LLM,再到数据讲述的广泛内容,但都专注于可操作的、第一手的建议。请欣赏!
-
编程曾经很难,直到我学会了这两件事 如何从“有志编程者”转变为能够实际竞争优秀编程职位的人?Natassha Selvaraj 的热门文章探讨了培养成长心态和建立日常编程习惯的实际方法。
-
6 种扼杀数据科学生产力的坏习惯 Donato Riccio 指出,提高生产力不仅仅是关于学习和做更多的事情;避免或打破那些对你的工作有害的习惯同样重要。Donato 关注的这些习惯尤其与数据科学家的日常工作流程密切相关。
-
忘掉 RAG,未来是 RAG-Fusion 检索增强生成已经成为优化大型语言模型的常见方法,但它也存在重大缺陷。Adrian H. Raudaschl 提出了 RAG-Fusion,这是一种经过修改的技术,通过将互反排名融合和生成查询纳入过程来解决这些挑战。
图片由 engin akyurt 提供,来源于 Unsplash
-
介绍 KeyLLM — 使用 LLM 进行关键词提取 在提高 LLM 效率的话题上,Maarten Grootendorst 最近分享了 KeyLLM 的发布消息,这是他对 KeyBERT 包的扩展,旨在大规模进行关键词提取。他随后通过一个基于开源 Mistral 7B 模型的示例为我们演示了其使用方法。
-
如何成为数据工程师如果你是一个初级 IT 从业者或中级软件工程师,想要职业转型,💡迈克·沙霍米罗夫的实践指南是一个很好的资源来探索数据工程角色的转变。
-
在远程工作的时代培养新的数据科学家远程和混合工作模式的转变对早期职业数据科学家产生了什么影响?斯蒂芬妮·基尔默提供了对雇主和员工在这一(相对)新领域面临挑战的深思,及他们可以采取什么措施以确保下一代数据专业人士仍能从经验丰富的前辈那里获益。
-
TimesNet:时间序列预测的最新进展跟上时间序列分析领域的最新前沿研究:马尔科·佩谢罗的最新解释聚焦于今年早些时候发布的论文中揭示的 TimesNet。这个模型利用基于 CNN 的架构在不同任务中实现了最先进的结果,“使其成为时间序列分析的基础模型的优秀候选者。”
-
公司今天可以实施的 5 种生成 AI 用例有时存在热议,有时存在实际价值——而在生成 AI 工具方面,商业领导者很难分辨两者的不同。巴尔·摩西为此提供了解救,概述了五个有前景的用例,展示了生成 AI 方法在公司中可能进行实验的意义。
-
Excel 中的互动仪表板如果你在寻找新的创意方式来以引人入胜和易于访问的方式展示你的数据,为什么不尝试一下 Excel?杰克·邱的逐步教程解释了如何充分利用“在‘非技术’世界中最广泛使用的数据工程和分析软件”来创建时尚的互动仪表板。
-
战略数据分析(第一部分) 在她最近推出的系列中,Viyaleta Apgar 提供了一个结构化、详细的概述,介绍了数据分析师需要回答的问题以及他们可以使用的各种有效方法。如果你还没有阅读过,我们推荐从头开始:第一部分概述了数据分析师可能处理的四种基本问题类型。(或者也可以跳过到第二部分,该部分侧重于描述性问题。)
我们最新的一批新作者
每个月,我们都很高兴看到一批新的作者加入 TDS,他们各自以独特的声音、知识和经验与我们的社区分享。如果你在寻找新的作家来探索和关注,只需浏览我们最新加入的作者的作品,包括 Daniel Warfield、Satwiki De、Samuel Montgomery、Alexander Nikitin、Aman Steinberg、Hamed Seyed-allaei、Matheus Cammarosano Hidalgo、Malte Bleeker、Christopher Karg、Akif Mustafa、Gabriel Moreira、Jake Teo、Ilia Teimouri、Jeremie Charlet、Ed Izaguirre、Silvia Onofrei、Markus Stadi、Kairo Morton、Josu Diaz de Arcaya、Deepsha Menghani、Jon Flynn、Lennart Langouche、Guillaume Colley、Angjelin Hila、Emmanouil Karystinaios、Sofia Rosa、Anthony Alcaraz、Kseniia Baidina、Kenneth Ball和 Nicholaus Lawson。
感谢您支持我们作者的工作!如果您喜欢在 TDS 上阅读的文章,可以考虑成为 Medium 会员——这将解锁我们所有的档案(以及 Medium 上的其他所有帖子)。
直到下一个 Variable,
TDS 编辑部
使用 cProfile 对 Python 代码进行性能分析
原文:
towardsdatascience.com/profiling-python-code-with-cprofile-328ae152fdfc
在这篇文章中,我们将探索如何使用 cProfile 模块对 Python 代码进行性能分析
·发表于 Towards Data Science ·阅读时长 6 分钟·2023 年 2 月 10 日
–
图片由 Arnold Francisca 提供,来源于 Unsplash
目录
-
介绍
-
什么是代码性能分析?
-
cProfile 基本使用
-
从终端使用 cProfile 对 Python 代码进行性能分析
-
使用 cProfile 对 Python 代码进行性能分析
-
使用 cProfile 在 Python 中对函数进行性能分析
-
导出 cProfile 数据
-
结论
介绍
现在程序员可以在几天内编写成千上万行代码。新程序和应用程序的复杂性不断演变,代码库包含多个函数,其中一些可能会减慢整个程序的性能。
专注于代码性能分析,特别是对你的 Python 代码进行性能分析以识别性能瓶颈,可以显著提高软件的性能,并改善用户体验。
在本教程中,我们将使用 Python 内置的 cProfile 模块,该模块提供 Python 程序的确定性性能分析。
什么是代码性能分析?
代码性能分析是分析程序性能的过程,特别是分析代码性能以识别潜在的瓶颈。
确定代码中运行缓慢的部分并优化这些代码可以显著提高软件的性能,减少内存使用和资源消耗。
在 Python 中,cProfile 性能分析工具可以跟踪 Python 脚本的执行时间和内存使用情况,帮助识别运行缓慢和高资源消耗的代码部分。
cProfile 基本使用
这是一个使用 cProfile 进行性能分析和解释输出的快速示例:
import cProfile
import re
cProfile.run("re.compile('Python')")
你应该得到:
Python
4 function calls in 0.000 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.000 0.000 :1()
1 0.000 0.000 0.000 0.000 {built-in method builtins.exec}
1 0.000 0.000 0.000 0.000 {built-in method builtins.print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}Now let’s interpret the output:
现在让我们来解释输出结果:
-
ncalls — 调用次数
-
tottime — 在给定函数中花费的总时间
-
percall — tottime 和 ncalls 的比率
-
cumtime — 当前函数及其子函数的累计时间
-
percall — cumtime 和原始调用的比率
-
filename — 每个函数的数据
从终端使用 cProfile 对 Python 代码进行性能分析
在开始 Python 代码性能分析之前,我们需要一些示例 Python 代码来进行测试。
让我们创建一个简单的脚本,它将“Python Programming”打印 5 次,并将其命名为main.py:
i=0
for i in range(5):
print('Python Programming')
i+=1
一旦你运行它,你应该得到:
Python Programming
Python Programming
Python Programming
Python Programming
Python Programming
使用 cProfile 性能分析 Python 代码最简单的方法之一是从终端运行 cProfile。
打开终端,导航到你的 Python 脚本 (main.py) 所在的文件夹,并运行:
python -m cProfile main.py
你应该得到:
Python Programming
Python Programming
Python Programming
Python Programming
Python Programming
8 function calls in 0.001 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.001 0.001 main.py:1()
1 0.000 0.000 0.001 0.001 {built-in method builtins.exec}
5 0.001 0.000 0.001 0.000 {built-in method builtins.print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
使用 Python 的 cProfile 对 Python 代码进行性能分析
使用 cProfile 直接在 Python 脚本中分析 Python 代码是另一种方法。
你需要将 cProfile 模块导入到 Python 环境中,并显式调用性能分析函数,同时将 Python 代码作为字符串传递给该函数作为参数:
import cProfile
cProfile.run("print('Python Programming')")
你应该得到:
Python Programming
4 function calls in 0.000 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.000 0.000 :1()
1 0.000 0.000 0.000 0.000 {built-in method builtins.exec}
1 0.000 0.000 0.000 0.000 {built-in method builtins.print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
在这个示例中,我们将代码简化为仅打印“一行‘Python Programming’”,因为如你所见,将代码作为字符串传递给**cProfile.run()**并不是最方便的选项。
应该有更好的方法,对吧?当然!在下一节中,我们将探讨如何使用函数和cProfile对 Python 代码进行性能分析。
使用 Python 的 cProfile 对 Python 函数进行性能分析
让我们重用前面部分的 Python 代码,现在将其放入函数**my_func()**中:
def my_func():
i=0
for i in range(5):
print('Python Programming')
i+=1
现在我们可以通过将该函数作为参数传递给**cProfile.run()**来轻松地对其进行性能分析:
import cProfile
def my_func():
i=0
for i in range(5):
print('Python Programming')
i+=1
cProfile.run('my_func()')
你应该得到:
Python Programming
Python Programming
Python Programming
Python Programming
Python Programming
9 function calls in 0.001 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.000 0.000 0.001 0.001 :1()
1 0.000 0.000 0.001 0.001 main.py:3(my_func)
1 0.000 0.000 0.001 0.001 {built-in method builtins.exec}
5 0.001 0.000 0.001 0.000 {built-in method builtins.print}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
导出 cProfile 数据
在前面的部分中,我们对 Python 代码进行了性能分析,结果在终端中打印出来。
那我们可以提取并保存分析数据吗?
是的!使用内置的pstats模块和cProfile,我们可以将分析结果提取并保存到一个简单的.txt 文件中。
请注意,profile.run() 功能定义如下:
profile.run(command, filename=None, sort=- 1)
当 filename 设置为 None 时,它会自动打印出性能分析报告。
现在,让我们将 filename 设置为一些示例文件名,如‘results’,并运行上一节的代码:
import cProfile
def my_func():
i=0
for i in range(5):
print('Python Programming')
i+=1
cProfile.run('my_func()', 'results')
你会注意到终端中的输出仅仅是函数应该生成的内容,但没有性能分析报告:
Python Programming
Python Programming
Python Programming
Python Programming
Python Programming
现在,你会在项目目录中创建一个名为‘results’的新文件:
图片由作者提供
我们有了包含分析报告的文件,但我们还不能打开它。现在我们需要使用pstats模块将其转换为.txt 文件,然后才能访问报告。
这段额外的代码将把‘results’文件转换为‘results.txt’:
import pstats
with open('results.txt', 'w') as file:
profile = pstats.Stats('results', stream=file)
profile.print_stats()
file.close()
你应该在目录中看到一个新文件:
作者提供的图片
现在我们已经成功创建了一个包含分析报告的.txt 文件。它应包含以下数据:
作者提供的图片
提取 cProfile 数据的完整代码
import cProfile
import pstats
def my_func():
i=0
for i in range(5):
print('Python Programming')
i+=1
cProfile.run('my_func()', 'results')
with open('results.txt', 'w') as file:
profile = pstats.Stats('results', stream=file)
profile.print_stats()
file.close()
结论
在这篇文章中,我们探讨了如何使用cProfile模块来分析 Python 代码。
代码分析帮助识别代码中的瓶颈,并帮助理解哪些部分的代码应该优化以提高整体性能。
如果您有任何问题或对某些编辑有建议,请随时在下面留言,并查看我的更多Python 编程教程。
最初发表于 https://pyshark.com 2023 年 2 月 10 日。
程序辅助语言模型
原文:
towardsdatascience.com/program-aided-language-models-93d226c7d9a0
大型语言模型(LLMs)可以编写代码,但如果它们能执行程序呢?
·发表于 Towards Data Science ·18 分钟阅读·2023 年 8 月 22 日
–
(照片由 Florian Olivo 提供于 Unsplash)
尽管大型语言模型(LLMs)被用于多种应用,但它们通常在解决基于推理的任务时遇到困难。随着链式思维和从少到多提示等提示技术的出现,这一问题显著减轻了。从高层次来看,这些技术通过在模型的提示中提供问题解决的例子来鼓励 LLMs 进行推理行为。然后,模型可以学会输出这些推理过程,并逐步解决潜在的问题。值得注意的是,这是一种仅依赖提示的方法,无需微调,显示了 LLMs 在给定足够上下文的提示时具备推理能力。
尽管链式思维提示等技术效果显著,但 LLM 预计需要生成问题解决的思维链和最终答案。有趣的是,这种方法可能导致一些特殊的失败情况,其中 LLM 可能生成准确的解决问题的推理,但仍会给出错误的答案。通常,这些错误是由于简单的错误(例如,计算错误)造成的。为了解决这个问题,最近的研究探讨了一种程序化的方法,鼓励 LLM 生成包含自然语言和代码组件的思维链。然后,LLM 可以通过外部解释器运行这些代码,以获得所需的输出。
为了理解这种方法为什么有用,我们应该注意到许多 LLM 难以解决的问题(例如,算术错误、无法评估复杂表达式等)可以在程序内部轻松表达和解决。因此,使用链式思维风格的提示在具有编码能力的 LLM(例如,Codex)上,可以将 LLM 的优势与任意 Python 程序的计算能力相结合!更具体地说,LLM 可以被鼓励生成包含自然语言和代码组件的问题解决理由,生成一个可以由外部解释器运行的脚本,以计算问题的最终输出。我们将在本概述中探讨这种方法,这对 LLM 在解决推理任务中的准确性和可靠性大有裨益。
(来自 [1, 2])
背景信息
预训练一个大语言模型(LLM)
尽管现代大语言模型具有令人难以置信的能力,但这些模型都基于一个简单的预训练程序,该程序对大量未标记的文本数据执行下一个词预测。虽然我们可以调整这个过程的细节(例如,数据的类型或混合),但大多数 LLM 的基本预训练方法保持不变。我们只需* i)* 从预训练语料库中采样一些文本,* ii)* 教会模型准确预测语料库中的下一个词/标记。就是这样! 这个简单而深刻的方法为现代语言建模奠定了基础。
但……还有一些从多年的研究中学到的技巧和经验,使我们能够让语言模型变得像ChatGPT或GPT-4一样强大。大多数模型使用相同的仅解码器架构,但仅仅通过预训练无法创建高性能的语言模型。我们需要:
-
足够的规模(即,大模型和预训练数据集)。
-
通过监督微调(SFT)和来自人类反馈的强化学习(RLHF)进行行为调整[11, 12]。
-
[可选] 领域专业化(即,在特定类型的数据上微调模型,例如代码或对话)。
如果我们正确执行所有这些步骤,我们可以创建一个强大的基础模型,它能够通过文本提示解决各种任务。值得注意的是,大多数语言模型的知识和信息都是通过预训练获得的(见“训练过程”部分 这里),但在预训练之后进行的这些额外的精细化步骤使 LLMs 变得更具可引导性和更有趣;见下文。
(来自 [11])
LLMs 在什么方面表现不佳? 语言模型在各种不同的应用中取得了令人印象深刻的表现,但它们并不完美。这些模型有已知的局限性,例如:
-
添加大数的困难
-
无法评估/解决复杂方程
-
对迭代过程的推理困难
例如,如果我们给大型语言模型(LLM)提供一个关于Fibonacci 数列的描述,然后要求它计算第 100 个数字,那么它很可能会失败!为什么会这样? 好吧,我们知道 LLMs 在进行算术运算时表现不佳,而解决 Fibonacci 数列(除非模型使用暴力记忆)需要在两个数字之间进行多次迭代加法。如果模型在每次迭代中有 95%的概率正确执行加法,那么第 100 个 Fibonacci 数正确的概率不到 1%!
快速免责声明。 最近发布的 GPT-4 使得关于 LLM 局限性的声明变得更加困难。例如,GPT-4 完全能够解决第 100 个 Fibonacci 数,甚至可以在最小提示努力下评估一些(相对)复杂的方程;见下文。
(来自 ChatGPT Plus)
鉴于此,对 LLM 能力的任何声明都需要保持一定的怀疑态度。这个领域迅速发展,模型每天都变得越来越强大和令人印象深刻(字面意义上)。
教授 LLMs 如何编码
如上所述,创建高性能 LLM 的一个(可选)部分是领域专业化。在预训练之后,LLMs 相当通用,仅能完成单一任务——下一个标记预测。如果我们想要一个在某个特定领域专业化或擅长执行特定任务(例如,信息检索对话或编写剧本)的 LLM,我们需要在大量展示该任务正确行为的数据上进行微调。这个技术的一个最成功的应用,特别是与此概述相关,是创建可以编写代码的语言模型。
(来自 [4])
类似于如何从互联网下载大量文本数据用于预训练语言模型,我们可以从公共来源(例如 GitHub)下载大量代码用于训练 LLMs,这使得编码成为专门化 LLMs 的一个特别完美的应用。例如,Codex [4]就是一个显著的模型,它使用从互联网下载的未标记文本数据和代码的组合进行训练。给定一个 Python 文档字符串,Codex 的任务是生成一个有效的 Python 函数,以执行文档字符串中概述的任务;见上文。
(来自 [4])
Codex 在人工策划的编码任务上表现极佳(见上文),甚至被用来驱动GitHub Copilot编码助手,揭示了 LLMs 不仅仅可以应用于自然语言!我们也可以将它们应用于许多其他具有类似结构的问题。在这种情况下,我们使用进一步的语言模型预训练来适应预训练 LLM 到新领域。值得注意的是,Codex 能够生成代码和自然语言输出,使其成为一个特别多用途且有用的 LLM。而且,创建这种领域特定的模型相对简单——我们只需要大量的代码进行训练。
思维链 (CoT) 提示
超越之前概述的限制,LLMs 最初因无法解决推理任务而受到批评。然而,该领域的研究带来了突破性的技术,如CoT 提示 [3],使 LLMs 能够相当准确地解决基于推理的任务。CoT 提示的理念很简单。我们只需使用少量样本学习来教 LLM 如何输出详细解释其答案的解决方案——适用于任何推理任务;见下文。这种方法极其实用,因为我们只需要生成少量解决方案示例来包含在提示中,而之前的工作则编纂了整个数据集用于微调。
(来自 [3])
与教 LLM 如何编码不同,我们通过 CoT 提示发现,这些模型能够在无需任何微调的情况下解决推理任务!相反,我们只需采用一种更好的提示方法来“解锁”LLM 解决复杂推理任务的能力。
“大型预训练语言模型具备内置的推理能力,但它们需要特定的提示才能释放其威力。” — 来自 [13]
鉴于我们在之前的概述中已经了解了很多关于 CoT 提示及其许多变体,我不会在这里深入探讨这个概念。然而,有一个显著的方面我们应该注意——语言模型被期望同时i) 生成思维链和ii) 从这个思维链中提取最终答案。尽管 CoT 提示是有效的,但我们可能会开始怀疑:依赖语言模型准确解决这两个步骤是否真的是一个好主意?
在语言模型中解耦推理和计算
我们知道,语言模型(在正确的提示方法下)能够提供准确的问题解决理由或详细的输出解释。然而,生成正确的理由并不意味着语言模型会正确解决问题!如果语言模型在给出最终答案时出现一个小的算术错误怎么办? 由于语言模型的基本局限性,像 CoT 提示这样的技术通常会遇到令人沮丧的失败案例,其中模型生成了准确的理由,但输出了错误的最终答案。这类错误通常被称为语言模型的组合性差距。
“我们衡量模型可以正确回答所有子问题但未能生成整体解决方案的频率,这一比率称为组合性差距。” —— 引自 [16]
在本节中,我们将探讨最近的研究,这些研究尝试通过利用已在代码上进行训练的语言模型的独特技能(例如,Codex [4])来编写连贯且功能性强的程序来解决这个问题。我们可以依靠语言模型生成问题解决的理由。但是,我们不是要求语言模型给出实际的答案,而是提示模型生成一个与理由相关的程序,这个程序在使用单独的代码解释器执行时,可以生成最终答案。因此,我们的理由变成了代码和语言的混合体——基本上是一个带有说明性评论的 Python 脚本!
程序辅助语言模型(PaL)
(来自 [1])
在 [1] 中,作者提出了一种受 CoT 启发的技术,称为程序辅助语言模型(PaL),该技术使用语言模型将基于推理的问题分解为逐步的问题解决理由。然而,这种理由包含了自然语言和(基于 Python 的)编程组件。生成这种混合理由后,我们可以通过 Python 解释器执行程序化部分来解决问题。这种方法的目标是消除语言模型生成正确推理链但仍产生错误最终答案的情况。
“这弥合了链式思维方法中的一个重要差距,即推理链可能是正确的,但产生了错误的答案。” —— 引自 [1]
使用 PaL,我们可以利用 LLM 生成解决问题的推理,但计算最终解决方案的过程(即,模型通常在这一部分挣扎的地方!)被委托给代码解释器,从而消除了算术或逻辑错误的潜在可能性。因此,LLM 只需学习如何生成解决问题的推理——解决方案是程序化得出的。我们可以通过少量学习教导 LLM 生成这种混合推理。然而,为了实现这一点,我们需要一个在自然语言和代码上都经过预训练的 LLM(例如,Codex [4])。
理解 PaL。 从高层次来看,PaL 采用的方法与 CoT 提示非常相似。我们使用一种少量提示的方法,提供几个将问题分解为相关推理的示例。CoT 和 PaL 之间的主要区别在于,PaL 使用的推理是由交错的自然语言和程序语句组成的;见下文。
(来自 [1])
PaL 中的每一步推理过程都附加了程序语句。然后,当这些程序语句被综合时,它们可以通过单独的 Python 解释器执行,以生成最终答案(即,通过单次、事后的执行完成)。PaL 正在通过少量学习教导 LLM 生成一个逐步解决所需问题的程序。有趣的是,[1]中的作者鼓励 LLM 通过利用 Python 注释语法(即 #
字符)生成基于自然语言的中间步骤,这使得语言组件能够插入到生成的程序中。换句话说,我们正在教导 LLM 通过逐步的程序和信息性注释来解决推理任务!
(来自 [1])
与 CoT 提示不同,PaL 使用的少量示例不包含最终解决方案。相反,示例仅仅是交错了自然语言语句的程序(没有其他东西!)。最终解决方案的生成委托给 Python 解释器,因此 LLM 不需要学习如何执行这一步骤;见上文。
更进一步,[1]中的作者观察到,为程序中使用的变量提供有意义的名称是有益的。这一发现表明,PaL 提出的推理过程是一种真正的混合方法,它融合了语言和程序组件。在编程和语言模式之间形成符号链接是重要的;见下文。
(来自 [1])
这效果好吗? PaL 在各种符号、数学和算法推理任务中进行了评估,结果显示它能够减轻许多与 CoT 提示相关的常见问题。所提出的方法与标准的 少量学习(在 [1] 中称为“直接”提示)以及 CoT 提示进行了比较。在数学推理任务中,PaL 与 Codex [4] 相结合,轻松超越了之前的提示方法,适用于各种不同的模型。值得注意的是,PaL 甚至超越了 Minerva [5],这是一种专门针对大量定量推理数据进行微调的 LLM;见下文。
(来自 [1])
从上表中,我们还应注意到,使用 Codex 的 PaL 在 GSM8K 上达到了最先进的性能,超越了 PaLM-540B(即更大的模型!)的 CoT 性能 15% 的绝对 top-1 准确率。有趣的是,[1] 中的作者指出,GSM8K
主要集中在较小数字的数学词题上(即,50% 的数字在 0-8 之间),并提出了 GSM-Hard
——这是一个包含更大数字的数据集版本。在更困难的数据集上,PaL 相较于 PaLM 的 CoT 提示在绝对 top-1 准确率上提高了近 40%,揭示了程序辅助提示对于需要复杂算术运算的大数字问题更具优势。
(来自 [1])
在符号和算法推理任务中,PaL 再次提供了显著的好处;见上文。实际上,PaL 在这个类别中的五个数据集中接近完全解决四个,达到了 >90% 的准确率。此外,PaL 似乎随着问题复杂性的增加而保持一致的表现;见下文。在这里,我们可以看到,大量数字或推理任务中的更多对象所带来的复杂性在程序处理上很简单,尽管直接用 LLM 处理这样的复杂性可能会引发问题。
(来自 [1])
思维程序(PoT)提示
如前所述,CoT 提示的推理过程有两个不同的步骤:
-
生成基于语言(或程序)的解决方案 rationale
-
根据这一 rationale 计算最终答案
LLM 擅长执行上述第一步,但可能在计算最终答案时遇到困难。通常,这个问题是由于算术错误或无法评估复杂表达式所致。简单来说,LLM 在处理复杂的数字任务时会遇到困难。在 [2] 中,作者旨在利用一种称为思维程序(PoT)提示的代码增强提示方法来缓解这个问题,并使 LLM 能够准确地解决复杂的数字任务。
“在 PoT 中,计算可以委托给程序解释器,该解释器用于执行生成的程序,从而将复杂的计算与推理和语言理解解耦。” — 摘自 [2]
正如我们所猜测的那样,PoT 提示与 PaL 非常相似。这两种技术都使用代码增强的提示技术来解决复杂的推理任务,并将推理过程的必要部分委托给代码解释器。更具体地说,PoT 提示利用基于代码的 LLM 的少量学习(例如,Codex [4])生成包含自然语言声明和代码(用 Python 编写)的混合推理。然后,将输出的代码部分卸载到解释器进行评估,从而将推理和计算解耦。
(摘自 [2])
相比之下,CoT 提示直接在 LLM 上进行推理和计算。这是一个问题,因为 LLM 在以下方面存在困难:
-
执行基本的算术运算(特别是大数运算)
-
评估复杂的数学表达式(例如,多项式或微分方程)
-
解决需要迭代的问题
这些问题通过上图展示,其中一个使用 CoT 提示的 LLM 无法评估一个简单的立方方程或在斐波那契序列的迭代计算中进行推理。幸运的是,我们可以用程序轻松解决这些问题!例如,我们可以使用 for 循环计算斐波那契序列,而立方方程可以轻松地用 Python 语法表示。然后,我们可以运行这个程序来生成正确的输出,从而消除对 LLM 的不必要依赖。
(摘自 [2])
PoT 的详细信息。 类似于 PaL,PoT 提示生成的解决问题的推理包含语言和代码组件。LLM 通过一系列包含问题对及相关“思维程序”(即包含解释计算的多步骤程序和自然语言声明)的少量示例来学习生成这种推理;见上文。
使用 SymPy 的符号数学编程示例
与 PaL 不同,PoT 编写的代码依赖于一个名为SymPy的符号数学库。这个包允许用户定义数学“符号”,然后将它们组合在一起形成复杂的表达式。为了评估这些表达式,我们可以将它们传递到 SymPy 的solve
函数中;见上文。更多细节,请查看这里的教程。尽管使用了符号数学,PoT 提示与尝试直接生成数学方程的 LLM 不同,之前的工作表明这非常困难[3]。这是因为 PoT 提示:
-
通过一个多步骤、基于理由的过程生成符号方程。
-
将符号变量与语义上有意义的名称关联起来。
与 PaL 类似,[2]中的作者指出,为程序中的变量分配有意义的名称确实对 LLM 的性能产生了可测量的影响。
结果。 PoT 提示在多个数学文字问题和金融问答数据集(例如,FinQA [8]和 ConvFinQA [9])上使用 Codex [4]和 GPT-3 [7]进行了评估。使用少量学习和 CoT 提示(包括一个可以访问外部计算器的 CoT 提示变体)的多个不同 LLM 被用作基准。如下表所示,PoT 提示在所有情况下都显著优于基准,这强调了将推理与计算分离的价值。
(来自[2])
有趣的是,[2]中的作者还发现零-shot PoT 提示(即类似于零-shot CoT 提示 [10])效果非常好。即使没有为 LLM 策划几个程序融入的理由,我们也可以通过 PoT 提示在数值任务中实现合理的性能。此外,作者对使用 PoT 提示提出了一个有趣的实际注意事项。为了避免生成完全基于语言的理由(即一个带有所有注释的程序),他们不得不手动抑制#
符号的概率。虽然这是一个小细节,但值得记住——我们不希望生成的程序仅仅是注释!此外,这也表明,使这种技术在实践中有效往往是脆弱和困难的。
我们能做得更好吗?
(来自[14])
PaL 和 PoT 在大多数实验中采用了贪婪解码策略,这意味着 LLM 会通过迭代选择下一个概率最高的 token 来生成输出序列。然而,我们可以使用各种更好的解码策略!一个值得注意的(且超级简单的)策略是自一致性[14]。该技术使用相同的 LLM 和提示来为一个问题生成多个不同的输出。然后,通过对生成的所有输出进行多数投票来得出最终答案;见上文。
(来自 [2])
当将自一致性应用于 PoT 提示时,我们会看到立即且显著的好处!如上所示,带有自一致性的 PoT 在几乎所有考虑的数据集中都达到了新的最先进性能。同样,PaL [1] 也受益于自一致性的使用,甚至用于探索更复杂的解码/提示策略,如最少到最多提示 [15](即 CoT 提示的一种变体,显式地逐步解决推理任务)。与这种更复杂的提示风格结合时,PaL 变得更加有效;见下文。
(来自 [1])
尽管 PaL 和 PoT 表现得相当好,我们可以通过对其提示技术进行一些易于实现的补充来使其更上一层楼。这些发现激发了进一步的实验。也许我们可以通过利用其他有用的技术,如提示集成,来获得额外的性能提升。
主要结论
尽管 LLM 本身非常有用,但在本综述中我们看到,当 LLM 能够访问有用的工具时,它们可以变得更加出色。特别是,我们了解到将 LLM 连接到外部代码解释器对于推理任务的性能极其有利。然而,为了使其效果良好,我们需要访问能够编写代码的 LLM。以下是一些主要结论。
为什么会有效? PaL 和 PoT 的有效性源于 LLM 能够生成准确的解决问题的理由,但往往在简单任务如算术和迭代中遇到困难。幸运的是,这些概念可以轻松地在程序中建模,使得将 LLM 连接到外部代码解释器成为一种直观且强大的解决推理问题的技术。简而言之,我们通过依赖 LLM 擅长的领域,并将剩余的解决问题的组件委托给能够更可靠地生成解决方案的代码解释器,获得了很多收益。
我们应该如何解决 LLM 的弱点? 正如这篇文章简要提到的那样,许多已知的 LLM 缺点正随着更强大的模型(如 GPT-4)的发布而得到解决。然而,我们在这份概述中看到,解决这些问题的替代方法可能更加可靠。特别是,依靠外部代码解释器可以解决由于 LLM 在解决基于推理的任务时所遇到的局限性问题。赋予模型执行代码的能力无疑扩大了其能力范围,这激发我们思考其他可能对 LLM 有用的 工具。
将思想表达为程序。 这项工作真正突显了程序可以被解释为表达个人思想的结构化语言这一事实。与自然语言相比,编程语言的约束更多,这使得它们能够轻松表达迭代、建模复杂方程等。然而,程序的形式化性质也限制了表达能力——用自然语言写诗要比在 Python 脚本中写诗容易得多(假设没有调用 GPT-4 API)! 在我看来,考虑自然语言和代码之间的差异是相当有趣的。我们在这里看到,将它们结合在一起可以发挥两者的优势。
结束语
非常感谢你阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的人工智能总监。我研究深度学习的实证和理论基础。你还可以查看我在 medium 上的 其他写作!如果你喜欢这篇文章,请在 twitter 上关注我,或者订阅我的 Deep (Learning) Focus newsletter,在这里我通过对流行论文的易懂概述帮助读者建立对人工智能研究主题的更深入理解。
参考文献
[1] Gao, Luyu 等. “PAL:程序辅助语言模型。” arXiv 预印本 arXiv:2211.10435 (2022)。
[2] Chen, Wenhu 等. “思维程序提示:将计算与推理解耦以进行数值推理任务。” arXiv 预印本 arXiv:2211.12588 (2022)。
[3] Wei, Jason 等. “思维链提示引发大语言模型的推理。” arXiv 预印本 arXiv:2201.11903 (2022)。
[4] Chen, Mark 等. “评估训练有素的代码大型语言模型。” arXiv 预印本 arXiv:2107.03374 (2021)。
[5] Lewkowycz, Aitor 等. “用语言模型解决定量推理问题。” arXiv 预印本 arXiv:2206.14858 (2022)。
[6] Chen, Wenhu. “大型语言模型是少量(1)次表格推理者。” arXiv 预印本 arXiv:2210.06710 (2022)。
[7] Brown, Tom 等. “语言模型是少量学习者。” 神经信息处理系统进展 33 (2020): 1877–1901。
[8] 陈志宇 等. “Finqa:一个关于金融数据的数字推理数据集。” arXiv 预印本 arXiv:2109.00122 (2021)。
[9] 陈志宇 等. “Convfinqa:探索对话金融问答中的数字推理链。” arXiv 预印本 arXiv:2210.03849 (2022)。
[10] 小岛武志 等. “大型语言模型是零-shot 推理器。” arXiv 预印本 arXiv:2205.11916 (2022)。
[11] 欧阳龙 等. “通过人工反馈训练语言模型以遵循指令。” 神经信息处理系统进展 35 (2022): 27730–27744。
[12] 托皮兰·罗马尔 等. “Lamda:用于对话应用的语言模型。” arXiv 预印本 arXiv:2201.08239 (2022)。
[13] 李一飞 等. “关于提升语言模型推理能力的进展。” arXiv 预印本 arXiv:2206.02336 (2022)。
[14] 王学智 等. “自我一致性提高语言模型中的思维链推理。” arXiv 预印本 arXiv:2203.11171 (2022)。
[15] 周登嵘 等. “从最少到最多的提示使大型语言模型能够进行复杂推理。” arXiv 预印本 arXiv:2205.10625 (2022)。
[16] 普雷斯·奥菲尔 等. “测量并缩小语言模型中的组合性差距。” arXiv 预印本 arXiv:2210.03350 (2022)。
通过 Go 和 Metal Shading Language 编程苹果 GPU
研究 Go、Cgo、Metal Shading Language、Metal Performance Shaders,以及对矩阵乘法的不同方法进行基准测试
·
关注 发表在 Towards Data Science ·12 分钟阅读·2023 年 12 月 4 日
–
图片由 Etienne Martin 在 Unsplash 上拍摄
下面我将描述如何使用 cgo 在 Go 和原生 C 之间进行接口,如何将其用于与 Apple 的 Metal Performance Shaders 框架的 Objective-C 绑定进行接口,如何与用 Metal 着色语言 编写的 自定义 GPU 代码(着色器)进行接口,最后将所有这些与手写的和 OpenBLAS 的基于 Go 的矩阵乘法操作进行基准测试。这是为我的 M2 MacBook 编写的。
源代码的布局,在 GitHub 上可用,如下所示:
高级源代码、库和设备布局
这量很大,所以我将其分解为这些部分,或者可以直接跳到 基准测试。
-
GPU 和浮点并行性
-
Metal GPU 基础
-
Metal 着色语言
-
Objective-C 绑定
-
Metal Performance Shaders 框架
-
Go 和 cgo
-
Go 实现基线和 OpenBLAS
-
结果
GPU 和浮点并行性
我假设大多数人此时直观上已经熟悉 GPU 在某些计算任务中的强大性能,尤其是一些支持机器学习的任务。直到我开始尝试 Metal,我才亲身理解它们比 CPU 强大 多少。
GPU 设计上极其高效于大规模并行浮点运算,这要求高内存带宽。我的 MacBook M2 有 8 个 CPU 核心和 8 个 GPU 核心,但为了对比,Nvidia RTX 4090 拥有 16384 核心,而 H100 拥有 16896 CUDA 核心及数百个额外的专用张量核心。GPU 通常支持 SIMD 处理,使其能够在多个数据点上同时执行相同的指令。
除了图形处理外,矩阵乘法和线性代数任务一般都受益于这种并发,得益于其高度可并行的算法。这反过来支持了核心机器学习工作负载,如训练和推断 [1] [2]。
CUDA 可能是最著名的 GPU 编程平台,专门针对 Nvidia 硬件。也有数学框架可用于 OpenGL。像 TensorFlow 和 PyTorch 这样的框架可以与 GPU 硬件 轻松集成,且透明度相当高。这篇文章 对将支持 Metal 的 GPU 框架集成到 spaCy NLP 库 中的性能提升做了有趣的分析。
Metal GPU 基础知识
直接编程 GPU 计算并不像编写设备 CPU 代码那样简单。当使用 Apple 的 Metal 框架时,执行 GPU 上代码的大致操作步骤如下:
-
寻找适当的 GPU 设备
-
创建一个用于执行命令的队列(即 MTLCommandQueue)
-
将数据数组的指针封装到结构化缓冲区中;如果数据是可执行代码,则使用 管道状态,否则使用 常规缓冲区。Apple 的 GPU 使用 统一内存空间,这意味着我们不需要实际 复制 任何数据到特定于 GPU 的物理内存中。
-
提交命令缓冲区以进行执行,并等待结果或在完成时设置事件处理程序
-
从响应缓冲区中提取字节,并使用 CPU 程序代码在本地格式化
原始 GPU 编程使用异步模型。
Metal 着色语言
Metal 着色语言 是 C++14 的一种衍生语言,可用于编写自定义逻辑(称为“着色器”),以在兼容 Metal 的 GPU 上运行。一般来说,如果可能的话,使用 MPS 框架(稍后讨论)来实现等效功能可能更好——它通常针对常见的 GPU 对齐用例(如矩阵乘法或 神经网络)进行了高度优化。
MSL 代码的调试相当困难。你可以通过 Xcode 使用着色器调试器,但如果你想在没有 Xcode 的情况下检查或打印中间值,你需要将数据写入响应调试缓冲区,并在你的 C++或 Objective-C 包装器中解析这些原语。
MSL 函数通过kernel
标识公开为公共接口。Metal 框架传递当前调用线程上下文或线程组的 ID,这些 ID 可以用来确保非重叠写入。线程可以通过三维 ID 系统表示;这个线程空间的维度在包装器代码中配置。
以下是原始矩阵乘法算法的实现,结合了一些循环展开,令人惊讶地显著提高了性能。这只是为了比较;通常,MPS 的MPSMatrixMultiplication
功能会更合适。
kernel void matrix_multiply_naive(
device const MatrixParams *params,
constant float *A,
constant float *B,
device float *C,
// Indicates the thread's unique position within the entire grid of
// threads being executed. The uint2 type is a 2D coordinate, with
// fields x and y representing its indices on each axis.
// This parameter is not directly provided from the calling code,
// but provided by the Metal framework
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params->a_rows || gid.y >= params->b_cols) {
return; // This thread is out of matrix dimensionality range, do nothing
}
float sum = 0.0;
int k;
// Loop unrolling; improves performance by a notable margin
for (k = 0; k <= params->a_cols - 4; k += 4) {
sum += A[gid.x * params->a_cols + k]
* B[k * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 1]
* B[(k + 1) * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 2]
* B[(k + 2) * params->b_cols + gid.y];
sum += A[gid.x * params->a_cols + k + 3]
* B[(k + 3) * params->b_cols + gid.y];
}
// Handle any remaining elements
for (; k < params->a_cols; ++k) {
sum += A[gid.x * params->a_cols + k] * B[k * params->b_cols + gid.y];
}
C[gid.x * params->b_cols + gid.y] = sum;
}
我还在 MSL 中实现了一个原始转置函数以供比较。给定一个转置矩阵,这是对上述逻辑的一个微不足道的调整,其内部循环遍历 B 的行而不是列:
// Loop unrolling; improves performance by a notable margin
for (k = 0; k <= params->a_cols - 4; k += 4) {
sum += A[gid.x * params->a_cols + k]
* B[gid.y * params->b_cols + k]; // Note this is gid.y * cols plus k
sum += A[gid.x * params->a_cols + k + 1]
* B[gid.y * params->b_cols + k + 1];
sum += A[gid.x * params->a_cols + k + 2]
* B[gid.y * params->b_cols + k + 2];
sum += A[gid.x * params->a_cols + k + 3]
* B[gid.y * params->b_cols + k + 3];
}
// Handle any remaining elements
for (; k < params->a_cols; ++k) {
sum += A[gid.x * params->a_cols + k] * B[gid.y * params->b_cols + k];
}
我在早期的博客文章中讨论了这种方法,这是一种相当简单的方法,可以提高原始算法的标量性能,至少在 CPU 上是如此。更多内容稍后会讨论。
Objective-C 绑定
Metal 框架提供了从 Metal 源代码编译库的能力。一旦文件内容被加载,绑定代码会按名称查找内核函数,并初始化一个新的MTLComputePipelineState
,表示编译后的函数代码。
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
// Compile and initialize a new library located at the provided source path.
MTLCompileOptions *compileOptions = [MTLCompileOptions new];
compileOptions.languageVersion = MTLLanguageVersion3_0;
// Wrap input source path string
NSString *ss = [NSString stringWithUTF8String:source_path];
// Initialize new library containing compiled shader functions
id<MTLLibrary> lib = [device newLibraryWithSource:ss
options:compileOptions
error:&error];
// Create a representation of the naive multiplication public shader function in
// the Metal library created above
id<MTLFunction> naiveFunction =
[lib newFunctionWithName:@"matrix_multiply_naive"];
// Create the new compute pipeline state
id<MTLComputePipelineState> pipelineStateNaive = [device newComputePipelineStateWithFunction:naiveFunction
error:&error];
为了实际调用原生 Metal 代码,线程配置需要设置,并初始化 GPU 缓冲区。
[computeEncoder setComputePipelineState:pipelineStateNaive];
MTLSize threadsPerGrid = MTLSizeMake(params->a_cols, params->a_rows, 1);
// Calculate a threadgroup size.
// https://developer.apple.com/documentation/metal/calculating_threadgroup_and_grid_sizes?language=objc
NSUInteger w = pipelineStateNaive.threadExecutionWidth;
NSUInteger h = pipelineStateNaive.maxTotalThreadsPerThreadgroup / w;
MTLSize threadsPerThreadgroup = MTLSizeMake(w, h, 1);
// Encode kernel function inputs
[computeEncoder setBytes:params length:16 atIndex:0];
[computeEncoder setBuffer:bufferA offset:0 atIndex:1];
[computeEncoder setBuffer:bufferB offset:0 atIndex:2];
[computeEncoder setBuffer:bufferC offset:0 atIndex:3];
// Encode the compute command.
[computeEncoder dispatchThreads:threadsPerGrid
threadsPerThreadgroup:threadsPerThreadgroup];
// End the compute pass.
[computeEncoder endEncoding];
// Execute the command.
[commandBuffer commit];
这内容比较多,我在这里阐明一下关系:
Objective-C 包装器中的概念、类型和硬件的高级布局
Metal 性能着色器框架
MPS 框架是苹果公司提供的高性能库,用于其Metal GPU 系列。它提供从图像任务到神经网络支持的功能。
API 主要通过 Swift 或 Objective-C 提供,尽管也有一个 Metal-cpp 库可供使用。
MPSMatrixMultiplication API 相对容易使用。与上述 MSL 代码一样,MPS 命令仍需编码到 MTLCommandBuffer
中,并异步提交执行。
// Define Matrix "descriptions", accounting for matrix dimensionality and byte size
MPSMatrixDescriptor *descriptorA = [MPSMatrixDescriptor matrixDescriptorWithDimensions:a_rows
columns:a_cols
rowBytes:a_cols * sizeof(float)
dataType:MPSDataTypeFloat32];
MPSMatrixDescriptor *descriptorB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:b_rows
columns:b_cols
rowBytes:b_cols * sizeof(float)
dataType:MPSDataTypeFloat32];
// Output matrix
MPSMatrixDescriptor *descriptorC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:a_rows
columns:b_cols
rowBytes:b_cols * sizeof(float)
dataType:MPSDataTypeFloat32];
// Initialize matrix representations using above descriptions and matrix buffers
MPSMatrix *matrixA = [[MPSMatrix alloc] initWithBuffer:bufferA descriptor:descriptorA];
MPSMatrix *matrixB = [[MPSMatrix alloc] initWithBuffer:bufferB descriptor:descriptorB];
MPSMatrix *matrixC = [[MPSMatrix alloc] initWithBuffer:bufferC descriptor:descriptorC];
// Creates the multiplication instance
MPSMatrixMultiplication *matrixMultiplication = [[MPSMatrixMultiplication alloc] initWithDevice:device
resultRows:a_rows
resultColumns:b_cols
interiorColumns:a_cols];
// Encodes the multiplication command into the command buffer for the GPU
id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
[matrixMultiplication encodeToCommandBuffer:commandBuffer
leftMatrix:matrixA
rightMatrix:matrixB
resultMatrix:matrixC];
Go 和 cgo
我不特别喜欢使用 Objective-C,这个程序的重点是运行源自 Go 程序的 GPU 代码。
Cgo 是一种 Go 语言功能,允许 Go 编译器理解与本地 C 代码相关的注释中的编译指令。它支持一种 外部函数接口。
指令配置有点脆弱,但任何紧接着 import “C”
行的注释(称为“前言”)在编译引用的 C 代码时将被解释为头文件导入或编译参数。例如:
/*
#cgo LDFLAGS: -framework Foundation -framework CoreGraphics -framework Metal -framework MetalPerformanceShaders -L/opt/homebrew/opt/openblas/lib -lopenblas
#include <stdlib.h>
#include "metal.h"
*/
import "C"
-
通过命令行
LDFLAGS
传递链接标志给链接器 -
使用标准头文件
stdlib.h
编译 C 代码 -
使用本地项目头文件
metal.h
编译 C 代码
需要一些反复试验才能找到适用于 MacOS 的正确链接器标志。
-
Foundation
:基础库 -
CoreGraphics
:在 MacOS 上与 GPU 接口时必需 -
Metal
:用于 Metal 的库和语言支持,包括 MSL -
MetalPerformanceShaders
:上述讨论的 MPS 库
事实证明,Apple 在其 Accelerate
框架中捆绑了一个 BLAS 实现,因此除了通过 brew
安装 OpenBLAS 外,还需要在链接时提供库的位置:
-L/opt/homebrew/opt/openblas/lib -lopenblas
go:embed
指令允许 Go 程序在编译时包含文件,这在我们希望将 MSL 源文件(mm.metal
)的内容传递给 Metal 框架时是 非常有用的,如上所述,用于编译。
//go:embed mm.metal
var source string
// Compile the shader source code and initialize pipelines. The metalSource
// param contains the contents of an embedded Metal Shading Language file.
func Compile (metalSource string) {
// Wrap string in a C string
src := C.CString(metalSource)
// Free the above string after command queue is initialized
defer C.free(unsafe.Pointer(src))
// Compile the source, initialize pipelines and command queue
C.initializePipelineAndCommandQueue(src)
}
上述对 C
的引用是通过 cgo 与 C API 接口,例如:
// Calls initializeMTLBuffers from Obj-C bindings
C.initializeMTLBuffers(
a_data, // Input opaque pointer for A
b_data, // Input opaque pointer for B
C.int(4), // Converts 4 into C integer type
C.int(a.Size()),
C.int(b.Size()),
C.int(a.Rows * b.Cols))
params := MatrixParams{
a_rows: int32(a.Rows),
a_cols: int32(a.Cols),
b_rows: int32(b.Rows),
b_cols: int32(b.Cols),
}
// Return an unsafe pointer to this MatrixParams struct, cast to
// the native C representation defined in the shared header file
return (*C.MatrixParams)(unsafe.Pointer(¶ms));
注意,这意味着 C
是一个保留关键字,不能用作变量名。
Go 实现基线和 OpenBLAS
我想将基于 GPU 的矩阵乘法性能与更高级的实现(如 Gonum library)以及直观的手写(且相对低效)实现进行比较。
我在 Go 中实现了几种不同的算法,包括这个并行转置的简单算法,它将乘法工作天真地划分到 N 个 goroutine 中:
func (a Matrix[T]) TransposeMultParallel(b *Matrix[T]) *Matrix[T] {
if a.Cols != b.Rows {
panic("matrices are the wrong size for multiplication")
}
c_data := make([]T, a.Rows*b.Cols)
t := b.Transpose()
var wg sync.WaitGroup
for i := 0; i < a.Rows; i++ {
wg.Add(1) // Add a count to the WaitGroup for the new goroutine
go func(i int) { // Kick off goroutine
defer wg.Done() // Decrease the count when the goroutine completes
ptr := i * b.Cols
for j := 0; j < b.Cols; j++ {
var sum T = 0.0
for k := 0; k < a.Cols; k++ {
sum += a.At(i, k) * t.At(j, k)
}
c_data[ptr+j] = sum
}
}(i)
}
wg.Wait() // Wait for all goroutines to complete
return InitMatrixWithData(a.Rows, b.Cols, c_data)
}
Gonum BLAS
是一个纯 Go 库,它实现了 BLAS 接口。然而,它也可以配置为将代数运算转发到本地代码 BLAS 实现,例如通过netlib的OpenBLAS。
我上面展示了如何配置cgo
以正确链接到 MacOS 上的 OpenBLAS 安装。在应用程序代码中,可以直接设置首选的 BLAS 实现。从基准测试代码:
// Convert primitive arrays into gonum dense matrix types
gonum_a := mat.NewDense(a_rows, a_cols, a64_data)
gonum_b := mat.NewDense(b_rows, b_cols, b64_data)
gonum_c := mat.NewDense(a_rows, b_cols, nil)
gonum_d := mat.NewDense(a_rows, b_cols, nil)
// Configure Gonum to use Gonum-default Go implementation
blas64.Use(gonum.Implementation{})
// Run a multiplication using Gonum BLAS impl
start = time.Now()
gonum_c.Mul(gonum_a, gonum_b)
bdata.TimeGonumNative(start)
// Configure Gonum to use Netlib which forwards operations to a
// native C-code BLAS implementation (OpenBLAS in our case)
blas64.Use(netlib.Implementation{})
// Run a multiplication using OpenBLAS impl through Gonum API
start = time.Now()
gonum_d.Mul(gonum_a, gonum_b)
bdata.TimeGonumOpenBLAS(start)
结果
我的基准测试代码运行了几次以下矩阵乘法实现的试验,并报告了每次乘法两个逐渐增大的方阵所花费的平均时间:
- Naive multiplication, in Go
- Transposed naive multiplication, in Go
- Goroutine-parallelized transposed naive multiplication, in Go
- Gonum pure Go-based BLAS multiplication
- Gonum-wrapped OpenBLAS multiplication, written in C
- Hand-implemented naive multiplication, in MSL, on GPU
- Hand-implemented transposed naive multiplication, in MSL, on GPU
- Metal Performance Shaders framework, called from Objective-C, on GPU
基准测试输出如下(浮点数为毫秒):
2023-12-01 11:12:51.644 go-mm[75818:22427382] Using default device Apple M2
elements naive transpose transpose_parallel metal_naive metal_transpose mps gonum openblas
160000 196.00 201.00 42.00 8.00 9.67 0.33 4.67 6.00
250000 381.33 387.67 80.67 11.00 11.67 0.00 8.33 21.00
360000 801.00 789.33 159.33 19.00 16.33 0.00 14.33 4.67
490000 1228.00 1075.00 411.00 23.67 24.33 1.00 26.67 16.33
...
一些快速绘图通过matplotlib
所有方法的性能图
正如预期,我手写的 Go 实现相对失控。实际上,其他方法速度如此之快,以至于在图中无法区分它们。以下是这次运行的 GPU 使用滑动直方图
活动监视器 GPU 历史可视化 — 所有方法(Y 轴为使用百分比)
你可以看到 GPU 并不是特别忙碌,因为时间主要花在了 CPU 操作上。以下是另一轮测试,排除了最慢的三种乘法技术:
排除我手写的 Go 变体的各种方法性能图
大约 16M 元素(4k x 4k),Gonum
开始下降。可以清楚地看到,基于 GPU 的和OpenBLAS
操作优于纯 Go 实现。仅看基于 GPU 的方法:
仅在 GPU 上运行的矩阵乘法操作性能图
这里有几个有趣的笔记:
-
Metal Performance Shaders 库的速度惊人
-
天真的方法和转置天真的方法之间没有实际性能差异
对于第二点:这与上述 Go 基础的实现对比性能特性不同。结果表明,对 CPU 有利的缓存访问模式在 GPU 上效果不同,尤其是它们的 SIMD 组(或 warps)访问内存的方式。见 GPU 利用率以便比较:
活动监视器 GPU 历史可视化 — 仅 GPU 操作
现在仅查看OpenBLAS
和MPS
— 这两种最快的方法:
OpenBLAS 与 Apple 的 Metal Performance Shaders MPSMatrixMultiplication API 性能对比图
在大约 35M 元素时,OpenBLAS
实现开始下降,而 MPS
则保持稳定。这里的差异相当显著,后者完成相同的 35M 元素矩阵乘法操作的时间少于 15%。可以合理地假设,随着矩阵规模的增长,这种差异会继续扩大。
当然,这两种方法之间可能存在算法差异,因此这不是一个公平的 CPU 与 GPU 比较。如果我绘制我两个手工编码实现的性能差异图,它看起来是这样的:
我的 MSL 编写的矩阵乘法代码与 Go 编写的代码的性能比率图
这意味着,基于 MSL 的简单实现完成 5M 元素的乘法操作仅需我 Go 实现的 1%时间,而这种比率似乎随着时间的推移对 GPU 更有利。