两种位置嵌入方式——nn.Embedding和使用正/余弦函数

通过nn.Embedding学习的位置嵌入:

class Embedding(nn.Module):  
    def __init__(self):  
        super().__init__()  
        # 初始化词嵌入层,vocab_size是词汇表大小,d_model是嵌入维度  
        self.tok_embed = nn.Embedding(vocab_size, d_model)  
        # 初始化位置嵌入层,maxlen是序列的最大长度,d_model是嵌入维度  
        self.pos_embed = nn.Embedding(maxlen, d_model)   
        self.norm = nn.LayerNorm(d_model)  
    def forward(self, x):  
        seq_len = x.size(1)  # 获取序列长度  
        # 创建一个从0到seq_len-1的等差数列,表示位置信息  
        pos = torch.arange(seq_len, dtype=torch.long)  
        # 将位置信息扩展为与输入数据相同的形状,即[batch_size, seq_len]  
        pos = pos.unsqueeze(0).expand_as(x)  
        # 生成词嵌入和位置嵌入,并将它们相加得到最终的嵌入向量  
        embedding = self.tok_embed(x) + self.pos_embed(pos)  
        # 对嵌入向量进行层归一化,并返回结果  
        return self.norm(embedding)  

优势

  1. 灵活性:这种方法允许模型通过训练来学习最适合任务的位置表示。嵌入层可以根据训练数据中的上下文信息来优化位置嵌入,从而更好地适应特定的任务和数据集。

  2. 任务特定性:由于嵌入是通过训练得到的,因此它们可以捕获与特定任务最相关的位置信息,这可能在某些应用中提供优势。

劣势

  1. 参数数量:每个位置都需要一个独立的嵌入向量,这增加了模型的参数数量。对于非常长的序列,这可能会成为一个问题。

  2. 长度限制:使用nn.Embedding通常意味着需要预先定义一个位置索引的上限,因此可能不适合处理任意长度的序列。如果序列长度超过预定义的上限,就需要进行截断或填充。

  3. 泛化能力:对于未在训练数据中出现过的位置索引,模型可能无法很好地泛化。

使用正弦与余弦函数的位置嵌入:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)  
        pe = torch.zeros(max_len, d_model)  
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))  
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]  
        return self.dropout(x)

优势

  1. 无参数:这种方法不需要额外的训练参数,位置嵌入是通过数学公式直接计算得到的。

  2. 任意长度:可以生成任意长度的位置嵌入,非常适合处理长度可变的序列。

  3. 相对位置信息:正弦和余弦函数的周期性使得模型能够捕获序列中的相对位置信息。

劣势

  1. 固定性:与通过训练学习的嵌入相比,这种方法提供的位置嵌入是固定的,可能无法根据特定任务进行调整。

  2. 任务无关性:由于嵌入是预先定义的,它们可能不包含与特定任务密切相关的信息。

  3. 可能的信息冗余:对于较短的序列,使用这种方法可能会引入一些不必要的信息冗余,因为嵌入是基于任意长度的序列设计的。

 使用 nn.Embedding 的方法可能需要更多的训练数据来学习有效的位置表示,但它允许模型自由地学习最适合任务的位置编码;而正弦和余弦位置嵌入则提供了一种无需额外训练参数的、固定的位置编码方式,它对于长度可变的序列特别有用。

在实际应用中,如果序列长度固定且不太长,可以尝试使用 nn.Embedding 。如果处理的是长度可变的序列,或者想要减少模型的数量,那么使用正弦和余弦函数生成位置嵌入可能更为合适。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值