通过nn.Embedding
学习的位置嵌入:
class Embedding(nn.Module):
def __init__(self):
super().__init__()
# 初始化词嵌入层,vocab_size是词汇表大小,d_model是嵌入维度
self.tok_embed = nn.Embedding(vocab_size, d_model)
# 初始化位置嵌入层,maxlen是序列的最大长度,d_model是嵌入维度
self.pos_embed = nn.Embedding(maxlen, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
seq_len = x.size(1) # 获取序列长度
# 创建一个从0到seq_len-1的等差数列,表示位置信息
pos = torch.arange(seq_len, dtype=torch.long)
# 将位置信息扩展为与输入数据相同的形状,即[batch_size, seq_len]
pos = pos.unsqueeze(0).expand_as(x)
# 生成词嵌入和位置嵌入,并将它们相加得到最终的嵌入向量
embedding = self.tok_embed(x) + self.pos_embed(pos)
# 对嵌入向量进行层归一化,并返回结果
return self.norm(embedding)
优势:
-
灵活性:这种方法允许模型通过训练来学习最适合任务的位置表示。嵌入层可以根据训练数据中的上下文信息来优化位置嵌入,从而更好地适应特定的任务和数据集。
-
任务特定性:由于嵌入是通过训练得到的,因此它们可以捕获与特定任务最相关的位置信息,这可能在某些应用中提供优势。
劣势:
-
参数数量:每个位置都需要一个独立的嵌入向量,这增加了模型的参数数量。对于非常长的序列,这可能会成为一个问题。
-
长度限制:使用
nn.Embedding
通常意味着需要预先定义一个位置索引的上限,因此可能不适合处理任意长度的序列。如果序列长度超过预定义的上限,就需要进行截断或填充。 -
泛化能力:对于未在训练数据中出现过的位置索引,模型可能无法很好地泛化。
使用正弦与余弦函数的位置嵌入:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
优势:
-
无参数:这种方法不需要额外的训练参数,位置嵌入是通过数学公式直接计算得到的。
-
任意长度:可以生成任意长度的位置嵌入,非常适合处理长度可变的序列。
-
相对位置信息:正弦和余弦函数的周期性使得模型能够捕获序列中的相对位置信息。
劣势:
-
固定性:与通过训练学习的嵌入相比,这种方法提供的位置嵌入是固定的,可能无法根据特定任务进行调整。
-
任务无关性:由于嵌入是预先定义的,它们可能不包含与特定任务密切相关的信息。
-
可能的信息冗余:对于较短的序列,使用这种方法可能会引入一些不必要的信息冗余,因为嵌入是基于任意长度的序列设计的。
使用 nn.Embedding
的方法可能需要更多的训练数据来学习有效的位置表示,但它允许模型自由地学习最适合任务的位置编码;而正弦和余弦位置嵌入则提供了一种无需额外训练参数的、固定的位置编码方式,它对于长度可变的序列特别有用。
在实际应用中,如果序列长度固定且不太长,可以尝试使用 nn.Embedding
。如果处理的是长度可变的序列,或者想要减少模型的数量,那么使用正弦和余弦函数生成位置嵌入可能更为合适。