Transforme中的位置嵌入模块

原理参考:

Transformer原理及Pytorch代码实现 - 知乎 (zhihu.com) 

class PositionalEncoding(nn.Module):

    def __init__(self, dropout, dim, max_len=5000):
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                              -(math.log(10000.0) / dim)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0) 
        # 维度[1,max_len,dim],unsqueeze(0)在第一个维度位置添加一个新的维度
        # 是为了与emb的维度[batch_size,seq_len,dim]对应,广播机制
        super(PositionalEncoding, self).__init__()
        self.register_buffer('pe', pe)
        self.dropout = nn.Dropout(p=dropout)
        self.dim = dim

    def forward(self, emb, step=None):
        emb = emb * math.sqrt(self.dim)
        if (step):
            emb = emb + self.pe[:, step][:, None, :]

        else:
            emb = emb + self.pe[:, :emb.size(1)]
        emb = self.dropout(emb)
        return emb

    def get_emb(self, emb):
        return self.pe[:, :emb.size(1)]

解析: 

self.register_buffer('pe', pe)

# 在 PyTorch 中,self.register_buffer('pe', pe) 是一个在神经网络模块中注册张量 pe(在这个上下文中指位置编码矩阵)为模块的一部分的语句。

# 这行代码的作用是将 pe 张量注册为当前模块的一个缓冲区,并给它一个名称 'pe'。这样,pe 就可以在模块的 forward 方法或其他方法中被方便地引用和使用,同时保证了它会被正确地序列化和恢复。

# 例如,如果你有一个自定义的模块,并且你想在这个模块中存储一些不会通过梯度下降更新的中间计算结果,使用 register_buffer 是一个很好的选择。

register_buffer 方法是 nn.Module 类提供的一个函数,它允许你将一个张量(在这个例子中是 pe)注册为模块的一个缓冲区(buffer)。

缓冲区与模块的参数(parameters)类似,但有一些关键的区别:

  1. 缓冲区不会由优化器更新:当使用优化器(如 torch.optim.Adamtorch.optim.SGD)进行模型训练时,缓冲区不会像模型参数那样通过梯度下降等方法被更新。

  2. 缓冲区会被序列化:缓冲区会与模型参数一起被保存,当模型被序列化(例如使用 torch.save)时,缓冲区也会被保存到磁盘上。这意味着,当你加载一个模型时,注册的缓冲区也会被恢复到之前的状态。

  3. 缓冲区不计入参数总数:虽然缓冲区会被序列化,但它们不会计入模型的参数总数中。这在某些情况下很有用,特别是当你需要存储一些中间计算结果或用于计算的张量,但这些张量的大小远大于模型参数本身时。

emb = emb + self.pe[:, step][:, None, :]

# [:, step],step取的是序列中的第step个字符,pe[:, step]的维度是[1,dim]
# pe[:, step][:, None, :]的维度是[1,1,dim],在第二个维度位置插入一个维度

在 PyTorch 中,[:, None, :] 这个表达式用于通过 None(在 PyTorch 中通常写作 torch.newaxis)来增加一个新的维度。这实际上是利用了 Python 的切片操作符来达到增加维度的目的。

  1. : 表示选择所有的元素,无论是在第一个维度(通常是批次维度)还是最后一个维度(通常是特征或通道维度)。

  2. None(或 torch.newaxis)是一个特殊的维度值,它告诉 PyTorch 在这个位置插入一个新的维度,其大小为 1。

  3. 第三个 : 再次表示选择所有的元素。

因此,[:, None, :] 的效果是在张量的中间插入一个新的维度,其大小为 1。这个操作经常用于调整张量的维度,以便可以与其他张量进行广播(broadcasting)或者满足某些操作的维度要求。

 举例:

举个例子,假设你有一个形状为 (batch_size, seq_len, feature_dim) 的张量 x,并且你想将其转换为 (batch_size, 1, seq_len, feature_dim) 的形状,以便可以与另一个具有相同 batch_sizeseq_len 但额外有一个维度的张量进行操作。你可以使用以下方式:

import torch

# 假设 x 是形状为 (batch_size, seq_len, feature_dim) 的张量
x = torch.randn(batch_size, seq_len, feature_dim)

# 使用 [:, None, :] 增加一个新的维度
y = x[:, None, :]

# 打印 y 的形状
print(y.shape)  # 输出: (batch_size, 1, seq_len, feature_dim)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值