Sinusoidal绝对位置编码

本文介绍了如何在PyTorch中实现和可视化位置编码,一个用于Transformer模型的组件,通过正弦和余弦函数为序列输入添加位置信息。作者展示了PositionalEncoding类的代码实现,并通过图形展示了不同维度的位置编码变化。
摘要由CSDN通过智能技术生成

一、目录

  1. 公式
  2. 实现

二、实现

  1. 公式
    在这里插入图片描述

  2. 实现

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        # 计算位置编码并将其存储在pe张量中
        pe = torch.zeros(max_len, d_model)                # 创建一个max_len x d_model的全零张量
        position = torch.arange(0, max_len).unsqueeze(1)
        # 计算div_term,              计算公式中10000**(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

        # 使用正弦和余弦函数生成位置编码,对于d_model的偶数索引,使用正弦函数;对于奇数索引,使用余弦函数。
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).requires_grad_(False)                  # 在第一个维度添加一个维度,以便进行批处理
        self.register_buffer('pe', pe)

    # 定义前向传播函数
    def forward(self, x):
        # 将输入x与对应的位置编码相加
        x = x + self.pe[:, :x.size(1)]
        return x

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    import numpy as np
    plt.figure(figsize=(15, 5))
    pe = PositionalEncoding(d_model=20,max_len=100).pe

    plt.plot(np.arange(100), pe[0, :100, 4:8].numpy())      #图像未看出物理含义

    plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
    plt.title("Positional encoding")
    plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值