近似核方法Random Binning Feature(RBF)词嵌入降维

Random Binning Feature(RBF)

介绍

Random Binning Feature(RBF)RBF 将输入数据映射到固定的特征空间,其中每个维度对应于输入数据的一个固定范围。这个范围由我们自己指定,并且在 RBF 中是固定不变的。

假设我们有一个形状为 (N, d) 的输入张量 x,我们想要将其映射到形状为 (N, D) 的特征空间,其中 D 是我们想要映射到的特征空间的维度。我们可以按如下方式进行 RBF 映射:

将每个输入维度分成 k 个不同的范围,例如 [-1, -0.5], [-0.5, 0], [0, 0.5], [0.5, 1]。这些范围被称为“箱子”(bins)。

对于每个箱子,生成一个随机向量 r,该向量的每个元素都是从标准正态分布中随机抽取的。

将每个输入维度分配到相应的箱子中。对于每个箱子,将输入维度中位于该箱子中的元素加起来,然后乘以该箱子对应的随机向量 r。重复此过程 k 次,并将所有 k 个结果拼接在一起形成一个新的特征向量。

将所有箱子中的特征向量相加,得到最终的特征向量。

代码

代码如下:

import torch

class RBF(torch.nn.Module):
    def __init__(self, d, D, k):
        super(RBF, self).__init__()
        self.d = d
        self.D = D
        self.k = k
        self.bins = self.generate_bins()
        self.r = self.generate_weights()

    def generate_bins(self):
        bins = []
        for i in range(self.d):
            bins.append(torch.linspace(-1, 1, self.k+1)[1:-1])
        return bins

    def generate_weights(self):
        r = torch.randn(self.k, self.d, self.D)
        return r

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.unsqueeze(2).repeat(1, 1, self.k, self.D)
        bins = torch.tensor(self.bins).unsqueeze(2).unsqueeze(0).repeat(batch_size, 1, 1, self.D)
        x = (x > bins).to(torch.float32) * (x < bins + 1/self.k).to(torch.float32)
        x = x.sum(dim=1)
        r = self.r.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        x = x.unsqueeze(3).repeat(1, 1, 1, self.D)
        r = r.repeat(batch_size, 1, self.k, 1)
        x = x * r
        x = x.sum(dim=2)
        return x

在这个代码中,我们定义了一个 RBF 类,它是一个继承自 PyTorchnn.Module 类的子类。我们在 init 函数中初始化了 RBF 的一些参数,包括输入维度 d、输出维度 D 和箱子数量 k。我们还定义了两个帮助函数:generate_binsgenerate_weights

generate_bins 函数生成了箱子列表。对于每个输入维度,我们使用 torch.linspace[-1, 1] 之间生成 k+1 个数,并删除第一个和最后一个数。这些数用作该维度中的箱子边界。

generate_weights 函数生成了随机权重 r。我们使用 PyTorchtorch.randn 函数从标准正态分布中随机生成 kdD 个数,并将其形状重塑为 (k, d, D)

forward 函数中,我们首先将输入张量 x 的形状重塑为 (batch_size, d, 1, 1),并使用 unsqueezerepeat 函数来将其复制为形状为 (batch_size, d, k, D) 的张量。我们还将箱子列表转换为张量形式,并使用 unsqueezerepeat 函数将其复制为形状为 (batch_size, d, k, D) 的张量。

接下来,我们将 x 和 bins 进行比较,生成一个布尔张量。使用 to 函数将布尔张量转换为浮点张量,并乘以 1/self.k 得到每个箱子中的元素数量。我们将所有元素数量相加得到一个形状为 (batch_size, k, D) 的张量。

然后,我们将随机权重 r 和上面的张量进行乘法,得到一个形状为 (batch_size, d, k, D) 的张量。使用 sum 函数沿着 k 维度进行求和,得到一个形状为 (batch_size, d, D) 的张量。

最后,我们将所有输入维度上的特征向量相加,得到最终的特征向量。

下面是一个实例化 RBF 类的例子:

import torch

# 实例化RBF模型
rbf_model = RBF(d=10, D=5, k=20)

# 生成随机数
x = torch.randn(32, 10)

# 输入随机数到RBF模型中
output = rbf_model(x)

print(output.shape) # 打印输出的形状


这里实例化了一个 RBF 类的对象 model,其中 d=10 表示输入特征向量的维度为 10D=64 表示输出特征向量的维度为 64k=5 表示使用 5 个均匀分布的间隔将特征向量划分为若干个小区间。


附录-详细解释

以上代码实现了 Random Binning Feature (RBF) 方法,用于将高维输入数据映射到低维特征空间中。RBF 通过将输入空间分成多个小区间,并使用随机权重将每个小区间映射到低维特征空间中,从而实现降维的目的。

该代码实现了一个名为 RBFPyTorch 模块,其构造函数接受三个参数:d,表示输入数据的维度;D,表示映射到的低维特征空间的维度;k,表示每个输入维度被划分成的小区间数量。

在构造函数中,首先调用了父类的构造函数,然后将输入参数保存在类的属性中。接着,生成了一组小区间 bins,其中每个小区间都是一个从 -1 到 1 的等间隔序列,序列长度为 k。这些小区间用于将输入数据映射到低维特征空间中。

接下来,调用 generate_weights() 方法生成一个随机权重张量 r,该张量的形状为 (k, d, D)。该张量用于将每个小区间映射到低维特征空间中。

forward() 方法中,首先获取输入数据的 batch_size,然后将输入数据的形状从 (batch_size, d) 变为 (batch_size, d, 1, 1),并将其复制 k 次得到形状为 (batch_size, d, k, 1) 的张量 x

将小区间 bins 变形为形状为 (1, d, k, 1) 的张量,并将其复制 batch_size 次得到形状为 (batch_size, d, k, 1) 的张量 bins

使用 xbins 进行比较运算,得到一个布尔型张量,表示每个输入数据点位于哪个小区间中。将这个张量转换为浮点型张量,即可得到相应的二进制特征。这里使用了 PyTorch 的广播机制,将 xbins 的形状进行了匹配。

使用 sum 函数对第二个维度进行求和,得到形状为 (batch_size, k, D) 的张量 x

使用 unsqueeze 函数将 x 的最后一个维度变为 (batch_size, k, D, 1),然后将 r 复制 batch_size 次得到形状为 (batch_size, k, d, D) 的张量 r

使用 repeat 函数将 xr 的维度匹配,然后对它们进行点乘操作,得到形状为 (batch_size, d, D) 的张量 x

最后返回 x,即为输入数据 xRBF 映射下得到的低维特征表示。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高山莫衣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值