【扒代码】positional_encoding.py

PositionalEncodingsFixed :固定位置编码

import torch
from torch import nn

class PositionalEncodingsFixed(nn.Module):
    def __init__(self, emb_dim, temperature=10000):
        super(PositionalEncodingsFixed, self).__init__()
        # emb_dim 是嵌入的维度
        self.emb_dim = emb_dim
        # temperature 是一个超参数,用于控制位置编码的平滑度
        self.temperature = temperature

    def _1d_pos_enc(self, mask, dim):
        # 为维度 dim 生成一维位置编码
        temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
        # 温度参数用于调整编码的频率
        temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)

        enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
        # 使用 sine 和 cosine 函数生成位置编码
        enc = torch.stack([
            enc[..., 0::2].sin(), enc[..., 1::2].cos()
        ], dim=-1).flatten(-2)

        return enc

    def forward(self, bs, h, w, device):
        # 生成位置编码的前向传播
        mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
        # 为 x 方向生成一维位置编码
        x = self._1d_pos_enc(mask, dim=2)
        # 为 y 方向生成一维位置编码
        y = self._1d_pos_enc(mask, dim=1)

        # 将 x 和 y 方向的位置编码合并,并重新排列维度
        # torch.cat([y, x], dim=3) 将 y 和 x 方向的编码在第 4 维上进行合并
        # .permute(0, 3, 1, 2) 重新排列维度,得到最终的位置编码张量
        return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)

功能解释

  • PositionalEncodingsFixed 类继承自 nn.Module,是一个可以生成位置编码的 PyTorch 模块。
  • 在初始化方法 __init__ 中,传入嵌入维度 emb_dim 和温度参数 temperature,这些参数控制位置编码的生成。
  • _1d_pos_enc 方法用于生成一维位置编码。它首先计算一个温度调整的频率因子,然后使用 cumsum 函数沿指定维度累加,生成位置编码的初步表示。接着,使用正弦和余弦函数生成最终的位置编码。
  • forward 方法定义了模块的前向传播过程。它首先创建一个全零的掩码张量,然后调用 _1d_pos_enc 方法分别沿 x 方向和 y 方向生成位置编码。最后,将两个方向的位置编码合并,并使用 permute 方法重新排列维度,以匹配模型的输入要求。

整体而言,PositionalEncodingsFixed 类实现了一种固定位置编码的生成方式,这种编码可以为模型提供关于序列或网格中每个元素位置的信息,有助于模型更好地理解数据的空间结构。

temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)

  • self.temperature:是一个超参数,用于控制编码的平滑度或频率。较高的温度值会使正弦和余弦波形更加平滑,而较低的温度值会使波形更加紧凑。
  • temp.div(2, rounding_mode='floor'):这里 temp 是一个从 0 到 self.emb_dim // 2 - 1 的整数序列。div 运算符用于将每个元素除以 2,并使用 'floor' 舍入模式,这意味着结果总是向下舍入到最接近的整数。
  • 2 * (...) / self.emb_dim:这部分计算每个位置的频率比例。由于我们对 temp 进行了除以 2 的操作,所以频率会从 0 开始,线性增加到 self.emb_dim // 2
  • **:表示指数运算,使用温度参数的幂来调整频率比例,生成最终的频率因子 temp

enc = torch.stack([
    enc[..., 0::2].sin(), enc[..., 1::2].cos()
], dim=-1).flatten(-2)

    • ~mask:这是一个按位取反的操作,将掩码张量中的 0 变为 11 变为 0。这样,掩码张量中原本为 1 的位置(即不应该有值的位置)将不会影响到累加结果。
    • .cumsum(dim):沿着指定的 dim 维度对取反掩码张量进行累积求和,得到每个位置之前所有位置的累加和。
    • .float():将累积和的结果转换为浮点数。
    • .unsqueeze(-1):在最后一个维度上增加一个维度,为后续的正弦和余弦函数应用做准备。
    • / temp:将累积和的结果除以之前计算的频率因子,进行归一化。
    • torch.stack([...], dim=-1):将正弦和余弦计算的结果沿着最后一个维度堆叠起来,形成一个联合的编码张量。
    • enc[..., 0::2].sin() 和 enc[..., 1::2].cos():分别计算偶数索引位置的正弦值和奇数索引位置的余弦值。这是位置编码的典型模式,其中正弦用于偶数位置,余弦用于奇数位置。
    • .flatten(-2):在倒数第二个维度上展平张量,将 self.emb_dim 长度的向量展平,使得每个位置的编码成为一个单独的元素。

这种位置编码方式模仿了自然语言处理中的正弦和余弦位置编码,但在图像或序列数据中以二维形式应用。通过这种方式,模型可以学习到序列中每个元素的相对位置信息,从而更好地理解数据的结构。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值