三种位置编码

概要

 Transformer等模型不像循环神经网络(RNN)或长短时记忆网络(LSTM)那样具有显式的时间步顺序,因此需要一种方法来处理输入序列中的位置信息。本文列出了常见的模型及其位置编码的方法及代码实现。

Transformer中的PositionEmbedding

先来看原文中的公式 P E ( p o s , 2 i ) = s i n ( p o s / 10000 0 2 i / d m o d e l ) PE_{(pos,2i)}=sin(pos/100000^{2i/d_{model}}) PE(pos,2i)=sin(pos/1000002i/dmodel) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 10000 0 2 i / d m o d e l ) PE_{(pos,2i+1)}=cos(pos/100000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/1000002i/dmodel)pos指的是token在序列中的位置,2i对应偶数维度2i+1对应奇数维度
代码如下

import torch

def create_1d_absolute_sincos_embedding(pos_vec, dim):
    assert dim % 2 == 0, "Wrong dimension! Dimension must be even."
    position_embedding = torch.zeros(pos_vec.numel(), dim, dtype=torch.float)  # Initialize

    
    omega = torch.arange(dim // 2, dtype=torch.float)
    omega /= dim / 2
    omega = 1. / (10000 ** omega)
    
    
    out = pos_vec[:, None] @ omega[None, :]#列向量乘行向量得到矩阵
    sin_emb = torch.sin(out)
    cos_emb = torch.cos(out)
    
    
    position_embedding[:, 0::2] = sin_emb
    position_embedding[:, 1::2] = cos_emb
    
    return position_embedding


pos_vec = torch.arange(10)  
dim = 8  
embedding = create_1d_absolute_sincos_embedding(pos_vec, dim)
print(embedding.shape)  

特点

  • 1维的
  • 绝对的
  • 不可学习

ViT中的PositionEmbedding

可学习的,参与梯度更新

import torch
import torch.nn as nn

def create_1d_absolute_trainable_embedding(pos_vec, dim):
    position_embedding = nn.Embedding(pos_vec.numel(),dim)
	nn.init.constant_(position_embedding.weight,0)
	return position_embedding

特点

  • 1维的
  • 可学习的

SwinTransformer中的PositionEmbedding

2d的,pos由相对位置决定
先来看原文中的公式 A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V Attention(Q,K,V)=SoftMax(QK^T/\sqrt{d}+B)V Attention(Q,K,V)=SoftMax(QKT/d +B)V

import torch
import torch.nn as nn

class RelativePositionEmbedding2D(nn.Module):
    def __init__(self, image_size, patch_size, dim):
        super().__init__()
        image_height, image_width = image_size
        patch_height, patch_width = patch_size
        
        assert image_height % patch_height == 0 and image_width % patch_width == 0
        num_patches_h = image_height // patch_height
        num_patches_w = image_width // patch_width
        num_patches = num_patches_h * num_patches_w
        
        self.embedding = nn.Embedding(num_patches, dim)
        
        self.row_embeddings = nn.Embedding(num_patches_h, dim)
        self.col_embeddings = nn.Embedding(num_patches_w, dim)
        
    def forward(self, x):
        b, n, _ = x.shape  
        
       
        row_pos = torch.arange(n // num_patches_w, device=x.device)  
        col_pos = torch.arange(n // num_patches_h, device=x.device)  
        
        row_embeddings = self.row_embeddings(row_pos).unsqueeze(1).expand(-1, num_patches_w, -1)
        col_embeddings = self.col_embeddings(col_pos).unsqueeze(0).expand(num_patches_h, -1, -1)
        
        relative_embeddings = row_embeddings + col_embeddings
        
        x = x + relative_embeddings
        
        return x

# Example usage:
image_size = (224, 224)
patch_size = (16, 16)
dim = 256
  • 相对的
  • 2d的
  • 可学习的
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值