TowardsDataScience 2023 博客中文翻译(三百五十七)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

使用 Gumbel Softmax 的离散分布变分自编码器(VAE)

原文:towardsdatascience.com/variational-autoencoder-vae-with-discrete-distribution-using-gumbel-softmax-b3f749b3417e

理论与 PyTorch 实现

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Alexey Kravets

·发表于 Towards Data Science ·17 分钟阅读·2023 年 8 月 9 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

unsplash.com/photos/sbVu5zitZt0

由于这篇文章将会很详尽,我将为读者提供一个索引以便更好地导航:

  1. 介绍

  2. 变分自编码器(VAEs)简要介绍

  3. Kullback–Leibler (KL) 散度

  4. VAE 损失

  5. 重参数化技巧

  6. 从分类分布中采样与 Gumbel-Max 技巧

  7. 实现

介绍

生成模型如今变得非常流行,这要归功于它们能够通过学习和捕捉训练数据的基础概率分布来生成具有固有变异性的全新样本。

我们可以识别出两大主要的生成模型家族:生成对抗网络(GANs)、变分自编码器(VAEs)和扩散模型。在这篇文章中,我们将深入探讨 VAEs,特别是关注具有分类潜在空间的 VAEs。

变分自编码器(VAEs)简要介绍

变分自编码器(VAEs)是一种用于无监督机器学习的深度神经网络。它们属于自编码器家族,自编码器是设计用于通过压缩然后重构数据来学习高效数据表示的神经网络。

VAEs 的主要思想是学习潜在空间中的数据概率分布。这个潜在空间是输入数据的低维表示,其中每个点对应于一个特定的数据样本。例如,给定一个维度为 3 的潜在空间中的向量,我们可以认为第一个维度表示眼睛的形状,第二个维度表示胡须的多少,第三个维度表示生成的人的脸上的肤色。

VAEs 具有两个关键组件:

  1. 编码器:编码器网络接收输入数据,并将其映射到潜在空间中一个概率分布的参数(通常是高斯分布)。编码器不是直接在潜在空间中产生一个单一的点,而是输出分布的均值和方差。

    输出一个分布而不是潜在空间中的单个点作为正则化,这样当我们在潜在空间中选择一个随机点时,解码这个数据点后我们总能得到一个有意义的图像。

  2. 解码器:解码器网络从潜在空间中采样,并将其重建回原始数据空间。它使用类似于编码器的过程但相反,将潜在表示转换回数据空间。

让我们来说明这个过程:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

VAE 编码器-解码器图示,图片由作者提供 (1)

其中 x 是输入图像,z 是潜在空间中的一个采样向量,μ 和 σ 是潜在空间参数,其中μ是均值向量,σ是标准差向量。最后,x’ 是从潜在变量重建的图像。

我们希望这个潜在空间具备两个特性

  1. 潜在空间中接近的点应输出相似的图片。

  2. 从潜在空间中采样的任何点都应产生与训练数据相似的东西,即,如果我们训练的是人的面孔,它不应产生任何有 3 只眼睛或 4 只耳朵的面孔。

要实现第一个目标,我们需要让编码器将相似的图片映射到接近的潜在空间参数,然后解码器将它们映射回看起来相似的图片——这通过图像重建损失来实现。为了实现第二个目标,我们需要添加一个正则化项。这个正则化项是编码器返回的参数与均值为 0、方差为 1 的标准高斯分布——N(0,1)之间的 Kullback–Leibler(KL)散度。通过保持潜在空间接近 N(0,1),我们确保编码器不会为每个样本产生相距过远的分布(通过使均值非常不同和方差非常小),这会导致过拟合。如果发生这种情况,在潜在空间中采样一个远离任何训练点的值将无法产生有意义的图像。

Kullback–Leibler (KL) 散度

KL 散度,简称 Kullback-Leibler 散度,是衡量一个概率分布与另一个分布的不同程度的指标。给定两个概率分布 P(X)和 Q(X),其中 X 是随机变量,KL(Q || P)表示从 Q 到 P 的 KL 散度,是一个非负值,表示使用 Q 来近似 P 时信息的丧失程度。它不是对称度量,即 KL(Q || P)通常不同于 KL(P || Q)。连续和离散变量的公式如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

KL 散度,离散情况 (2)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

KL 散度,连续情况 (3)

但是这个公式的直觉是什么?它是如何推导出来的?

假设我们有一个包含 从分布 P(x) 中抽样得到的观察数据 — {x1, x2, …, xn} 的数据集,我们想要比较这些观察数据在真实分布 P(x) 和近似分布 Q(x) 下的生成可能性。在概率分布下观察整个数据集的可能性可以通过每个观察数据的个体概率的乘积来计算:

  • 在 P(x) 下的数据似然:L_P = P(x1) * P(x2) * … * P(xn)

  • 在 Q(x) 下的数据似然:L_Q = Q(x1) * Q(x2) * … * Q(xn)

通过比较比率 L_P / L_Q,我们可以比较它们的相似度。如果比率接近 1,则近似分布与真实分布相似;而如果这个比率很高,意味着根据近似分布从真实分布中抽样的序列的可能性显著较低,则两个分布不同。显然,这个比率不能小于 1,因为数据是从真实分布 P(x) 中抽样的。

对这个比率两边取对数,我们得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

比率 L_P / L_Q 的对数 (4)

现在,如果我们对数据集上真实分布 P(x) 的对数进行期望计算,我们得到期望对数似然比:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

比率 L_P / L_Q 的期望对数 (5)

这不过是 KL 散度!作为额外的内容,让我们深入了解 KL 散度如何与交叉熵相关联。细心的读者可能已经认识到:

公式中的 Σ P(x) * log(P(x)) 是 P(x) 的熵的负值,而 Σ P(x) * log(Q(x)) 是 P(x) 和 Q(x) 之间的交叉熵。所以,我们有:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

KL 散度作为交叉熵和熵之间的差异 (6)

现在,真实数据分布 P(x) 的熵是一个不依赖于近似分布 Q(x) 的常数。因此,最小化期望对数似然比 E[log(L_P / L_Q)] 等同于最小化真实分布 P(x) 和近似分布 Q(x) 之间的交叉熵 H(P, Q)

VAE 损失

在“变分自编码器(VAEs)简介”部分,我们提供了关于如何优化 VAEs 的一些直觉,并且潜在空间应该满足 2 个属性,以在从潜在空间抽样 任何 随机数据点时生成有意义的图像,这由重构损失和 KL 散度正则化强制执行。在本节中,我们将深入探讨这两个方面的数学。

给定一些从潜在变量 z 生成的训练数据 x = {x1, x2, …, xn},我们的目标是最大化这些数据的似然,以训练我们的变分自编码器模型。数据的似然由以下公式给出:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据似然 (7)

我们将潜在变量积分出去,因为它是不可观察的。

现在,p(x|z)可以通过解码器网络轻松计算,而p(z)被假定为高斯分布。然而,我们面临一个大问题——在有限的时间内计算这个积分实际上是不可能的,因为我们需要在所有潜在空间上进行积分。因此,我们使用贝叶斯规则以不同的方式计算我们的p(x)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(x)的贝叶斯规则 (8)

现在,p(z|x)是难以处理的。p(zx)的难处理性源于我们需要对每个数据点x的所有可能值z计算p(zx)的积分。形式上,这个积分可以表示为:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(z|x)的贝叶斯规则 (9)

由于这种难处理性,在 VAE 中,我们 resort 使用一个近似分布(在我们情况下是高斯分布)q(zx),这更容易处理且计算上可行。这个近似分布是通过编码器网络学习的:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(z|x)的近似分布 (10)

现在我们已经准备好所有元素,可以用解码器网络计算的p(x|z)来近似p(x),以及由编码器q近似的p(z|x)。对方程 9 的两边应用对数并进行一些代数变换,我们得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(x)的对数概率 (11)

现在,对两边应用期望算子:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(x)的对数概率的期望 (12)

这等于:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

p(x)的对数概率的期望 — 不同形式 (13)

在上图中,第一个项是重建项,即我们的模型从潜在变量重建训练数据x的效果。第二个项是z的先验——*N(0,1)*与来自编码器的样本之间的 KL 散度。第三个项是编码器和解码器后验之间的 KL 散度,这是难以处理的。如果我们忽略最后一项,我们得到数据似然的下界,因为 KL 总是≥0,这称为证据下界(ELBO)。因此,我们最终得到:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

证据下界(ELBO) (14)

因此,在训练 VAE 时,我们尝试最大化 ELBO,这等同于最大化我们数据的概率。

重参数化技巧

让我们先了解重参数化技巧,因为理解这一点对于理解 Gumbel-Softmax 使用类似的东西至关重要。

正如我们在第一部分中所看到的,编码器输出正态分布的均值和方差参数,然后我们从具有这些参数的正态变量中抽取一个随机向量,并通过解码器传递这个潜在向量以重建初始图像。为了最小化重建损失并使网络学习,我们需要从这个重建损失中进行反向传播,但存在一个问题——潜在变量 Z,即从高斯中抽样的变量,是不可微分的。想一想——你如何对一个样本进行求导?因此,我们不能使用反向传播。解决方案是使用重新参数化技巧。

为了使随机变量Z可微分,我们需要将其分为一个可微分的确定性部分和一个不可微分的随机部分。任何来自随机正态分布的样本 Z ~ N(μ, σ) 可以写成:

Z = μ + N(0,1) = σ = μ + ε σ 其中 ε ~ N(0,1)

所以μ和σ是确定的,我们可以对其进行反向传播,而ε是随机部分,我们不能对其进行反向传播。因此,我们可以对μ和σ进行求导:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

随机变量 Z 对均值和标准差的导数 (15)

…以学习潜在空间中正态分布的均值和标准差

从分类分布中进行抽样 & Gumbel-Max 技巧

如果我们希望将潜在空间建模为分类分布,而不是具有连续潜在分布的情况,会怎么样?你会问,为什么有人要这样做?好吧,离散表示在许多情况下是有用的,例如在强化学习问题中采样离散动作、生成离散文本标记等等。

那么我们如何从分类分布中进行抽样并学习其参数,使其可微分?我们可以重复使用重新参数化技巧的想法,将其调整到这个问题上!

首先,让我们尝试理解如何从分类分布中进行抽样。假设我们有以下概率向量:

theta = [0.05, 0.25, 0.7] 代表以下类别——[红色, 蓝色, 白色]。为了进行抽样,我们需要一个随机源,通常使用 0 到 1 之间的均匀分布。请记住,在均匀分布中,0 到 1 之间的抽样是同样可能的。因此,我们从均匀分布中抽样,并将其转换为分类分布,我们可以根据我们的概率theta进行切片。我们定义一个累计和向量 theta_cum = [0.05, 0.3, 1],它代表下面的图。

给定来自均匀分布的样本,例如 0.31,我们选择累计概率超过生成随机数的类别。

argmax(theta_cum ≥ U(0,1)) = argmax([False, True, True]) 这对应于示例中的“蓝色”,因为 argmax 选择第一个对应于True的索引。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

累积概率分类分布,图像作者提供 (16)

现在,我们可以用另一种方式从分类分布中采样 — 不是使用均匀分布,而是使用定义为的 Gumbel 分布:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Gumbel 分布 (17)

假设我们有一个 (log) 概率向量,如之前所示

theta = [log(alpha1), log(alpha2), log(alpha3)], 这些是我们希望通过反向传播估计的参数。为了使用反向传播,我们复现了重新参数化技巧部分中所做的 — 拥有一个确定性部分,即作为我们参数的类别对数概率和一个由随机标准 Gumbel 噪声给出的随机部分。

要使用 Gumbel 从分类分布中采样,我们可以按以下步骤操作:

argmax([log(alpha1) + G1, log(alpha2) + G2, log(alpha3) + G3])

其中 theta 是确定性部分,Gumbel 噪声是随机部分。我们可以通过这两个部分的和进行传播。然而,argmax 不是一个可微的 函数。因此,我们用具有温度 τSoftmax 替代它,以使一切可微。于是,类别的概率 yi 变成:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 Gumbel-Softmax 分布采样 (18)

τ 会使 Softmax 更接近 argmax,而较高的 τ 会使其更接近均匀分布。实际上,当我们将温度降低到如 1e-05 这样的非常低的值时,概率几乎像选择 argmax,即我们基本上是从离散分布中采样。

实现

我们以 MNIST 数据集为例 (许可:公共领域 / 来源:yann.lecun.com/exdb/mnist/,也可以在 torchvision.datasets 中找到),目标是学习一个生成模型,假设图像是二值的。潜在变量大小假设为 20,包含 10 个分类变量(10 个数字)。先验是一个包含 10 个类别的分类分布,均匀先验概率为 1/10。

1. 首先实现 Gumbel softmax 函数 gumbel_softmax。如前所述,这由每个类别的对数概率(logits)之和加上 Gumbel 分布给出的随机性构成。在 3 个类别的情况下,我们有:

softmax([log(alpha1) + G1, log(alpha2) + G2, log(alpha3) + G3]) 使用 Softmax 替代 argmax 以实现可微性。

def sample_gumbel(shape, eps=1e-20):
    # sample from a uniform distribution
    U = torch.rand(shape)
    if is_cuda:
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
    y = gumbel_softmax_sample(logits, temperature)

    if not hard:
        return y.view(-1, latent_dim * categorical_dim)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # skip the gradient of y_hard
    y_hard = (y_hard - y).detach() + y 
    return y_hard.view(-1, latent_dim * categorical_dim)

附注:

我们可以注意到 gambel_softmax 函数中的一个小技巧——如果参数 hard 为 True,我们使用 argmax 而不是 softmax。在评估时,我们通常使用 argmax(这是我们在 model.sample_img 中所做的),而在训练期间,我们使用 softmax,因为 argmax 操作是不可微分的。然而,这不是必需的,我们也可以在训练期间使用 argmax,通过 跳过 y_hardgumbel_softmax 函数中的梯度并对 softmax y 进行微分。一个简短的示例会有所阐明:

skip_d = False

a = torch.Tensor([1])
a.requires_grad = True

b = torch.Tensor([2])
b.requires_grad = True

c = 2 * (a + b)

if skip_d:
    d = c ** 2
    d = (d - c).detach() + c
else:
    d = c ** 2

f = d * 4
f.retain_grad()
d.retain_grad()
c.retain_grad()

loss = f * 3
loss.backward()

print(loss)
print(a.grad, b.grad, c.grad, d.grad, f.grad)
# Loss value: tensor([432.])
# (tensor([288.]), tensor([288.]), tensor([144.]), tensor([12.]), tensor([3.]))

# Running the same with skip_d = True we get:
# tensor([432.])
# (tensor([24.]), tensor([24.]), tensor([12.]), tensor([12.]), tensor([3.]))

skip_d = False 时,我们有:

dl/df = 3

dl/dd = dl/df * df/dd = (3) * (4) = 12

dl/dc = dl/df * df/dd * dd/dc = (3) * (4) * (2 * c) = 144

dl/da = dl/df * df/dd * dd/dc * dc/da = (3) * (4) * (2 * c) * (2) = 288

dl/db = dl/df * df/dd * dd/dc * dc/db = (3) * (4) * (2 * c) * (2) = 288

skip_d = True: dl/df = 3

dl/dd = dl/df * df/dd = (3) * (4) = 12

dl/dc = dl/df * df/dd = (3) * (4) = 12

从现在开始,我们跳过 dd/dc,即我们将梯度 dl/dc = dl/dd。

dl/da = dl/df * df/dd * dc/da = (3) * (4) * (2) = 24

dl/db = dl/df * df/dd * dc/db = (3) * (4) * (2) = 24

在上述示例中,loss 的值是相同的,但梯度却不同。在我们的模型中,值不会相同,因为当 hard=True 时我们将 latent_z 设置为 y_hard,而当 hard=False 时设置为 softmax y,但 y 的反向传播梯度在两种情况下都是相同的。

2. 现在让我们定义我们的 VAE 模型。编码器将图像映射到分类变量的对数概率,由 3 个线性层和 ReLU 非线性层组成。解码器将潜在空间向量映射回图像空间,由 3 个线性层、2 个 ReLU 非线性层和最后一个 sigmoid 非线性层组成。Sigmoid 直接输出概率,这很方便,因为我们将 MNIST 图像(每个像素)建模为 Bernoulli 变量。

class VAE_model(nn.Module):
    def __init__(self):
        super(VAE_model, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

在前向函数中,我们首先通过编码器计算 logits,使用 Gumbel Softmax:

logits_z = self.encode(data.view(-1, 
logits_z = logits_z.view(-1, latent_dim, categorical_dim)
latent_z = gumbel_softmax(logits_z, temp)
latent_z = latent_z.view(-1, latent_dim * categorical_dim)

然后,我们解码它们,给出每个像素的 Bernoulli 概率。我们可以从中采样,以生成具有概率参数的图像:

probs_x = self.decode(latent_z)
# we assumed distribution of the data is Bernoulli
dist_x = torch.distributions.Bernoulli(probs=probs_x, validate_args=False)

接下来,让我们计算 ELBO 损失

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

EBLO 损失 (19)

对于第一项(重建损失),我们需要计算在我们估计的模型下真实数据的对数似然,这告诉我们真实图像在我们模型下的可能性。我们之前从解码器计算了 dist_x,这就是我们用来估计该概率的:

# reconstruction loss - log probabilities of the data
rec_loss = dist_x.log_prob(data.view(-1, 784)).sum(dim=-1)

然后,我们计算由 KL 散度给出的正则化,该散度是由 10 类别的分类分布与均匀先验概率 1/10 之间的差异和潜在空间的分类参数给出的:

# KL divergence loss
KL = (posterior_distrib.probs * (logits_z_log - prior_distrib.probs.log())).view(-1, latent_dim * categorical_dim).sum(dim=-1)

包括训练函数和绘图工具在内的完整代码如下:

torch.manual_seed(0)

batch_size = 100
temperature = 1.0
seed = 0
log_interval = 10
hard = False
is_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
if is_cuda:
    torch.cuda.manual_seed(seed)
kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data/MNIST', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

def sample_gumbel(shape, eps=1e-20):
    # sample from a uniform distribution
    U = torch.rand(shape)
    if is_cuda:
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
    y = gumbel_softmax_sample(logits, temperature)

    if not hard:
        return y.view(-1, latent_dim * categorical_dim)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # skip the gradient of y_hard
    y_hard = (y_hard - y).detach() + y 
    return y_hard.view(-1, latent_dim * categorical_dim)

class VAE_model(nn.Module):
    def __init__(self):
        super(VAE_model, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def sample_img(self, img, temp, random=True):
        # evaluation
        with torch.no_grad():
            logits_z = self.encode(img.view(-1, 784))
            logits_z = logits_z.view(-1, latent_dim, categorical_dim)
            if random:
                latent_z = gumbel_softmax(logits_z, temp, True)
            else:
                latent_z = logits_z.view(-1, latent_dim * categorical_dim)
            logits_x = self.decode(latent_z)
            # probs instead of logits because we have sigmoid activation 
            # in the decoder
            dist_x = torch.distributions.Bernoulli(probs=logits_x)
            sampled_img = dist_x.sample()
        return sampled_img

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, data, temp, hard):
        logits_z = self.encode(data.view(-1, 784))
        logits_z = logits_z.view(-1, latent_dim, categorical_dim)

        # estimated posterior probabiity coefficients
        probs_z = F.softmax(logits_z, dim=-1)
        posterior_distrib = torch.distributions.Categorical(probs=probs_z)
        # categorical prior
        probs_prior = torch.ones_like(logits_z)/categorical_dim
        prior_distrib = torch.distributions.Categorical(probs=probs_prior)

        latent_z = gumbel_softmax(logits_z, temp)
        latent_z = latent_z.view(-1, latent_dim * categorical_dim)

        probs_x = self.decode(latent_z)
        # we assumed distribution of the data is Bernoulli
        dist_x = torch.distributions.Bernoulli(probs=probs_x, validate_args=False)
        # Losses
        # reconstruction loss - log probabilities of the data
        rec_loss = dist_x.log_prob(data.view(-1, 784)).sum(dim=-1)
        logits_z_log = F.log_softmax(logits_z, dim=-1)
        # KL divergence loss
        KL = (posterior_distrib.probs * (logits_z_log - prior_distrib.probs.log())).view(-1, latent_dim * categorical_dim).sum(dim=-1)
        elbo = rec_loss - KL
        loss = -elbo.mean()
        return loss

def train(epoch, model, optimizer):
    model.train()
    train_loss = 0
    temp = temperature
    for batch_idx, (data, _) in enumerate(train_loader):
        if is_cuda:
            data = data.cuda()
        optimizer.zero_grad()
        loss = model(data, temp, hard)
        loss.backward()
        train_loss += loss.item() * len(data)
        optimizer.step()
        if batch_idx % 100 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100\. * batch_idx / len(train_loader),
                       loss.item()))
            print("Temperature : ", temp)

            sampled = model.sample_img(data[0].view(-1, 28*28), temp).view(28, 28).detach().cpu()
            fig, axs = plt.subplots(1, 2, figsize=(6,4))
            fig.suptitle('Reconstructed vs Real')
            axs[0].imshow(sampled.reshape(28,28))
            axs[0].axis('off')
            axs[1].imshow(data[0].reshape(28,28).detach().cpu())
            axs[1].axis('off')
            plt.show()
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

### Train
temp_min = 0.5
ANNEAL_RATE = 0.00003
latent_dim = 20
categorical_dim = 10
my_model = VAE_model()
my_model.to('cuda:0')
optimizer = optim.Adam(my_model.parameters(), lr=1e-3)
for epoch in range(3):
    train(epoch, my_model, optimizer)

在训练开始时,我们有较高的损失和糟糕的重建效果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

重建与真实,训练的开始,作者图像(20)

在训练接近尾声时,我们得到了相当好的重建效果和显著降低的损失。显然,我们可以继续训练更长时间,以获得更好的重建效果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

重建与真实,训练的结束,作者图像(21)

结论

在这篇文章中,我们发现 VAE 也可以用分类潜在空间来建模。当我们想在强化学习问题中采样离散动作或生成文本的离散标记时,这非常有用。在尝试对 argmax 操作进行微分以选择分类变量时,我们遇到了一个问题,因为 argmax 是不可微分的,但 thanks to Gumbel Softmax 和重新参数化技巧的启发,解决了这个问题。

[## 使用我的推荐链接加入 Medium — Alexey Kravets

作为 Medium 的会员,你的部分会员费会分配给你阅读的作者,并且你可以完全访问所有故事…

medium.com](https://medium.com/@alexml0123/membership?source=post_page-----b3f749b3417e--------------------------------)

参考资料

[1] jhui.github.io/2017/03/06/Variational-autoencoders/

[2] blog.evjang.com/2016/11/tutorial-categorical-variational.html

[3] www.youtube.com/watch?v=Q3HU2vEhD5Y&list=PL5-TkQAfAZFbzxjBHtzdVCWE0Zbhomg7r&index=19

[4] arxiv.org/pdf/1611.01144.pdf

[5] github.com/shaabhishek/gumbel-softmax-pytorch

变分推断:基础知识

原文:towardsdatascience.com/variational-inference-the-basics-f70ac511bcea?source=collection_archive---------1-----------------------#2023-06-16

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Hylke C. Donker

·

关注 发表在 数据科学前沿 ·9 分钟阅读·2023 年 6 月 16 日

我们生活在量化的时代。然而,严格的量化比说起来容易做起来难。在生物等复杂系统中,数据的收集可能既困难又昂贵。而在医疗和金融等高风险应用中,考虑不确定性至关重要。变分推断——一种处于人工智能研究前沿的方法——是一种解决这些问题的方式。

本教程将介绍变分推断的基础知识:何时、为何以及如何使用变分推断。

变分推断何时有用?

变分推断在以下三个密切相关的用例中非常有吸引力:

1. 如果你拥有少量数据(即观察值较少),

2. 如果你关心不确定性,

3. 用于生成建模。

我们将在我们的实例中讨论每种用例。

1. 变分推断与少量数据

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:变分推断允许你在领域知识与样本信息之间进行权衡。图像由作者提供。

有时候,数据收集是昂贵的。例如,DNA 或 RNA 测量每次观察可能会花费几千欧元。在这种情况下,你可以用领域知识代替额外的样本进行硬编码。变分推断可以帮助你在收集更多样本时系统性地“减少”领域知识,并更多地依赖于数据(见图 1)。

2. 不确定性的变分推断

对于安全关键型应用,如金融和医疗保健,确定性很重要。不确定性可以影响模型的所有方面,最明显的是预测输出。模型的参数(例如权重和偏置)则不那么明显。你可以将参数赋予一个分布,使其变得模糊,而不是通常的数字数组——权重和偏置。变分推断允许你推断出合理值的范围。

3. 用于生成建模的变分推断

生成模型提供了数据生成的完整规范。例如,如何生成猫或狗的图像。通常,有一个潜在表示 z 具有语义意义(例如,z 描述了一只暹罗猫)。通过一系列(非线性)变换和采样步骤,z 被转换为实际的图像 x(例如,暹罗猫的像素值)。变分推断是一种推断和从潜在语义空间 z 进行采样的方法。一个著名的例子是 变分自编码器

变分推断是什么?

从本质上讲,变分推断是一种贝叶斯方法[1]。从贝叶斯的角度来看,你仍然让机器像往常一样从数据中学习。不同的是,你给模型一个提示(先验),并允许解(后验)变得更加模糊。更具体地说,假设你有一个训练集 X = [x₁, x₂,…,x]ᵗ,共有 m 个样本。我们使用贝叶斯定理:

p(Θ|X) = p(X|Θ)p(Θ) /p(X*),*

推断一个范围——一个分布——的解决方案Θ。这与传统的机器学习方法形成对比,后者通过最小化损失 ℒ(Θ, X) = ln p(X|Θ)来寻找一个特定的解决方案Θ*。贝叶斯推断的核心在于找到一种方法来确定 p(Θ|X): 给定训练集 X 的参数 Θ后验分布。一般来说,这是一个困难的问题。实际上,有两种方法用于求解 p(Θ|X): (i) 使用模拟(马尔科夫链蒙特卡罗)或 (ii) 通过优化。

变分推断涉及选项 (ii)。

证据下界(ELBO)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2: 变分推断的示意图。 我们寻找一个接近 p(Θ|X) 的分布 q(Θ)。图像由作者提供。

变分推断的核心思想是寻找一个分布 q(Θ),作为 p(Θ|X) 的替代(代理)。然后我们尝试通过改变 Φ 的值,使 q[Θ|Φ(X)] 看起来类似于 p(Θ|X)(见图 2)。这通过最大化证据下界(ELBO)来完成:

(Φ) = E[ln p(X,Θ) — ln q(Θ|Φ)],

其中期望 E[·] 是对 q(Θ|Φ) 进行的。 (注意 Φ 隐式依赖于数据集 X,但为了方便书写,我们将忽略这一显式依赖。)

进行基于梯度的优化时,乍一看,我们必须在对 Φ 求导时小心,因为 E[·] 对 q(Θ|Φ) 的依赖。幸运的是,像 JAX 这样的自动梯度包支持重参数化技巧 [2],允许你直接从随机样本(例如伽马分布的样本)中进行求导,而无需依赖高方差的黑箱变分方法 [3]。简而言之:使用一批 [Θ₁, Θ₂,…] ~ q(Θ|Φ) 来估计 ∇ℒ(Φ),然后让你的自动梯度包处理细节。

从头开始进行变分推断

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3: 来自 scikit-learn 的手写“零”的示例图像。图像由作者提供。

为了巩固我们的理解,让我们从头开始使用 JAX 实现变分推断。在这个例子中,你将对来自 scikit-learn 的手写数字进行生成模型训练。你可以按照 Colab notebook 进行操作。

为了简单起见,我们将只分析数字“零”。

from sklearn import datasets

digits = datasets.load_digits()
is_zero = digits.target == 0
X_train = digits.images[is_zero]

# Flatten image grid to a vector.
n_pixels = 64  # 8-by-8.
X_train = X_train.reshape((-1, n_pixels))

每张图像是一个 8x8 的离散像素值数组,范围从 0 到 16。由于像素是计数数据,我们使用 泊松分布 和伽马 先验 来对像素 x 进行建模,其中 Θ 是速率参数。速率 Θ 决定了像素的平均强度。因此,联合分布 为:

p(x,Θ) = 泊松(x|Θ)伽马(Θ|a, b),

其中 ab伽马分布 的形状参数和速率参数。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4: 使用数字“零”的领域知识作为先验。图像由作者提供。

先验——在这种情况下是 Gamma(Θ|a, b)——是你注入领域知识的地方(用例 1)。例如,你可能对“平均”的数字零是什么样子有一些想法(见图 4)。你可以使用这些 a priori 信息来指导你选择 ab。为了使用图 4 作为先验信息——我们称之为 x₀——并将其重要性作为两个例子来加权,然后设置 a = 2x₀; b = 2。

用 Python 写出来的代码如下:

import jax.numpy as jnp
import jax.scipy as jsp

# Hyperparameters of the model.
a = 2\. * x_domain_knowledge
b = 2.

def log_joint(θ):
  log_likelihood = jnp.sum(jsp.stats.gamma.logpdf(θ, a, scale=1./b))
  log_likelihood += jnp.sum(jsp.stats.poisson.logpmf(X_train, θ))
  return log_likelihood

请注意,我们使用了 JAX 实现的 numpy 和 scipy,以便我们可以进行求导。

接下来,我们需要选择一个替代分布 q(Θ|Φ)。提醒一下,我们的目标是改变 Φ 使得替代分布 q(Θ|Φ) 匹配 p(Θ|X)。因此,q(Θ) 的选择决定了近似的水平(我们在上下文允许的地方省略对 Φ 的依赖)。为了说明问题,我们选择一个由 gamma 分布组成的变分分布:

q(Θ|Φ) = Gamma(Θ|α,β),

其中我们使用了简写 Φ = {α,β}。

接下来,为了实现证据下界 (Φ) = E[ln p(X,Θ) — ln q(Θ|Φ)],首先写下期望括号内的项:

@partial(vmap, in_axes=(0, None, None))
def evidence_lower_bound(θ_i, alpha, inv_beta):
  elbo = log_joint(θ_i) - jnp.sum(jsp.stats.gamma.logpdf(θ_i, alpha, scale=inv_beta))
  return elbo

在这里,我们使用了 JAX 的 vmap 来矢量化函数,以便我们可以在批量 [Θ₁, Θ₂,…,Θ₁₂₈]ᵗ 上运行它。

为了完成 (Φ) 的实现,我们对变分分布 Θ ~ q(Θ) 的实现进行平均。

def loss(Φ: dict, key):
  """Stochastic estimate of evidence lower bound."""
  alpha = jnp.exp(Φ['log_alpha'])
  inv_beta = jnp.exp(-Φ['log_beta'])

  # Sample a batch from variational distribution q.
  batch_size = 128
  batch_shape = [batch_size, n_pixels]
  θ_samples = random.gamma(key, alpha , shape=batch_shape) * inv_beta

  # Compute Monte Carlo estimate of evidence lower bound.
  elbo_loss = jnp.mean(evidence_lower_bound(θ_samples, alpha, inv_beta))

  # Turn elbo into a loss.
  return -elbo_loss

关于这些参数,有几个要点需要注意:

  • 我们将 Φ 打包为一个字典(或者技术上说是一个 pytree),其中包含 ln(α), 和 ln(β)。这个技巧确保了 α>0 和 β>0——这是 gamma 分布施加的一个要求——在优化过程中。

  • loss 是 ELBO 的随机估计。在 JAX 中,我们每次采样时都需要一个新的伪随机数生成器(PRNG)key。在这种情况下,我们使用 key 采样 [Θ₁, Θ₂,…,Θ₁₂₈]ᵗ。

这完成了模型 p(x,Θ)、变分分布 q(Θ) 和损失 (Φ) 的规范说明。

模型训练

接下来,我们通过改变 Φ = {α,β} 来最小化损失 (Φ),使 q(Θ|Φ) 匹配后验 p(Θ|X)。怎么做?使用传统的梯度下降法!为了方便,我们使用了 Optax 中的 Adam 优化器,并用先验 α = aβ = b 初始化参数 [记住,先验是 Gamma(Θ|a, b) 并且编码了我们的领域知识]。

# Initialise parameters using prior.
Φ = {
    'log_alpha': jnp.log(a),
    'log_beta': jnp.full(fill_value=jnp.log(b), shape=[n_pixels]),
}

loss_val_grad = jit(jax.value_and_grad(loss))
optimiser = optax.adam(learning_rate=0.2)
opt_state = optimiser.init(Φ)

在这里,我们使用value_and_grad同时评估 ELBO 及其导数。这对于监控收敛非常方便!然后我们使用jit) 即时编译结果函数,使其更加高效。

最终,我们将训练模型 5000 步。由于损失是随机的,对于每次评估,我们需要为其提供一个伪随机数生成器(PRNG)密钥。我们通过分配 5000 个密钥来实现这一点,使用随机拆分

n_iter = 5_000
keys = random.split(random.PRNGKey(42), num=n_iter)

for i, key in enumerate(keys):
  elbo, grads = loss_val_grad(Φ, key)
  updates, opt_state = optimiser.update(grads, opt_state)
  Φ = optax.apply_updates(Φ, updates)

恭喜!你已经成功地使用变分推断训练了第一个模型!

你可以通过这里在 Colab 上访问包含完整代码的笔记本。

结果

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:变分分布与精确后验分布的比较。图片来源:作者。

让我们退一步欣赏一下我们所构建的(图 5)。对于每个像素,替代的 q(Θ) 描述了关于平均像素强度的 uncertainty(用例 2)。特别地,我们选择的 q(Θ) 捕捉了两个互补的元素:

  • 典型的像素强度。

  • 图像间强度的变化程度(变异性)。

结果表明,我们选择的联合分布 p(x,Θ) 有一个精确的解:

p(Θ|X) = Gamma(Θ|a + Σxᵢ, m + b),

其中 m 是训练集中样本的数量 X。在这里,我们可以明确看到领域知识——以 ab 形式体现——在我们收集更多样本 xᵢ 时如何被调节。

我们可以轻松比较学到的形状 α 和速率 β 与真实值 a + Σxᵢ 和 m + b。在图 5 中,我们比较了两个特定像素的分布——q(Θ|Φ) 与 p(Θ|X) —。结果令人惊叹,完美匹配!

附加:生成合成图像

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6:使用变分推断生成的合成图像。图片来源:作者。

变分推断非常适合生成建模(用例 3)。有了替代后验 q(Θ),生成新的合成图像是简单的。两步过程是:

  • 样本像素强度 Θ ~ q(Θ).
# Extract parameters of q.
alpha = jnp.exp(Φ['log_alpha'])
inv_beta = jnp.exp(-Φ['log_beta'])

# 1) Generate pixel-level intensities for 10 images.
key_θ, key_x = random.split(key)
m_new_images = 10
new_batch_shape = [m_new_images, n_pixels]
θ_samples = random.gamma(key_θ, alpha , shape=new_batch_shape) * inv_beta
  • 使用 x ~ 泊松(x|Θ) 采样图像。
# 2) Sample image from intensities.
X_synthetic = random.poisson(key_x, θ_samples)

你可以在图 6 中看到结果。请注意,“零”字符的锐度稍逊于预期。这是我们建模假设的一部分:我们将像素建模为相互独立而非相关。要考虑像素相关性,你可以扩展模型以聚类像素强度:这称为泊松分解[4]。

摘要

在本教程中,我们介绍了变分推断的基础知识,并将其应用于一个玩具示例:学习手写数字零。得益于自动求导,从头实现变分推断只需几行 Python 代码。

变分推断在数据较少的情况下特别强大。我们展示了如何融合和权衡领域知识与数据中的信息。推断的替代分布 q(Θ) 提供了模型参数的“模糊”表示,而不是一个固定值。如果你处于一个不确定性重要的高风险应用中,这种方法是理想的!最后,我们展示了生成模型。只要你能从 q(Θ) 中采样,生成合成样本就很容易。

总结来说,通过利用变分推断的力量,我们可以解决复杂问题,使我们能够做出明智的决策、量化不确定性,并最终释放数据科学的真正潜力。

致谢

我想感谢 Dorien Neijzen 和 Martin Banchero 的校对。

参考文献:

[1] Blei, David M., Alp Kucukelbir, 和 Jon D. McAuliffe. “变分推断:统计学家的综述.美国统计协会杂志 112.518 (2017): 859–877.

[2] Figurnov, Mikhail, Shakir Mohamed, 和 Andriy Mnih. “隐式重新参数化梯度.” 《神经信息处理系统进展》 31 (2018).

[3] Ranganath, Rajesh, Sean Gerrish, 和 David Blei. “黑箱变分推断.” 人工智能与统计学. PMLR, 2014.

[4] Gopalan, Prem, Jake M. Hofman, 和 David M. Blei. “可扩展推荐与泊松分解.arXiv 预印本 arXiv:1311.1704 (2013).

机器学习中的各种部署类型

原文:towardsdatascience.com/various-types-of-deployment-in-machine-learning-b503017e6bae?source=collection_archive---------17-----------------------#2023-01-06

学习各种部署策略,以成功构建端到端的机器学习管道

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Suhas Maddali

·

关注 发表在 Towards Data Science · 6 分钟阅读 · 2023 年 1 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Jiawei Zhao 提供,来源于 Unsplash

机器学习有很大的范围需求,尤其是在最新的自动驾驶行业中,驾驶员通过 AI 的帮助获得辅助。此外,还有其他行业受益,如制药行业,它们开始使用 AI 来开发有趣的产品,这些产品本质上用于预测性医疗。其他行业还包括电子商务,在这些行业中,最相关的产品被推荐给用户,提高了客户购买产品的倾向。

通常,关于机器学习的能力以及它们如何在大量任务中取得最先进结果以实现高精度的讨论很多。然而,最少讨论的话题是如何在实时中进行部署,以及在生产阶段进行持续监控和评估。这是许多在线机器学习和深度学习课程中被忽视的关键因素。一个机器学习模型只有在我们能够将其作为应用提供给最终用户时才算优秀。

查看所有依赖机器学习的不同类型行业会让很多人倾向于这个领域,并使公司取得成功。有大量的在线课程突出机器学习的关键领域,如特征工程、数据准备、模型构建、超参数调整等等。然而,这些课程中缺少一个重要的元素:部署。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Mediamodifier提供,来源于Unsplash

在这篇文章中,我们将详细了解各种部署策略,这些策略对于希望通过建立AI 能力来给团队留下深刻印象的人来说是至关重要的。现在,让我们详细探讨一下机器学习的部署策略。

批量推理

现在你已经训练并进行了机器学习模型的超参数调整,是时候将最佳模型投入生产。批量推理是一种部署策略,其中机器学习模型以实时方式部署,只接受数据批次。采用这种策略,模型通常能够离线工作或处理周期性任务,例如生成报告或预测。

批量推理可以在我们希望对客户对各种产品的情感进行分类的场景中非常有用。换句话说,他们可能会给出评论,如果我们想了解客户对产品的整体情感,批量推理可以是一种很好的机器学习模型部署策略。

实时推理

这是一种将机器学习模型在接收数据时进行实时运行的部署方式。因此,它们会准备好以某种格式接收数据并提供实时预测,以便可以采取相应的行动。此外,根据项目和团队目标,可能还需要满足实时推理的要求,如低延迟系统或更高的预测准确性。

实时推理的一个经典例子是在进行交易时检测欺诈活动的可能性。机器学习模型最初会使用包含欺诈和非欺诈交易的数据进行训练。在选择出最佳模型后,它会进行实时推理,以便客户能够了解是否发生了欺诈活动。

本地部署

在产品实时部署之前,机器学习团队通常需要进行高安全性措施和数据合规检查。在这种情况下,数据和生产中的机器学习代码的重要性更高。

本地部署涉及在组织设施内的物理设备或服务器上部署机器学习模型。因此,它可以提供对数据和模型的高安全性和控制。

本地部署在预测性维护中可能非常有用,其中使用机器学习模型来确定各种制造设备的故障可能性。我们不依赖互联网提供实时预测,而是使用我们自己的一组服务器和机器,这些设备能够提供机器学习预测所需的计算能力。每当模型预测制造材料存在各种缺陷时,人们可以通过更换这些产品来采取行动。

云部署

这是一种将我们的机器学习服务提供到云中的部署方式。因此,我们利用集群设备的计算资源和内存。因此,我们应能根据用户执行的各种机器学习操作的流量来扩展我们的应用程序。

云部署在我们不确定训练和部署模型所需资源数量时可能会很有用。此外,这些服务仅在用户使用我们的预测时才会初始化。

使用云部署的一个流行示例是预测客户在特定服务集上流失的可能性。如果我们建立了一个服务,让订阅者使用,我们将基于一组预测特征预测客户是否会离开服务。由于我们无法完全了解可能注册服务和同时离开服务的客户总数,因此在中部署训练模型是一种好的方法,因为这会根据流量需求简化扩展。

移动部署

这是一种在移动设备(如智能手机和平板电脑)上部署机器学习模型的部署方式。这种类型的部署示例包括个人助理、图像识别和语言翻译应用程序。

由于我们在资源受限的环境中部署模型,这与在服务器上部署的环境不同,因此必须在最终形成 ML 产品之前考虑硬件因素。机器学习应用可能非常有用,并且可以具有合理的准确性。然而,如果模型在硬件资源较少的情况下无法生成预测,那么它可能不适合用于移动应用程序。

在尝试实时在移动设备上部署这些产品时,必须考虑低延迟要求、偏差-方差权衡以及其他因素。

边缘部署

这是一种在边缘设备(如物联网 (IoT))上部署机器学习模型的方式。这些设备位于网络的边缘,并依赖于稳定的互联网连接来进行预测。尽管如此,也有一些物联网设备不需要互联网连接,而是拥有能够生成预测的硬件。

在尝试使用这种类型的部署时,必须考虑一些要求。重要的考虑因素包括处理能力、内存容量和连接性。这些因素对用于机器学习的物联网设备的性能有很大影响。

另一个重要的考虑因素是优化模型以适应边缘部署,并考虑在云端运行模型的可行性。这可能涉及减少需要在边缘设备上使用的 ML 模型的复杂性或大小。因此,部署类型将取决于用于提供机器学习能力的设备类型。

结论

总的来说,我们已经探讨了很多机器学习和深度学习模型的部署选项。在许多在线课程中,重点强调了机器学习模型及其内部工作原理。这些课程很好地突出了一些关于这些模型的细微差别,这些差别可以在测试之前进行深入解读。然而,也应该对部署方面给予足够重视,因为由于众多考虑因素,实时部署这些模型可能具有挑战性。

阅读完这篇文章后,希望你对机器学习的常见部署选项有了一些了解。谢谢。

以下是你可以联系我或查看我工作的方式。

GitHub: suhasmaddali (Suhas Maddali ) (github.com)

YouTube: www.youtube.com/channel/UCymdyoyJBC_i7QVfbrIs-4Q

LinkedIn: (1) Suhas Maddali, Northeastern University, Data Science | LinkedIn

Medium: Suhas Maddali — Medium

机器学习中的向量表示

原文:towardsdatascience.com/vector-representations-for-machine-learning-5047c50aaeff

数据科学家如何将现实世界的对象转换为数值表示以开发机器学习模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 安德烈亚·达戈斯提诺

·发表于 Towards Data Science ·8 分钟阅读·2023 年 4 月 25 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Sigmund 提供,来自 Unsplash

机器学习工程师利用世界的数值表示来构建和训练预测算法。

在监督学习的背景下,这些表示方式使计算机能够学习它们与目标变量之间的关系。

假设一个向量仅仅是一个数字列表。

X = [1, 2, 3, 4, 5]

这个列表与目标变量 y 相关。

X = [1, 2, 3, 4, 5]; y = 1

机器学习模型学习特征与目标之间的关系,并输出预测——在这种情况下,是一种将一个类别标识为数字 1 的分类。

在这篇文章中,我将写关于如何使用向量以数值格式表示复杂概念。

其理论基础是机器学习模型无法从以非数值格式提供的观察数据中学习。

文本、图像、声音和其他输入观察数据必须首先转化为适合学习的数值格式。

有多种技术可以将现象转化为向量,这取决于我们处理的数据类型。

  • 我们将从介绍One-Hot 编码的概念开始,这是一种用于将单词表示为数值向量的技术。

  • 接下来,我们将探讨这种技术的局限性,并介绍嵌入的概念,这是一种可以将单词、图像、声音等表示为比 One-Hot 编码所需的成千上万个类别更小的数值向量的技术。

  • 我们还将提到TF-IDF 和词袋模型,它们在文本向量化中至关重要。

我们如何将现象编码为向量?

我们将以文本为例继续讨论。这个例子非常明显,因为正如我们所猜测的那样,机器学习模型不能直接使用文本进行学习。我们需要首先将每个字符或词汇转换为数字。

假设我们想创建词汇的数值表示

  • 国王

  • 皇后

  • 王子

  • 公主

对这些词汇进行编码的最简单方法是依次为每个词汇分配一个数字。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

词汇已正确转换为数值格式,按照映射关系

map = {
 "King": 1,
 "Queen": 2,
 "Prince": 3,
 "Princess": 4
}

但有一个问题。如果我们将这些数据提供给任何预测模型,它会为王子和公主分配更高的数学值,使它们比国王和皇后更重要。

显然,这将为模型提供错误的信息,导致模型学习错误的关系。我们需要使我们的数值表示更为精确。

One-Hot Encoding

为了解决上述数值表示问题,可以使用One-Hot Encoding技术。

在这种情况下,每个词将由一个数值向量表示,向量的大小等于需要表示的词总数。该向量的所有值都为零,只有一个值表示特定的词。

例如,对于“国王”、“皇后”、“王子”和“公主”这四个词,每个词将由一个包含四个元素的数组表示,其中值“1”在对应于该词的位置,所有其他位置的值为“0”。

该技术解决了在数值表示中为词汇分配更高数学值的问题,这些词汇在数量上并不比其他词汇更重要。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

现在,我们的模型为数据集中每个词(在此案例中仅包含 4 个词)具有“平衡”的向量表示。

但是……如果我们的词汇表由成千上万甚至数百万个词汇组成呢?考虑到意大利词典中大约有 270,000 个词汇,应用 One-Hot Encoding 至少会有问题。

执行这种编码所需的计算资源将是相当可观的,最终的表示将是“仅仅”平衡的:没有关于词汇之间关系的信息。

嵌入

为了克服 One-Hot Encoding 的局限性,可以使用称为嵌入的技术。这允许将词汇表示为可控大小的数值向量,相较于 One-Hot Encoding 所需的成千上万的类别。

其想法是创建一个数值表示词汇的方式,这种表示方式考虑了词汇之间的语义关系。

实际上,每个词汇都表示为一个实数向量,其中每个维度表示词义的不同方面。

理解嵌入很简单:相关的单词应该在向量空间中靠得很近,而不相关的单词则应该相距较远。

让我们尝试创建一个图表,捕捉前面提到的一些单词特征。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

我们可以看到王子公主的词语彼此多么接近,就像国王女王一样。

假设性别变量只能取两个值,M 和 F(我们使用 0 和 1),而年龄变量只能取三个值[年轻、中年、老年](我们使用 0、1、2),我们可以看到嵌入如何表示这些关系。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像

这种表示通过使用性别和年龄维度,成功捕捉了个体的贵族地位。

在 X 轴上移动,我们可以观察到这两个贵族如何在表示性别差异的维度上等距(0:男性,1:女性)。而在 Y 轴上移动,我们可以观察到年龄是如何通过嵌入距离 Y 轴的方式来表示的。

这样,词嵌入可以作为机器学习模型的输入,使得复杂的概念能以数字格式更准确地表示

在这个例子中我们只有两个维度。实际上,神经网络是通过特定任务来训练的,以在多个维度上找到这些表示。

为了更好地理解,像 GPT-3 这样的模型使用了超过 12,000 维

行业的一个里程碑

嵌入表示不仅可以用于单词,还可以用于表示图像、声音等。

向量表示的使用在当今的机器学习中至关重要。深度学习领域的各种创新和技术源于向量化的概念。

像 GPT-3.5 这样的模型通过结合向量表示、经过充分研究的优化算法和大量计算资源而产生。

理论上,这种方法没有限制。

更多的数据 → 更高质量的向量 → 使用这些向量的模型进行更好的训练。

嵌入的局限性

尽管嵌入是一种非常有用的技术,用于以数字格式表示复杂概念,但它们也有局限性。

特别重要的是要强调,嵌入是从训练数据中构建的,因此可能会受到数据中任何偏见的影响

如前所述,嵌入的质量取决于训练数据的质量。如果训练数据不代表模型将使用的领域,嵌入可能无法捕捉概念之间所有的语义关系。

此外,嵌入可能需要大量的内存来存储,特别是当维度数量很大时。这对于需要在资源受限的设备上运行的机器学习模型,尤其是移动设备,可能特别有问题。

其他文本表示方式

由于文本是我们周围最常见的数据格式(只需想到互联网上大量的文本数据),一些文本向量化技术是常见且众所周知的。

其中之一是TF-IDF 转换,这是一种文本向量化技术,根据词在文档中的频率及其在语料库中的总体频率,为每个词分配一个权重。

这样,文档中出现频率较高但在语料库中出现频率较低的词汇将比那些在各处频繁出现的词汇具有更高的权重。这种技术在自然语言处理领域用于文本分析中被广泛使用。

我邀请有兴趣的读者通过阅读以下文章来了解更多关于 TF-IDF 模型的知识

## Text Clustering with TF-IDF in Python

文本聚类的简单管道解释。完整示例和代码

medium.com

TF-IDF 基于词袋模型,该模型将文档表示为一个无序的词汇集合,忽略句子结构和词序。

通过这种方式,词袋模型可以用来表示任何文档为一组数值,其中每个数值代表一个词在文档中的频率。当然,这并不能充分表示词汇之间的关系,这一点由词嵌入(embeddings)提供。

结论

在这篇文章中,我们已经看到如何使用向量以数字格式表示复杂的概念。

对数据科学家来说,以向量化的方式思考是重要的。像这样的疑问

  • 我如何将这个刺激转换成一个数字?

  • 神经网络如何解释这些数据?

  • 我如何改善这个表示?

这些问题至关重要,能够充分回答这些问题的团队将创造出更好的系统。

数据科学家从向量的角度看待世界。

如果你想支持我的内容创作活动,请随意通过下面的推荐链接加入 Medium 的会员计划。我将获得你投资的一部分,你将能够无缝访问 Medium 上丰富的数据科学及更多领域的文章。

## Join Medium with my referral link - Andrea D’Agostino

阅读 Andrea D’Agostino(以及 Medium 上的其他数千位作家的)每个故事。你的会员费直接……

medium.com

推荐阅读

对于有兴趣的人,这里是我推荐的每个机器学习相关主题的书单。这些书籍在我看来是必读的,并且对我的职业生涯产生了深远的影响。

免责声明:这些是亚马逊的附属链接。我将从亚马逊那里获得少量佣金作为推荐费。您的体验不会改变,您不会被收取额外费用,但这将帮助我扩大业务并制作更多有关人工智能的内容。

有用的链接(由我编写)

向量搜索并不是你所需的一切

原文:towardsdatascience.com/vector-search-is-not-all-you-need-ecd0f16ad65e

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 安东尼·阿尔卡拉兹

·发表于数据科学前沿 ·阅读时间 6 分钟·2023 年 9 月 18 日

人工智能软件被用来提升本文文本的语法、流畅性和可读性。

引言

检索增强生成(RAG)已经革新了开放领域问答,使系统能够对各种查询生成类似人类的响应。RAG 的核心是一个检索模块,它扫描大量语料库以找到相关的上下文段落,这些段落随后由神经生成模块(通常是预训练的语言模型,如 GPT-3)处理,以制定最终答案。

尽管这种方法非常有效,但也不是没有局限性。

最关键的组成部分之一,即对嵌入段落的向量搜索,存在固有的限制,这可能会妨碍系统以细致的方式进行推理。这在需要跨多个文档进行复杂的多跳推理时尤为明显。

向量搜索指的是使用数据的向量表示来搜索信息。这涉及两个关键步骤:

  1. 将数据编码为向量

首先,被搜索的数据会被编码为数值向量表示。对于像段落或文档这样的文本数据,这通过诸如 BERT 或 RoBERTa 的嵌入模型来完成。这些模型将文本转换为表示语义意义的连续数字的稠密向量。图像、音频和其他格式也可以使用适当的深度学习模型编码为向量。

2. 使用向量相似性搜索

一旦数据被编码为向量,搜索涉及到找到与搜索查询的向量表示相似的向量。这依赖于距离度量,如余弦相似度,以量化两个向量的接近程度并对结果进行排序。距离最小(相似度最高)的向量被返回为最相关的搜索结果。

向量搜索的主要优势在于能够搜索语义相似性,而不仅仅是字面上的关键词匹配。向量表示捕捉了概念意义,从而能够识别出更相关但语言上不同的结果。这使得搜索质量比传统的关键词匹配更高。

然而,将数据转换为向量并在高维语义空间中进行搜索也存在局限性。平衡向量搜索的权衡是一个活跃的研究领域。

在本文中,我们将剖析向量搜索的局限性,探讨它为何难以捕捉文档之间的多样关系和复杂的相互联系。我们还将深入研究如知识图谱提示等替代技术,这些技术有望克服这些不足之处。

随着我们的生活中越来越多地整合 AI 工具,了解当前 AI 工具的优缺点是至关重要的。本文旨在提供对向量搜索在增强大语言模型推理能力方面的优缺点的全面视角。

问题与答案之间的语义差距

在向量搜索中,输入问题和语料库中的段落都被编码为密集的向量表示。通过找到与问题向量具有最高语义相似性的段落来检索相关的上下文。

然而,问题往往与它们寻求的实际答案存在间接关系。

“法国的首都是什么?”的向量可能不一定与陈述“巴黎是法国人口最多的城市”的段落具有高度相似性。

这种语义差距意味着包含答案的段落可能会被忽视。

嵌入无法捕捉问题与答案之间的推理联系。

段落粒度的重要性

在向量搜索系统中,段落通常由单一的嵌入向量表示。这些段落的粒度可以有所不同。

如果段落非常大,例如整个文档,它可能包含多个概念。段落的某些部分可能相关,而其他部分则不相关。

但由于单个向量代表整个段落,因此无法区分相关部分和无关部分。整个段落可能与问题向量只有微弱的相似性。

相反,使用句子级别的块可以帮助隔离概念。但这会增加索引中的向量数量,增加计算开销。

选择段落大小时存在精度与可处理性之间的固有权衡。

对复杂推理的挑战

有些问题需要综合多个文档中的事实。

例如,“酿酒的最早历史记录是什么?” 可能需要从不同来源拼凑日期。

向量搜索对于这种多跳推理能力不足。每个段落独立地对问题进行评分。没有机制可以共同分析或连接不同结果中的信息。

随着问题变得越来越复杂,简单的相似性搜索达到了极限。系统难以从不同的段落中收集和上下文化事实。

黑箱模型工作原理

在标准向量搜索流程中,如何选择初始检索的段落是不透明的。排名取决于语义相似性模型的内部工作。

这种缺乏透明性使得结果难以解释、验证和改进,也限制了在业务关键应用中的部署。

为了增加监督,排名算法应提供一些可解释性,以说明为什么某些段落被认为是相关的。

建模多样关系

标准向量搜索的核心限制在于其单一关注语义相似性。

然而,现实世界的推理需要对内容之间的多样关系进行建模。

## 知识图谱提示:一种用于多文档问答的新方法

多文档问答(MD-QA)涉及回答需要综合多篇信息的问题…

blog.gopenai.com

知识图谱通过明确编码各种连接到互联图结构中来克服这一点。具体而言:

  • 主题关系 — 如果段落共享稀有或关键关键词,则这些段落会被链接。这捕捉到讨论主题的相似性。

  • 语义关系 — 段落嵌入被比较以连接那些语义上接近的段落,即使它们不共享相同的术语。

这超越了表面层次的主题匹配。

  • 结构关系 — 段落与它们出现的特定部分、页面或文档相连。

这编码了上下文层次结构。

  • 时间关系 — 讨论时间顺序事件的段落按时间顺序链接在一起。

这代表了事件的流动。

  • 实体关系 — 在引用相同现实世界实体的段落之间添加了指代链接。

这允许以实体为中心的推理。

通过结合这些超越语义相似性的多样信号,知识图谱提示(KGP)提供了一个更丰富的推理基础,用于关于互联信息的推理。

结构关系

相比之下,标准向量搜索没有这些结构关系的概念。段落被视为原子,没有任何周围的上下文。

知识图谱对结构关系的建模值得进一步讨论。通过将段落链接到它们出现的特定文档或部分,信息的上下文层次被编码。

这使得可以明确推理某一事实所在的部分、其来源的文档以及发布的网站。

对层次文档结构的编码为确定重要性、有效性和相关性提供了有用的归纳偏差,这在跨段落推理时尤为重要。

时间关系

在孤立的向量搜索中完全不存在这种归纳偏差。向量相似性评分没有考虑时间动态。检索到的段落是断裂的快照,缺乏叙事流。

KGP 中时间关系的明确建模也带来了显著的优势。根据所描述事件的时间顺序排列段落,使得对展开的叙事和时间线进行推理成为可能。

知识图谱通过根据相对时间链接事件来克服这一局限性。这解锁了更丰富的推理能力。

实体关系

在标准的向量搜索中,这些实体链接没有直接建模。有关实体的宝贵知识在段落嵌入中丢失。

知识图谱连接实体引用的能力是一项强大的资产。链接讨论相同现实世界实体、概念或人物的段落,可以围绕这些共享元素进行重点推理。

KGP 保留了这一信号,使得可以以实体为中心探索知识图谱。这在跨文档聚合关于特定实体的事实时提供了结构性优势。

结论

向量搜索基于语义相似性实现了高效的近似匹配。然而,在 RAG 系统的检索步骤中,单独使用时存在明显的局限性。

采用结合向量搜索与基于图谱的知识表示、多步推理模块和透明排名算法的混合方法可以帮助克服这些弱点。

一如既往,没有单一的解决方案——利用多样化的技术工具包是实现现实世界问答系统的强大检索的关键。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

此图像是使用 AI 图像生成模型创建的。

来源:

medium.com/thirdai-blog/understanding-the-fundamental-limitations-of-vector-based-retrieval-for-building-llm-powered-48bb7b5a57b3

labelbox.com/blog/how-vector-similarity-search-works/

www.elastic.co/what-is/vector-search

kaushikshakkari.medium.com/open-domain-question-answering-series-part-7-the-rise-of-vector-databases-in-the-world-of-9d848a3f47d5

medium.com/@PolonioliAI/limitations-of-vectors-and-neural-search-4d81fd64482f

medium.com/vector-database/frustrated-with-new-data-our-vector-database-can-help-e5c430b29be7

www.singlestore.com/blog/why-your-vector-database-should-not-be-a-vector-database/

clickhouse.com/blog/vector-search-clickhouse-p1

www.searchenginejournal.com/semantic-search-with-vectors/467574/

www.usenix.org/system/files/osdi23-zhang-qianxi_1.pdf

www.infoworld.com/article/3651360/solving-complex-problems-with-vector-databases.html

people.eecs.berkeley.edu/~matei/papers/2020/sigir_colbert.pdf

blog.futuresmart.ai/gpt-4-semantic-search-and-vector-databases-revolutionizing-question-answering

blog.vespa.ai/constrained-approximate-nearest-neighbor-search/

www.pinecone.io/learn/vector-search-filtering/

向量化:是什么以及它是如何工作的?

原文:towardsdatascience.com/vectorisation-what-is-it-and-how-does-it-work-1dd9cef48407?source=collection_archive---------6-----------------------#2023-04-13

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由Mariana Beltrán提供,来自Unsplash

O(n)比 O(1)更快,缓存行,Pandas 2.0 和列的持续增长

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Mark Jamison

·

关注 发布在 Towards Data Science ·10 min read·Apr 13, 2023

这是本文的第二次迭代。在完成第一次迭代后,我让它静置一段时间进行编辑,因为标题看起来不太好——一个关于向量化的 13 分钟长篇,其中包含与数据库理论和历史趋势的松散联系。

在等待重新草拟时,我发现了几个关于新版本 Pandas 2.0 的性能比较——尤其是与 Polars 的比较。此时我必须坦白——Pandas 对我而言是起点,我甚至还没有 pip install Polars 进行测试。我总是犹豫在尚未得到广泛支持的新工具和流行工具之间进行替换,直到:

  • 现有工具开始失败(我在 SQL 聚合后使用的数据不够大)。

  • 有一些其他明显有力的证据支持采用。

然而,新版 Pandas 2.0 性能相对较差确实让我产生了疑问——如果 Polars 在内存操作中如此快速,它是如何实现的?

Polars 的作者写了我原来的文章(但更好)。

向量化。Polars 之所以快速,是因为原作者设计了整个系统以向量化为核心。在这篇‘hello world’文章中,Polars 作者 Ritchie Vink 通过清晰简洁的语言和简单的视觉效果解释了 Polars 如何实现其目标——因为它不仅仅是以向量化为理念进行构建,而是完全围绕这些原则构建的。

这篇文章的其余部分并非仅仅是重述那些内容,而是通过回顾一些想法和历史背景,阐明我们如何将基于列(或数组)的计算变得更加‘主流’,以及这如何开始渗透到现代数据科学工具包中。

“我不在乎你的花哨数据结构是什么,但我知道数组会胜过它。”

上述内容出自这次讲座由 Scott Meyers 主讲,他将上述引言归于一位支持数组的算法交易公司 CTO。这个想法在 Polars 文章中也有所提及,但概念相同——在实际应用中,有时你需要抛弃基本的时间复杂度分析,因为 O(n) 算法可以超越 O(1) 算法。

我来自非计算机科学背景,但大学(特别是美国大学)提供的大量在线材料使我能够学习一些基础的‘算法’和‘数据结构’课程。根据我所见,联合目标可以(可能总结得很糟糕)如下:

  • 使用逻辑来提出一个涉及最少步骤(算法)的过程。

  • 组织你的数据,以便选择具有最小步骤的算法(数据结构)。

理解两者核心概念的重要性在于众多公司用大量的Leetcode动态规划或二叉搜索树等与日常无关的题目来考察他们的候选人。一个典型的例子是——在查找东西时使用哈希映射而不是数组。这是因为:

  • 哈希映射的查找是 O(1)

  • (未排序的)数组是 O(n),因为你可能需要遍历整个数组

为什么上述情况并不总是成立?因为*“O(1)算法比 O(n)算法更快”是不完整的。真正的说法应该是“O(1)算法比 O(n)算法更快,前提是起始点相同*”*。

在纽约驾驶法拉利

现代 CPU 被描述为类似于在纽约驾驶一辆法拉利。毋庸置疑,这个比喻显然来自于一个认为开车唯一目的就是从 A 点到 B 点的计算机科学家,但这一点依然成立。如果你只是不断地停车和启动(尽管停车和启动非常非常快),为什么需要这样一辆快车呢?

将‘车’替换为‘处理器’,我们就看到了现代 CPU。这是因为处理器速度与内存速度之间的相对速度改进(尽管近年来速度提升变得较慢)。这一点在 Herb Sutter 的C++导向讲座中得到了很好的阐述(如果你只想了解概述,可以从 12:00 开始观看约 20 分钟),下面的图表也很好地展示了这一点:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

新的向量时代的机会和选择——ResearchGate 上的科学图

简单来说,我们有一个“摩尔定律引发的”处理器操作速度和将数据送到处理器的速度之间的差距。如果处理器大部分时间都在空闲等待新数据进行操作,那么就没有必要拥有如此快速的处理器

我们如何规避这个问题?缓存、缓存和更多缓存

下面是许多人对计算机架构的标准印象(跳到 Herb Sutter 讲座的 22:30,观看这个有趣的例子):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

换句话说——我们有:

  • 我们的 CPU 执行任务

  • 我们的内存(或 RAM)‘快速’访问

  • 我们的磁盘访问速度慢

如上图所示,内存可能相对于从硬盘检索数据是快的——但相对于我们当前处理器的数量和速度,检索速度确实不快。为了解决这个问题,硬件开发者决定将内存放置到CPU(或‘芯片’)上。这些被称为缓存,每个处理核心都有多个。 每个核心包括:

  • L1 缓存:这是最快的,并分为指令缓存(存储你代码转化成的指令)和数据缓存(存储你操作的变量)

  • L2 缓存:比 L1 大,但较慢

然后,机器上的所有核心共享一个 L3 缓存——同样比 L1 和 L2 大,但再次较慢。

下图基于来自 Intel 白皮书的放大照片(Intel 架构基础知识(v.1, Jan. 2014)),展示了 Intel i7 处理器的 4 核心 CPU 布局——L1 和 L2 缓存每个‘核心’部分内:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供——原图见此文档:Intel 架构基础知识(v.1, Jan. 2014)

但这同样适用于更新的 PC。我在 Macbook Air M2 上写这篇文章, 在维基百科上快速查阅表明 M2 芯片包含 L1、L2 和一个共享的‘最后一级缓存’(L3)。

那么,从缓存中检索数据比从内存中快多少呢?

对此——一图胜千言——或者更具体地说,是一个动画。来自游戏软件优化公司 Overbyte 的以下链接展示了相对性能的极端差异:www.overbyte.com.au/misc/Lesson3/CacheFun.html

速度在机器之间有所不同,但——以标准的‘时钟周期’单位来衡量,我们大致有:

  • L1:~1–3x 时钟周期

  • L2:~10x 时钟周期

  • L3:~40x 时钟周期

  • 主内存:~100–300x 时钟周期

换句话说:如果我们的数据在 L1 内存中,我们的计算机会比从主内存中提取数据快 30–300 倍。主内存指的是 RAM,而非硬盘。当你的 CPU 在 L1 缓存中找不到数据时,它会搜索 L2,然后是 L3,最后到主内存每次这些搜索失败都会标记为‘缓存未命中’。 缓存未命中越少,你的代码执行越快。

那么,我如何将数据加载到缓存中?缓存行

你的 CPU 为你完成了这一切。基于你的代码,这些代码被翻译成指令集(最低级的命令——即使汇编也被汇编成指令),你的 CPU 将数据加载到缓存中。然后它对数据进行操作并将结果存储在缓存中。

特别是,如果你定义以下内容:

x = 1
y = 2
x + y

你的 CPU 不仅仅是将整数 xy 加载到缓存中——它实际上加载的是‘缓存行’。 缓存行 是 CPU 处理数据的最底层单位。一个缓存行将包含你需要的数据,但也包括内存中周围的其他数据,这些数据构成了整个缓存行——通常是一个 64 字节的连续内存块。

不仅如此,CPU 还被设计来进行聪明的数据优化提取。为什么?因为这很慢——所以我们越早将数据加载到缓存中进行操作,就越早解决数据存储便宜而处理器快速之间的瓶颈。

  • 处理器对数据进行操作的速度

  • 将数据提供给处理器所需的时间——在缓存中

为了做到这一点,CPU 实现了类似于预取的功能——这意味着识别程序内存访问中的模式,并预测你将会使用哪些内存。

总结

可能最好的总结方式是回答(最后)我们最初的问题——什么是向量化?

  • 向量化是充分利用你超快处理器的手段

  • 它通过将数据组织成可预测的连续块(数组)来完成这一点,这些数据块可以一起加载(缓存行)并进行预取

  • 这可以防止你的 CPU 在等待从 L2/L3 缓存或更糟——主内存中加载数据时无所事事

但为什么是现在?如果这看起来如此明显,为什么这不是一直以来的基础?

发生了什么变化?相对的历史进展

之前的计算机科学家并不是愚蠢的(正好相反)而错过了这个显而易见的想法。相反,我们现在正在开发最适合的工具包:

  • 我们当前的硬件格局

  • 我们当前的使用案例

由于历史进展(上面的摩尔定律图表),我们正处于这样的情况:

  • 数据存储便宜而处理器迅速

  • 我们有广泛的数据分析使用案例来推动决策制定(无论是自动化还是非自动化)

结果是,由于相对的进步,我们现在面临一个瓶颈——将所有这些数据输入到我们超快的处理器中。而且为了清楚起见,这个问题并不新鲜。如上图所示,这已经是一个多年来逐渐扩大的问题,但我们应用了 L1/L2 缓存的权宜之计;因为这意味着大多数人不必担心这个问题。

但是现在差距 如此 巨大,数据规模也在不断增长,以至于问题需要在更接近源头的地方解决。换句话说,如果你想让你的数据以闪电般的速度被操作,那么不要把工作推给 CPU 设计师,而是自己动手,并将数据存储在数组中。

Kdb+、Dremel 和 BigQuery,PyArrow

一般来说,当我们处理数据时,通常有两种方式:

  • 在内存中

  • 在磁盘上

对我来说,这仅仅意味着数据是否一般存储在一个数据 或者数据 框架 中。Polars,以及最近的 Pandas 2.0,未必做了完全新的事情,而是更多地重新设计了我们内存中数据的表示方式,使其更接近于在磁盘上通常的存储方式。

为什么?因为我们在如何以可以快速过滤和汇总的数据存储方式上取得了显著进展——那么为什么不在内存中以能够利用这些进展的方式来表示数据呢?我们完全可以采纳像 Kdb+ 这样的技术驱动思想,并将其实现于内存数据存储方式中。

迈向一致的方法

Polars 基于 PyArrow——这是 Apache Arrow 内存列式数据格式 的 Python 实现。新版本的 Pandas 2.0 也是如此。PyArrow 在从磁盘加载 Apache Parquet 格式的数据时表现特别好。Parquet 基本上是原始 Dremel 论文 中描述的列格式。Dremel 是支撑 Google Big Query 的分析引擎。

要点是:这些都是相互关联的概念,最新的内存数据科学工具并不一定是一个彻底的变化,而是朝着一个日益一致的理念迈进,这个理念支撑着我们的数据分析工具。我们将数据存储在数组/列中——既在磁盘上 在内存中

为什么?因为这是考虑到处理硬件、数据存储成本和内存数据检索速度之间的相对速度改进后,做事的最佳方式。

结论:Wes McKinney 的简历

这可能看起来是一个奇怪的结论话题,但通过 Pandas 创始人的职业生涯可以看出上述向语言无关、一致的数据操作和建模方法的轨迹。他最初在 AQR Capital Management 使用电子表格处理数据,然后创建并推广了 Pandas(利用 NumPy 的向量友好型 ndarray),现在他 深度参与 Apache Arrow

他的职业生涯与现代“数据科学堆栈”(至少在 Python 中)朝着更加面向列的内存数据表示方式的演变紧密相连。看起来数组的推广势不可挡,并且我个人认为这种趋势不会放缓。请原谅这句双关语。

使用 JAX 向量化和并行化 RL 环境:以光速进行 Q 学习⚡

原文:towardsdatascience.com/vectorize-and-parallelize-rl-environments-with-jax-q-learning-at-the-speed-of-light-49d07373adf5?source=collection_archive---------1-----------------------#2023-10-15

在这篇文章中,我们学习如何向量化一个 RL 环境,并在 CPU 上并行训练 30 个 Q 学习代理,每秒进行 180 万次迭代。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Ryan Pégoud

·

关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 10 月 15 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源于 Google DeepMindUnsplash

在前面的故事中,我们介绍了时序差分学习,特别是Q 学习,并将其应用于 GridWorld 的背景中。

## 时间差分学习和探索的重要性:图解指南

在动态网格世界中比较无模型(Q-learning)和有模型(Dyna-Q 和 Dyna-Q+)的 TD 方法。

[towardsdatascience.com

虽然这个实现用于展示这些算法在性能和探索机制上的差异,但速度非常慢

实际上,环境和代理主要使用Numpy编写,这在强化学习中并非标准,尽管它使代码易于理解和调试。

在这篇文章中,我们将看到如何通过向量化环境和无缝并行化数十个代理的训练来扩展强化学习实验。特别地,本文涵盖了:

  • JAX 基础和强化学习的有用功能

  • 向量化环境及其高速原因

  • 在 JAX 中实现环境、策略和 Q-learning 代理

  • 单代理训练

  • 如何并行化代理训练,以及这有多简单!

本文中展示的所有代码均可在 GitHub上找到:

[## GitHub - RPegoud/jym: JAX 实现的 RL 算法和向量化环境

JAX 实现的 RL 算法和向量化环境 - GitHub - RPegoud/jym: JAX 实现的 RL…

github.com](https://github.com/RPegoud/jym?source=post_page-----49d07373adf5--------------------------------)

JAX 基础

JAX 是 Google 开发的另一种Python 深度学习框架,被 DeepMind 等公司广泛使用。

“JAX 是Autograd(自动微分)和XLA(加速线性代数,TensorFlow 编译器)的结合,旨在实现高性能数值计算。” — 官方文档

与大多数 Python 开发人员习惯的不同,JAX 不采用面向对象编程(OOP)范式,而是采用函数式编程(FP)[1]

简而言之,它依赖于纯函数确定性无副作用)以及不可变的数据结构不是在原地修改数据,而是创建具有所需修改的新数据结构)作为主要构建块。因此,FP 鼓励一种更具函数性和数学性的编程方法,使其非常适合于数值计算和机器学习任务。

让我们通过查看 Q 更新函数的伪代码来说明这两种范式之间的差异:

  • 面向对象 方法依赖于一个包含各种 状态变量(如 Q 值)的 类实例。更新函数被定义为一个类方法,它 更新实例的 内部状态

  • 函数式编程 方法依赖于 纯函数。实际上,这个 Q 更新是 确定性的,因为 Q 值作为参数传递。因此,对这个函数的任何调用只要 输入相同 就会产生 相同的输出,而类方法的输出可能依赖于实例的内部状态。此外,数据结构 如数组在 全局范围 内被 定义修改

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

面向对象编程函数式编程 中实现 Q 更新(作者制作)

因此,JAX 提供了各种 函数装饰器,在 RL 的上下文中尤为有用:

  • vmap (向量化映射):允许作用于单个样本的函数应用于一个 批次。例如,如果 env.step() 是一个在单个环境中执行一步的函数,那么 vmap(env.step)() 是一个在 多个环境 中执行一步的函数。换句话说,vmap 为函数添加了一个 批次维度

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 vmap 向量化的 step 函数示例(作者制作)

  • jit (即时编译):允许 JAX 执行 “JAX Python 函数的即时编译” 使其 兼容 XLA。本质上,使用 jit 允许我们 编译函数 并提供 显著的速度提升(以在首次编译函数时的一些额外开销为代价)。

  • pmap (并行映射):类似于 vmap,pmap 实现了简便的并行化。然而,它不是为函数添加批次维度,而是复制函数并在 多个 XLA 设备 上执行它。注意:应用 pmap 时,jit 也会被 自动 应用

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 pmap 并行化的 step 函数示例(作者制作)

既然我们已经掌握了 JAX 的基础知识,我们将探讨如何通过向量化环境获得巨大的速度提升。

向量化环境:

首先,什么是向量化环境,它解决了什么问题?

在大多数情况下,RL 实验由于 CPU-GPU 数据传输变慢。深度学习 RL 算法如 近端策略优化(PPO)使用神经网络来近似策略。

像深度学习中的常规做法一样,神经网络在训练推理时使用GPU。然而,在大多数情况下,环境运行在CPU上(即使在使用多个环境并行的情况下也是如此)。

这意味着,通过策略(神经网络)选择动作并从环境中接收观察和奖励的常规 RL 循环需要不断的来回交换,这影响了性能

此外,使用诸如 PyTorch 的框架而不进行*“jitting”*可能会导致一些开销,因为 GPU 可能需要等待 Python 将观察和奖励从 CPU 发送回来。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

通常的 RL 批量训练设置在PyTorch中(由作者制作)

另一方面,JAX 使我们能够轻松地在 GPU 上运行批量环境,消除由 GPU-CPU 数据传输引起的摩擦。

此外,随着 jit 将我们的 JAX 代码编译为 XLA,执行不再(或至少减少)受到 Python 低效的影响。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

RL 批量训练设置在JAX中(由作者制作)

有关元学习 RL 研究的更多细节和令人兴奋的应用,我强烈推荐Chris Lu的这篇博客文章。

环境、代理和策略实现:

让我们查看 RL 实验中不同部分的实现。以下是我们需要的基本函数的高级概述:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

简单 RL 设置所需的类方法(由作者制作)

环境

该实现遵循Nikolaj Goodger在其关于在 JAX 中编写环境的精彩文章中提供的方案。

[## 在 JAX 中编写 RL 环境

如何以 1.25 亿步/秒运行 CartPole

medium.com](https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba?source=post_page-----49d07373adf5--------------------------------)

让我们从环境及其方法的高级视图开始。这是实现 JAX 环境的一般计划:

让我们更详细地查看类方法(作为提醒,函数以“_”开头的是 私有的 ,不应在类的作用域之外调用):

  • _get_obs:此方法将环境状态转换为代理的观察。在部分可观察随机环境中,应用于状态的处理函数将在这里。

  • _reset:由于我们将并行运行多个代理,因此我们需要一个方法来在完成一个回合后进行单独的重置。

  • _reset_if_done:此方法将在每一步调用,并在“done”标志设置为 True 时触发 _reset。

  • reset:此方法在实验开始时被调用,以获取每个代理的初始状态以及相关的随机密钥。

  • 步骤:给定一个状态和一个动作,环境返回一个观察(新状态)、一个奖励和更新后的“done”标志。

实际上,GridWorld 环境的通用实现如下:

请注意,如前所述,所有类方法都遵循函数式编程范式。实际上,我们从未更新类实例的内部状态。此外,类属性都是常量,在实例化后不会被修改。

让我们更仔细地看一下:

  • init: 在我们的 GridWorld 中,可用的动作是**[0, 1, 2, 3]**。这些动作通过自我移动转化为二维数组,并在步骤函数中添加到状态中。

  • _get_obs: 我们的环境是确定性完全可观察的,因此代理直接接收到状态,而不是处理后的观察。

  • _reset_if_done: 参数env_state对应于(state, key)元组,其中 key 是一个jax.random.PRNGKey. 如果done标志被设置为 True,该函数会返回初始状态,然而,我们不能在 JAX jitted 函数中使用传统的 Python 控制流。使用jax.lax.cond,我们实际上得到一个等效的表达式:

def cond(condition, true_fun, false_fun, operand):
  if condition: # if done flag == True
    return true_fun(operand)  # return self._reset(key)
  else:
    return false_fun(operand) # return env_state
  • step: 我们将动作转换为移动,并将其添加到当前状态中(jax.numpy.clip确保代理保持在网格内)。然后我们更新env_state元组,然后检查环境是否需要重置。由于步骤函数在训练中频繁使用,对其进行 jitting 可以显著提高性能。*@partial(jit, static_argnums=(0, )*装饰器表示该类方法的“self”参数应被视为静态。换句话说,类属性是常量,在对步骤函数的连续调用中不会改变。

Q-Learning 代理

Q-learning 代理由update函数定义,以及一个静态的学习率折扣因子

再次强调,当对更新函数进行 jitting 时,我们将“self”参数传递为静态。同时,请注意,q_values矩阵是就地修改的,使用set(),其值未作为类属性存储。

Epsilon-Greedy 策略

最后,本实验中使用的策略是标准的epsilon-greedy 策略。一个重要细节是它使用随机平局,这意味着如果最大 Q 值不是唯一的,动作将从最大 Q 值中进行均匀采样使用 argmax 会始终返回具有最大 Q 值的第一个动作)。如果 Q 值被初始化为零矩阵,这一点尤其重要,因为动作 0(向右移动)将始终被选择。

否则,策略可以通过这段代码总结:

action = lax.cond(
            explore, # if p < epsilon
            _random_action_fn, # select a random action given the key
            _greedy_action_fn, # select the greedy action w.r.t Q-values
            operand=subkey, # use subkey as an argument for the above funcs
        )
return action, subkey

注意,当我们在 JAX 中使用key时*(例如这里我们采样了一个随机浮点数并使用了 random.choice),通常的做法是之后拆分 key(即“转到新的随机状态”,更多细节见这里)。*

单代理训练循环:

现在我们有了所有必要的组件,让我们训练一个单一的代理。

这是一个Pythonic的训练循环,正如你所见,我们基本上是使用策略选择一个动作,在环境中执行一步,并更新 Q 值,直到一个回合结束。然后我们重复这个过程N回合。正如我们稍后会看到的,这种训练代理的方式相当低效,但它以一种可读的方式总结了算法的关键步骤:

在单个 CPU 上,我们在 11 秒内完成了 10,000 个回合,以每秒 881 个回合和 21,680 步的速度。

100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]
Total Number of steps: 238 488
Number of steps per second: 21 680

现在,让我们使用 JAX 语法重复相同的训练循环。以下是rollout函数的高级描述:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用JAX 语法的训练 rollout 函数(作者制作)

总结一下,rollout 函数:

  1. 初始化 观察值奖励完成标志为空数组,维度等于时间步的数量,使用jax.numpy.zeros. Q 值被初始化为一个形状为**[timesteps+1**, grid_dimension_x, grid_dimension_y, n_actions**]**的空矩阵。

  2. 调用***env.reset()***函数来获取初始状态

  3. 使用jax.lax.fori_loop()函数调用fori_body()函数N次,其中Ntimestep参数

  4. fori_body()函数的行为类似于之前的 Python 循环。在选择一个动作、执行一步并计算 Q 更新后,我们在原地更新 obs、rewards、done 和 q_values 数组(Q 更新目标是时间步t+1**)*。

这种额外的复杂性导致了85 倍加速,我们现在以大约183 万步每秒的速度训练我们的代理。请注意,这里训练是在单个 CPU上进行的,因为环境较为简单。

然而,端到端的向量化应用于 复杂环境受益于多 GPU 的算法效果更佳Chris Lu 的文章报告了 CleanRL PyTorch PPO 实现与 JAX 复现之间惊人的4000 倍加速)。

100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]
Total Number of steps: 1 000 000
Number of steps per second: 1 837 563

在训练我们的代理后,我们绘制了 GridWorld 中每个单元格(即状态)的最大 Q 值,并观察到它已经有效地学会了从初始状态(右下角)到目标(左上角)的路径。

GridWorld 中每个单元格的最大 Q 值的热图表示(作者制作)

并行代理训练循环:

如承诺的那样,现在我们已经编写了训练 单个代理 所需的函数,剩下的工作就是在批处理环境中训练 多个代理,几乎没有其他工作!

由于 vmap 的帮助,我们可以快速将之前的函数转换为处理数据批次。我们只需指定预期的输入和输出形状,例如对于 env.step:

  • in_axes = ((0,0), 0) 表示输入形状,由 env_state 元组(维度 (0, 0))和一个 observation(维度 0)组成。

  • out_axes = ((0, 0), 0, 0, 0) 表示输出形状,输出为 ((env_state), obs, reward, done)。

  • 现在,我们可以在一个 arrayenv_statesactions 上调用 v_step,并接收一个处理后的 array,其中包含 env_statesobservationsrewardsdone flags

  • 注意,我们还对所有批处理函数进行了 jit 优化以提高性能(可以说,对 env.reset() 进行 jit 优化是多余的,因为它在我们的训练函数中只调用一次)。

我们必须做的最后一个调整是 为我们的数组添加批处理维度,以考虑每个代理的数据。

通过这样做,我们获得了一个函数,可以在 并行 训练 多个代理,与单个代理函数相比,只需最小的调整:

使用这个版本的训练函数,我们得到了类似的性能:

100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]
Total Number of steps: 100 000 * 30 = 3 000 000
Number of steps per second: 49 036 * 30 = 1 471 080

就这些了!感谢你读到这里,希望这篇文章为你提供了有关在 JAX 中实现矢量化环境的有用介绍。

如果你喜欢这篇文章,请考虑 分享 这篇文章并 收藏 我的 GitHub 仓库,谢谢你的支持! 🙏

[## GitHub - RPegoud/jym: JAX 实现的 RL 算法和矢量化环境

JAX 实现的 RL 算法和矢量化环境 - GitHub - RPegoud/jym: JAX 实现的 RL…

github.com

最后,对于那些希望深入了解的人,这里有一个 有用的资源 列表,帮助我入门 JAX 并撰写这篇文章:

精心策划的 JAX 文章和资源汇总:

[1] Coderized, (函数式编程) 最纯粹的编码风格,几乎不可能出错, YouTube

[2] Aleksa Gordić, 从零到英雄的 JAX YouTube 播放列表 (2022), The AI Epiphany

[3] Nikolaj Goodger, 用 JAX 编写 RL 环境 (2021)

[4] Chris Lu*,* 通过 PureJaxRL 实现 4000 倍加速和元进化发现 (2023), 牛津大学, Foerster 人工智能研究实验室

[5] Nicholas Vadivelu,Awesome-JAX (2020),一个 JAX 库、项目和资源的列表

[6] JAX 官方文档,使用 PyTorch 数据加载训练简单神经网络

使用 Modelbit 通过 Git 进行机器学习模型部署的版本控制

原文:towardsdatascience.com/version-control-your-ml-model-deployment-with-git-using-modelbit-1b3d76411436?source=collection_archive---------8-----------------------#2023-05-10

开发、部署和跟踪!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 阿维·乔拉

·

关注 发表于 Towards Data Science ·7 min read·May 10, 2023

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Yancy Min的照片,来源于Unsplash

介绍

版本控制对所有开发过程至关重要,允许开发者随时间跟踪软件变更(代码、配置、数据等)。

此外,它促进团队成员之间的合作,使他们能够在同一代码库上共同工作,而不会干扰彼此的工作。

在数据团队中,版本控制在部署模型时尤为关键。

它使他们能够准确识别发生了什么变化、何时发生变化以及是谁进行了更改——这是在部署过程中诊断和解决出现的问题或模型在部署后表现不佳时至关重要的信息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

模型版本控制(作者提供的图片)

在这种情况下,基于 git 的功能可以提供快速回滚到先前版本的能力。

因此,在本文中,我将展示如何利用 Git 功能为你的模型部署提供支持。

更具体地,我们将使用 Modelbit 的 git 功能进行部署,并将 GitHub 与 Modelbit 同步以实现协作功能。

让我们开始吧 🚀!

Git 对数据团队的重要性

在深入了解如何操作之前,让我们先建立更多关于基于 git 的版本控制的动机,以及它为何至关重要。

#1)协作

随着数据科学项目越来越大,有效的协作变得越来越重要。

使用版本控制,团队可以在相同的代码库/数据上工作,并改进相同的模型,而不会干扰彼此的工作。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分支模型(作者提供的图片)

此外,还可以轻松跟踪更改,审查彼此的工作,并解决冲突(如有)。

#2)可重复性

可重复性是构建可靠机器学习的关键方面之一。在一个系统上工作而在另一个系统上不起作用反映了不良的可重复性实践。

你可能会好奇,这为何重要?

它确保了结果可以被他人复制和验证,从而提高了你工作的整体可信度。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用版本控制的可重复性(作者提供的图片)

版本控制使你能够跟踪用于生成特定结果的确切代码版本和配置,从而使将来复制结果变得更加容易。

这对于许多人可以使用的开源项目尤其有用。

#3)持续集成和部署(CI/CD)

CI/CD 使团队能够快速有效地构建、测试和部署代码。

在机器学习中,持续集成(CI)可能涉及自动构建和测试对 ML 模型的更改,一旦这些更改被提交到代码库。

在持续部署(CD)中,目标是将模型的最新更改反映出来,一旦它们通过了测试。因此,它应该无缝更新生产中的更改,使最新版本的模型对最终用户可用。

既然我们知道了版本控制从开发和部署的角度为何重要,那么让我们看看如何利用 git 基于的功能在 Modelbit 的部署阶段。

本地仓库与 Modelbit 集成

Modelbit 完全由 git 驱动。因此,每当你推送模型进行部署时,它会将部署内部维护为 git 仓库。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

基于 Git 的部署(作者图片)

由于支持 git,它本地提供了所有版本控制的优点,适用于你的部署、模型和数据集。

更进一步,你可以从本地计算机克隆到远程 git 仓库,并执行所有 git 命令,如 git pullgit push 或进行分支等。

连接到 Modelbit git 仓库

要访问 Modelbit git 仓库,你需要添加一个 SSH 密钥,以将你的本地计算机连接到 Modelbit。

打开终端并运行以下命令:

ssh-keygen -t rsa -b 4096 -C "My SSH key"

这将创建一个 SSH 密钥。要查看它,请运行以下命令:

cat ~/.ssh/id_rsa.pub

上述命令取自官方 GitHub 文档

现在,复制 cat 命令的完整输出,并转到 Modelbit 仪表盘中的 Git 设置。点击“添加密钥”,然后粘贴上面获得的输出。如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

添加 SSH 密钥(作者图片)

完成了!

现在我们已连接到 Modelbit 的远程 git 仓库。

部署模型

让我们从 Jupyter Notebook 推送一个模型用于部署。我不会详细说明,因为我已经在之前的博客中讲过这个。

## 直接从 Jupyter Notebook 部署机器学习模型

用一行代码部署机器学习模型

towardsdatascience.com

简而言之,你应该训练一个模型,定义一个预测函数,并将这个函数对象推送用于部署,如下所示:

## Train Model
from sklearn.linear_model import LinearRegression
model = LinearRegression().fit(x, y)

## Define Prediction function
def Linear_Model(input_x):

    if isinstance(input_x, (int, float)):    ## check input type
        return model.predict([[input_x]])[0] ## prediction

    else:
        return None

## Deploy it
import modelbit
mb = modelbit.login() ## authenticate the notebook here.
mb.deploy(Linear_Model)

一旦我们部署了一个模型,我们会在 Modelbit 仪表盘中看到以下内容:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

部署仪表盘(作者图片)

克隆 Modelbit 仓库

让我们克隆到这个仓库以查看其内容。在终端中运行以下命令。

modelbit clone my_linear_model

这将克隆到 Modelbit 的 git 仓库,并创建一个名为 my_linear_model 的文件夹。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

克隆部署仓库(作者图片)

一旦你运行命令,复制获得的链接以进行认证。

如上所示,克隆创建了一个新的本地仓库,数据集、部署和端点在 Modelbit 的远程 git 仓库的主分支中

当前的仓库结构如下:

my_linear_models
├── bin
├── datasets
├── endpoints
└── deployments
    └── Linear_Model
        ├── source.py ## source code
        └── data
            └── model.pkl ## model pickle

将更改推送到 Modelbit

现在我们已经克隆到远程仓库,我们可以在本地进行任何更改并推送它们。

让我们将一个虚拟 CSV 文件添加到 Linear_Model 文件夹,提交到本地仓库,并推送到远程仓库。

my_linear_models
├── bin
├── datasets
├── endpoints
└── deployments
    └── Linear_Model
        ├── source.py ## source code
        ├── dummy_data.csv ## added locally
        └── data
            └── model.pkl ## model pickle

让我们将 CSV 文件添加到暂存区:

git add deployments/Linear_Model/git dummy_data.csv

接下来,让我们将其提交到本地仓库:

git commit -m "Add dummy data csv"

最后,让我们推送它:

git push

这样,虚拟 CSV 文件已被提交到远程 Modelbit git 仓库。

注意:我们将 CSV 添加到Linear_Model文件夹而不是datasets文件夹是有原因的。datasets文件夹仅支持通过 SQL 查询的数据集。这些查询的结果在运行时可用于运行部署。目前还不支持其他自定义数据集。

分支

如果您希望在远程 Modelbit 仓库中创建和工作于一个单独的分支,这也是可能的。

按如下方式从仪表盘创建新分支:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

分支远程仓库(图片由作者提供)

接下来,假设我们希望在这个分支上本地改进我们的模型。在你的笔记本中,你可以按如下方式切换到这个新分支:

## notebook.ipynb

mb.switch_branch("another_branch")

现在,从笔记本中进行的所有新部署(以及其他提交,如果有的话)将推送到远程 Modelbit git 仓库的another_branch分支。

同步 GitHub

远程 Modelbit 仓库可以与您的个人 GitHub 仓库自动同步。

这对于在 Modelbit 部署上执行基于 GitHub 的代码审查、CI/CD 和 Pull Request 工作流特别有用。

#1)创建一个新的 GitHub 仓库

下面,我在 GitHub 上创建了一个空仓库。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

新的 GitHub 仓库(图片由作者提供)

接下来,我们应该授予 Modelbit 对这个仓库的写权限。

#2 复制 GitHub 仓库的 SSH URL

CodeSSH下,复制 URL。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

仓库 SSH URL(图片由作者提供)

#3)在 Modelbit 中添加 Git 远程

在仪表盘中,转到Git SettingsAdd Git Remote,粘贴复制的仓库 URL,然后Connect Remote

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

将 Git 远程添加到 Modelbit(图片由作者提供)

#4)授予 Modelbit 写权限

从上述同步面板中,复制部署密钥:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

部署密钥(图片由作者提供)

现在转到 GitHub 仓库的SettingsDeploy keysAdd deploy key。粘贴密钥,给它一个标题,授予写权限,然后点击Add key

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在 GitHub 仓库中添加部署密钥(图片由作者提供)

完成!GitHub 仓库已自动更新:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

GitHub 中的部署代码(图片由作者提供)

现在,远程 Modelbit git 仓库与您的 GitHub 仓库已同步,您可以将其用于各种协作工作。

结论

至此,我们的博客结束了。

在这篇文章中,我们学习了 Git 功能对数据团队的重要性,以及如何通过 Modelbit 使用 git 来支持模型部署。

接下来,我们查看了如何将 Modelbit 内部创建的远程 git 仓库连接到个人 GitHub 仓库。

话虽如此,Modelbit 仍处于开发初期,目前可能还不是其他服务(如 Heroku)的终极替代品。

然而,根据我在使用 Modelbit 和 Heroku 的经验,我认为 Modelbit 的部署过程更为简化,不论是对经验丰富的用户还是新手都更为友好。

我期待看到他们的后续发展!

感谢阅读!

实践中的版本控制:数据、机器学习模型和代码

原文:towardsdatascience.com/version-controlling-in-practice-data-ml-model-and-code-e13c518067dc

MLOps 中版本管理的逐步指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Chayma Zatout

·发表于数据科学前沿 ·13 分钟阅读·2023 年 12 月 2 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由 Christopher Gower 提供,来源于 Unsplash

版本控制是一个至关重要的实践!没有它,你的项目可能会变得杂乱无章,使得回滚到任何期望的点变得困难。你可能会丧失重要的模型配置、权重、来自长期训练的实验结果,甚至整个项目本身。你也可能会在代码出现问题时与队友产生分歧和冲突,从而阻碍有效的合作。在本文中,我们通过一个实际的例子来探讨版本控制的重要性,例子中使用了一些该领域最常见的工具。本文的完整代码库可以在相关仓库中访问。

目录:

· 1. 介绍

· 2. 工具

· 3. 设置你的项目

∘ 3.1. 项目文件夹

∘ 3.2. 项目环境

· 4. 代码版本管理

· 5. 数据版本管理

· 6. 模型版本管理

· 结论

1. 介绍

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

版本控制是记录文件或文件集随时间变化的实践,使用版本控制系统,以便我们可以在以后回忆特定版本。在 MLOps 中,版本控制是我认为在开始机器学习项目时需要考虑的首要原则之一。为了确保我们充分利用所有好处,版本控制应应用于不同的机器学习工作流步骤,包括数据机器学习模型(ML 模型)和代码

为什么要进行版本控制? 使用版本控制来管理代码、数据和模型可以实现可重复性(这是另一个重要的 MLOps 原则),通过允许在任何时间点重新创建项目的特定状态;跟踪监控变化,通过建立系统化的方法来捕捉、记录和管理开发生命周期中的变化;协作,通过跟踪不同贡献者所做的更改,并高效地合并这些更改,以及其他许多重要的好处,如错误恢复可追溯性

版本控制用例? 让我们考虑一个具体的场景,在手写数字分类项目中,我们将用作本文的示例。

  • 代码。 假设我们引入了优化以提高速度。然而,部署后,用户报告了预测中出现了意外的不准确。得益于项目的强大代码版本控制实践,我们可以迅速识别出与错误相关的提交,并在解决错误、修复并重新集成到主项目版本之前,暂时回滚部署。

  • 数据。 假设我们决定扩展数据集以增强模型的泛化能力。然而,在扩展的数据集用于训练后,观察到模型性能出现了意外的变化。因此,我们审查版本控制历史,确定可能导致问题的具体扩展技术,并迅速回滚到数据集的先前版本。然后,我们共同优化数据扩展方法,确保只有经过验证的更改才会重新集成到主项目版本中。

  • 机器学习模型。 假设现在,我们开始优化模型架构以提高准确性。我们实施了一个卷积神经网络(CNN)以改进特征提取,并将其集成到主项目中。然而,在部署过程中,出现了细微的差异,影响了实时预测。因此,我们回滚到之前更稳定的模型版本。接着,我们共同解决问题,进行彻底测试,然后将优化后的模型重新集成到主项目版本中。

尽管这篇文章专注于如何在项目中使用版本控制,但它也是我 MLOps 文章系列的一部分。此外,通过阅读我之前和之后的教程,你将能够从工作流程到模型部署和跟踪,创建自己的端到端 MLOps 项目。

如果你对 MLOps 感兴趣,可以查看我的文章:

2. 工具

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在进行机器学习项目或任何计算机科学项目时,在开始编程之前,需要选择合适的工具。工具的选择取决于项目需求、团队专业技能、数据量和成本等不同因素。

在本文中,选择了以下工具:

  • Python 作为编程语言,结合了丰富的生态系统、社区支持、学习的便捷性、多功能性、集成能力、广泛的库、数据科学工具、可扩展性以及行业采用,这些因素共同促成了它在机器学习项目领域的突出地位。

  • Git 用于代码版本控制。Git,全称为全球信息跟踪器,是一种开源的分布式版本控制系统(DVCS),广泛用于软件开发中以跟踪项目开发过程中源代码的更改。它是一个关键工具,能够有效管理代码更改、无缝协作并保持可靠的版本历史。它已成为行业标准,被全球开发者用于各种规模的项目。

  • DVC 用于数据版本控制。DVC,全称为数据版本控制,是一种开源版本控制系统,广泛用于数据管理。它旨在管理大规模数据集、使项目可重现以及促进更好的协作。它在 Git 仓库之上运行,具有类似的感觉和流程。DVC 的一个关键特性是数据版本控制:它允许将数据集与代码分开进行版本控制。因此,数据可以被跟踪、共享,并在不同版本之间轻松切换。

  • MLflow 用于模型版本控制。它是一个开源平台,旨在管理端到端的机器学习生命周期并促进 ML 从业者之间的协作。其与流行库的兼容性和强大的社区支持使其成为以统一且可扩展的方式管理完整机器学习生命周期的有吸引力的选择。

3. 设置你的项目

在开始之前,确保你的系统上已安装 Git 和 DVC。如果尚未安装,你可以从官方 Git 网站官方 DVC 网站下载并安装,或者如果你使用的是 Ubuntu,可以直接执行以下命令:

sudo apt install git-all # to install git
pip install dvc # to install DVD (do not install it for now!)

然而,强烈建议在安装 DVD 之前创建一个虚拟环境;因此,我们将在创建虚拟环境后几分钟内进行安装。此外,请注意:

DVC 并不替代或包含 Git。你必须在系统中安装 git,以启用数据版本控制和快速实验等重要功能(推荐)。 [1]

3.1. 项目文件夹

让我们开始设置项目文件夹!为此,有几种方法,包括:

  • 从头创建文件夹: 这是最直接的方法,但需要手动添加标准文件并随后构建项目。我不推荐在中大型项目中使用这种方法。

  • 导入现有模板: 这是进行简单维护、易于协作以及确保良好透明性、可重复性和可重用性的最佳选择。本文中,我们将使用以下项目结构,用于通过这个 Github 模板或这个Cookiecutter MLOps 仓库创建的机器学习项目,但可以随意探索其他模板。如果你渴望深入了解 ML 项目结构,我邀请你阅读我专门讨论这一主题的文章:考虑 MLOps 的机器学习项目结构

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用这个 Github 模板或这个Cookiecutter MLOps 仓库的项目结构

  • 克隆/分叉现有项目: 当处理现有项目时,这通常是最佳选择。它支持协作和代码重用。对于本文,可以随意克隆或分叉我的仓库,以便轻松重用提供的代码。要克隆项目,请使用:
# Clone repository:
git clone git@github.com:Chim-SO/hand-written-digits-classification.git

使用 Github 模板或克隆 Github 仓库需要对 Github 有一定的了解。不过,请放心!你仍然可以跟随本教程,我将为你提供必要的命令和解释。

3.2. 项目环境

另一个重要步骤是设置虚拟环境,这是软件开发中的最佳实践,可以增强项目隔离、依赖管理、可重复性、协作和整体项目整洁性。

  • 首先创建一个名为 handwritten-digits-classification-env 的虚拟环境并激活它:
python -m venv venv/handwritten-digits-classification-env
source venv/handwritten-digits-classification-env/bin/activate
  • 在那之后,通常在使用 GPU 时,我们需要更新环境以适应适当的 Cuda 版本(有关详细信息,请参见这篇文章)。然而,为了使本教程简单易懂,并且项目要求较简单,因为数据和模型都不大,所以不需要 GPU。

  • 最后,我们通过执行以下命令来安装需求和 DVC:

pip install -r requirements.txt 
pip install dvc

4. 代码版本控制

在设置好仓库后,我们现在准备开始版本控制!在本教程中,我们采用一种简单的功能分支工作流。这种工作流涉及为每个新功能创建一个专门的分支,而不是直接更改主分支。然后,我们使用变基/合并方法将功能分支无缝集成到主分支中。

  • 我们首先列出所有分支,并检查当前所在的分支,通常用星号(*)标记:
git branch # List local branches
* master

git branch -r # List remote branches
remotes/origin/HEAD -> origin/master
origin/master

git branch -a # List all local and remote branches
* master
remotes/origin/HEAD -> origin/master
remotes/origin/master

这里我只有一个分支,即主分支,并且它是当前分支。

  • 如果你还不在主分支上,请使用以下命令切换到它:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration but you can skip it now
  • 我们首先创建一个名为feature/data的分支,在其中添加所有与数据处理相关的代码:
git branch feature/data # to create a branch
git checkout feature/data # to switch to the created branch
# or use the combined creation and switch command
git checkout -b feature/data
  • 在添加了所有必要的代码后,我们通过使用merge命令将代码导入主分支,该命令将命名分支中的更改并入当前工作分支:
git checkout master # switch to the main directory
git merge feature/data # apply changes to master
  • 类似地,我们创建了另一个名为feature/model的分支,在其中添加了所有与模型创建、训练和验证相关的代码,并将其合并到主分支中:
# Model branch creation:
git checkout master # switch to the main directory
git checkout -b feature/model

# Development ...

# Merge branch
git checkout master # switch to the main directory
git merge feature/model # apply changes to master

此时,我们可以说我们创建了代码的简单第一个版本!现在是时候标记这个特定点,添加一个标签,如下所示:

git tag -a v1.0 -m "Version 1.0"

整个工作流描述如下:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

每个圆圈代表一个提交,可以使用以下命令显示:

git log --pretty=format:"%h - %an, %ar : %s"

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

回到我们的代码问题示例:

  • 假设在部署后出现了问题,我们决定暂时将部署回滚到之前的版本:
git revert <commit-hash>  # Revert the merge commit

撤销操作通过创建一个新的提交来撤销指定提交所做的修改,但我们可能需要解决在此过程中出现的任何冲突,类似于常规合并中发生的情况。

  • 通过使用提交历史,我们识别出模型分支中的一个特定优化可能导致了这个问题。因此,创建了一个名为hotfix/inference-bug的热修复分支来解决这个错误:
git checkout -b hotfix/inference-bug
  • 然后我们对代码进行必要的修正并提交更改:
git commit -m "Fix bug in digit classification during inference"
  • 错误修复经过彻底测试,并且如果我们真的在团队中工作,还会为代码审查打开一个新的拉取请求,最终热修复被合并到主分支中:
git checkout main
git merge hotfix/inference-bug

修复了错误的代码重新部署到生产环境。

5. 数据版本控制

现在代码准备好了,我们可以将数据集下载到其第一个版本,然后将其转换为csv格式。

  • 首先,我们需要确保存储数据的文件夹未被 git 忽略。这通过检查.gitignore文件并移除/注释掉排除数据文件夹的行来完成。如果你使用的是我提供的模板,请注释掉第 79 行。

  • 现在,我们开始创建一个分支feature/data-csv,在项目文件夹内初始化 DVC 项目,最后将创建的文件添加到 Git 中:

# Branch creation:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration but you can skip it now
git checkout -b feature/data-dvc

# DVC initialisation: 
dvc init

# Add to Git the created files:
git commit -m "chore: Initialize DVC."
  • 然后,我们下载我们的数据集,将其添加到 DVC 中,并将新的 DVC 文件添加到 git 中:
# Download data
python src/data/ingestion.py -r data/raw

# Add data to dvc
dvc add data/raw/test_images.gz data/raw/test_labels.gz data/raw/train_images.gz data/raw/train_labels.gz

#Add dvc files to git and commit
git add data/raw/.gitignore data/raw/test_images.gz.dvc data/raw/test_labels.gz.dvc data/raw/train_images.gz.dvc data/raw/train_labels.gz.dvc
git commit -m "Add raw data"

将文件添加到 dvc 将生成存储在新文件中的元数据,文件扩展名为.dvc。另外,请注意,即使数据文件夹由 git 跟踪,但一旦我们将其添加到 DVC 中,它会创建.gitignore并将数据路径添加到其中,因此它将被忽略。

  • 我们还将其转换为csv格式,并将生成的文件data/preprocessed/train.csvdata/preprocessed/test.csv以与之前相同的方式添加到 DVC 中:
# transform data:
python src/data/build_features.py -r data/raw/ -p data/processed/

# Add to dvc: 
dvc add data/processed/train.csv data/processed/test.csv

#Add dvc files to git and commit
git add data/processed/.gitignore data/processed/test.csv.dvc data/processed/train.csv.dvc
git commit -m "Add processed data"

目前,数据已下载并创建。下一步是合并到主分支并添加 git 标签:

# Apply changes:
git checkout master # switch to the main directory
git pull origin master # mendatory when working in collaboration, you can skip it now
git merge feature/data-dvc # apply changes to master

# Tag this point:
git tag -a v1.1 -m "Data collected and processed"

回到我们的数据问题示例:

  • 假设我们对处理过的数据进行了离线增强,并将其添加到 DVC 中:
# Add to dvc after update: 
dvc add data/processed/train.csv data/processed/test.csv
git add data/processed/test.csv.dvc data/processed/train.csv.dvc
git commit -m "Data augmentation offline"
  • 然而,训练后,模型表现不佳,因此我们决定重新使用之前的版本,如下所示:
git checkout data/processed/test.csv.dvc data/processed/train.csv.dvc
dvc checkout data/processed/test.csv.dvc data/processed/train.csv.dvc

6. 模型版本控制

如前所述,我们使用 MLflow 来跟踪和管理我们的模型。由于目前我们将在本地工作,我们启动一个本地的 MLflow Tracking Server:

mlflow server --host 127.0.0.1 --port 8080
  • 创建一个分支,我们在其中训练并保存模型:
# Branch creation:
git checkout master # switch to the main directory
git checkout -b feature/model-dvc
  • 现在,我们初始化 MLflow 运行上下文以启动一个运行,训练模型,然后使用 MLflow 保存模型:
# Create model:
model = create_model(x_train[0].shape)

# Log parameters:
loss = 'categorical_crossentropy'
metric = 'accuracy'

# Train:
model.compile(loss=loss, optimizer='adam', metrics=[metric])
history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, verbose=1,
                    validation_data=(x_val, y_val))

# ....

# Set tracking server uri for logging
mlflow.set_tracking_uri(config['mlflow']['tracking_uri'])

# Create an MLflow Experiment
mlflow.set_experiment(config['mlflow']['experiment_name'])

# Start an MLflow run
with mlflow.start_run():
   # Save model:
   signature = infer_signature(x_train, y_train)
   mlflow.tensorflow.log_model(model, output_path, signature=signature)
   #Log other metrics and parametrics:
   # Next tutorial.
  • 合并到主分支并添加 git 标签:
# Apply changes:
git checkout master # switch to the main directory
git merge feature/model-mlflow # apply changes to master

# Tag this point:
git tag -a v1.2 -m "Model versioning mlflow"
  • 使用以下命令训练模型:
python -m src.models.cnn.train -c configs/cnn.yaml

其中configs/cnn.yaml文件包含一些配置参数,如批量大小和训练轮次。

  • 我们可以在 MLflow UI 中查看运行结果,只需在浏览器中导航到之前的 URL。点击实验名称cnn以列出其相关的运行,然后点击为运行生成的随机名称:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

  • 通过点击运行名称,显示 RUN 页面,其中显示了执行的详细信息:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

当你使用 MLflow 保存模型时,它会创建一个包含以下内容的目录结构:

  • data文件夹包含序列化的模型参数文件。

  • MLmodel文件,包含有关模型的元数据,如框架、模型签名及其他属性。

  • conda.yamlpython_env.yamlrequirements.txt文件,这些文件帮助重建相同的环境以便加载模型。

它还提供了对模型架构的深入理解,并展示了如何执行预测,支持 Spark DataFramePandas DataFrame 两种灵活的选择。MLflow 的另一个显著特点是能够保留生成模型的提交 ID。此外,它引入了一个简单的模型注册选项,这是一个将在后续文章中探讨的话题。

结论

我们来到了本文的结尾。在这篇文章中,我们通过一个实际示例学习了如何对机器学习项目中的三个元素:代码、数据和机器学习模型进行版本控制。版本控制是 MLOps 的一个基本原则,它能实现细致的跟踪、无缝的协作以及机器学习工作流程的强大可复现性。本文的完整代码库可在 相关的仓库 中访问。

感谢阅读本文。您可以在我的 GitHub 个人主页 中找到我提供的所有不同教程的示例。如果您喜欢我的教程,请通过关注我和订阅来支持我。这样,您将收到我新文章的通知。如果您有任何问题或建议,请随时留言。

## MLOps 的关键起点:探索其基本组件

初学者友好的 MLOps 介绍

towardsdatascience.com

参考文献

[1] dvc.org/doc/install

图片来源

本文中所有未在说明中提及来源的图片和图表均由作者提供。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值