概要
Transformer等模型不像循环神经网络(RNN)或长短时记忆网络(LSTM)那样具有显式的时间步顺序,因此需要一种方法来处理输入序列中的位置信息。本文列出了常见的模型及其位置编码的方法及代码实现。
Transformer中的PositionEmbedding
先来看原文中的公式
P
E
(
p
o
s
,
2
i
)
=
s
i
n
(
p
o
s
/
10000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos,2i)}=sin(pos/100000^{2i/d_{model}})
PE(pos,2i)=sin(pos/1000002i/dmodel)
P
E
(
p
o
s
,
2
i
+
1
)
=
c
o
s
(
p
o
s
/
10000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos,2i+1)}=cos(pos/100000^{2i/d_{model}})
PE(pos,2i+1)=cos(pos/1000002i/dmodel)pos
指的是token在序列中的位置,2i
对应偶数维度2i+1
对应奇数维度
代码如下
import torch
def create_1d_absolute_sincos_embedding(pos_vec, dim):
assert dim % 2 == 0, "Wrong dimension! Dimension must be even."
position_embedding = torch.zeros(pos_vec.numel(), dim, dtype=torch.float) # Initialize
omega = torch.arange(dim // 2, dtype=torch.float)
omega /= dim / 2
omega = 1. / (10000 ** omega)
out = pos_vec[:, None] @ omega[None, :]#列向量乘行向量得到矩阵
sin_emb = torch.sin(out)
cos_emb = torch.cos(out)
position_embedding[:, 0::2] = sin_emb
position_embedding[:, 1::2] = cos_emb
return position_embedding
pos_vec = torch.arange(10)
dim = 8
embedding = create_1d_absolute_sincos_embedding(pos_vec, dim)
print(embedding.shape)
特点
- 1维的
- 绝对的
- 不可学习
ViT中的PositionEmbedding
可学习的,参与梯度更新
import torch
import torch.nn as nn
def create_1d_absolute_trainable_embedding(pos_vec, dim):
position_embedding = nn.Embedding(pos_vec.numel(),dim)
nn.init.constant_(position_embedding.weight,0)
return position_embedding
特点
- 1维的
- 可学习的
SwinTransformer中的PositionEmbedding
2d的,pos由相对位置决定
先来看原文中的公式
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
/
d
+
B
)
V
Attention(Q,K,V)=SoftMax(QK^T/\sqrt{d}+B)V
Attention(Q,K,V)=SoftMax(QKT/d+B)V
import torch
import torch.nn as nn
class RelativePositionEmbedding2D(nn.Module):
def __init__(self, image_size, patch_size, dim):
super().__init__()
image_height, image_width = image_size
patch_height, patch_width = patch_size
assert image_height % patch_height == 0 and image_width % patch_width == 0
num_patches_h = image_height // patch_height
num_patches_w = image_width // patch_width
num_patches = num_patches_h * num_patches_w
self.embedding = nn.Embedding(num_patches, dim)
self.row_embeddings = nn.Embedding(num_patches_h, dim)
self.col_embeddings = nn.Embedding(num_patches_w, dim)
def forward(self, x):
b, n, _ = x.shape
row_pos = torch.arange(n // num_patches_w, device=x.device)
col_pos = torch.arange(n // num_patches_h, device=x.device)
row_embeddings = self.row_embeddings(row_pos).unsqueeze(1).expand(-1, num_patches_w, -1)
col_embeddings = self.col_embeddings(col_pos).unsqueeze(0).expand(num_patches_h, -1, -1)
relative_embeddings = row_embeddings + col_embeddings
x = x + relative_embeddings
return x
# Example usage:
image_size = (224, 224)
patch_size = (16, 16)
dim = 256
- 相对的
- 2d的
- 可学习的