原理参考:
Transformer原理及Pytorch代码实现 - 知乎 (zhihu.com)
class PositionalEncoding(nn.Module):
def __init__(self, dropout, dim, max_len=5000):
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
-(math.log(10000.0) / dim)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0)
# 维度[1,max_len,dim],unsqueeze(0)在第一个维度位置添加一个新的维度
# 是为了与emb的维度[batch_size,seq_len,dim]对应,广播机制
super(PositionalEncoding, self).__init__()
self.register_buffer('pe', pe)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, emb, step=None):
emb = emb * math.sqrt(self.dim)
if (step):
emb = emb + self.pe[:, step][:, None, :]
else:
emb = emb + self.pe[:, :emb.size(1)]
emb = self.dropout(emb)
return emb
def get_emb(self, emb):
return self.pe[:, :emb.size(1)]
解析:
self.register_buffer('pe', pe)
# 在 PyTorch 中,self.register_buffer('pe', pe) 是一个在神经网络模块中注册张量 pe(在这个上下文中指位置编码矩阵)为模块的一部分的语句。
# 这行代码的作用是将 pe 张量注册为当前模块的一个缓冲区,并给它一个名称 'pe'。这样,pe 就可以在模块的 forward 方法或其他方法中被方便地引用和使用,同时保证了它会被正确地序列化和恢复。
# 例如,如果你有一个自定义的模块,并且你想在这个模块中存储一些不会通过梯度下降更新的中间计算结果,使用 register_buffer 是一个很好的选择。
register_buffer
方法是nn.Module
类提供的一个函数,它允许你将一个张量(在这个例子中是pe
)注册为模块的一个缓冲区(buffer)。缓冲区与模块的参数(parameters)类似,但有一些关键的区别:
缓冲区不会由优化器更新:当使用优化器(如
torch.optim.Adam
或torch.optim.SGD
)进行模型训练时,缓冲区不会像模型参数那样通过梯度下降等方法被更新。缓冲区会被序列化:缓冲区会与模型参数一起被保存,当模型被序列化(例如使用
torch.save
)时,缓冲区也会被保存到磁盘上。这意味着,当你加载一个模型时,注册的缓冲区也会被恢复到之前的状态。缓冲区不计入参数总数:虽然缓冲区会被序列化,但它们不会计入模型的参数总数中。这在某些情况下很有用,特别是当你需要存储一些中间计算结果或用于计算的张量,但这些张量的大小远大于模型参数本身时。
emb = emb + self.pe[:, step][:, None, :]
# [:, step],step取的是序列中的第step个字符,pe[:, step]的维度是[1,dim]
# pe[:, step][:, None, :]的维度是[1,1,dim],在第二个维度位置插入一个维度
在 PyTorch 中,
[:, None, :]
这个表达式用于通过None
(在 PyTorch 中通常写作torch.newaxis
)来增加一个新的维度。这实际上是利用了 Python 的切片操作符来达到增加维度的目的。
:
表示选择所有的元素,无论是在第一个维度(通常是批次维度)还是最后一个维度(通常是特征或通道维度)。
None
(或torch.newaxis
)是一个特殊的维度值,它告诉 PyTorch 在这个位置插入一个新的维度,其大小为 1。第三个
:
再次表示选择所有的元素。因此,
[:, None, :]
的效果是在张量的中间插入一个新的维度,其大小为 1。这个操作经常用于调整张量的维度,以便可以与其他张量进行广播(broadcasting)或者满足某些操作的维度要求。
举例:
举个例子,假设你有一个形状为 (batch_size, seq_len, feature_dim)
的张量 x
,并且你想将其转换为 (batch_size, 1, seq_len, feature_dim)
的形状,以便可以与另一个具有相同 batch_size
和 seq_len
但额外有一个维度的张量进行操作。你可以使用以下方式:
import torch
# 假设 x 是形状为 (batch_size, seq_len, feature_dim) 的张量
x = torch.randn(batch_size, seq_len, feature_dim)
# 使用 [:, None, :] 增加一个新的维度
y = x[:, None, :]
# 打印 y 的形状
print(y.shape) # 输出: (batch_size, 1, seq_len, feature_dim)