一、目录
- 公式
- 实现
二、实现
-
公式

-
实现
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# 计算位置编码并将其存储在pe张量中
pe = torch.zeros(max_len, d_model) # 创建一个max_len x d_model的全零张量
position = torch.arange(0, max_len).unsqueeze(1)
# 计算div_term, 计算公式中10000**(2i/d_model)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
# 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).requires_grad_(False) # 在第一个维度添加一个维度,以便进行批处理
self.register_buffer('pe', pe)
# 定义前向传播函数
def forward(self, x):
# 将输入x与对应的位置编码相加
x = x + self.pe[:, :x.size(1)]
return x
if __name__ == '__main__':
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(d_model=20,max_len=100).pe
plt.plot(np.arange(100), pe[0, :100, 4:8].numpy()) #图像未看出物理含义
plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
plt.title("Positional encoding")
plt.show()
本文介绍了如何在PyTorch中实现和可视化位置编码,一个用于Transformer模型的组件,通过正弦和余弦函数为序列输入添加位置信息。作者展示了PositionalEncoding类的代码实现,并通过图形展示了不同维度的位置编码变化。
2850

被折叠的 条评论
为什么被折叠?



