Attention模块的介绍
本质上就是加权平均,权重由相似度来决定
Scaled Dot-Product Attention
代码未动,公式先行
S
c
a
l
e
d
D
o
t
P
r
o
d
u
c
t
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
⋅
K
T
d
k
)
⋅
V
ScaledDotProductAttention(Q, K, V) = softmax(\frac{Q \cdot K^T} {\sqrt{d_k}}) \cdot V
ScaledDotProductAttention(Q,K,V)=softmax(dkQ⋅KT)⋅V
代码来了
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.scale = torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
attention = torch.matmul(query, key.transpose(-2, -1)) / self.scale
attention_score = F.softmax(attention, dim=-1)
new_x = torch.matmul(attention_score, value)
return new_x
Q1: 计算softmax时为什么需要除以
d
l
\sqrt{d_l}
dl
为了保持输入softmax的数的分布符合
N
(
0
,
1
)
N(0, 1)
N(0,1)分布,这样能防止当某个
q
i
k
i
q_i k_i
qiki稍微较大时,softmax的值就远大于其他值,导致其他值趋近于0,梯度也趋近于0,不利于模型的收敛。
Q2: attention时间复杂度是多少?
attention:每个query需要和n个key计算相似度,每次相似度计算为o(d), 所以时间复杂度为O(
d
⋅
n
2
d \cdot n^2
d⋅n2)
softmax: O(
n
2
n^2
n2)
new_x:O(
d
⋅
n
2
d \cdot n^2
d⋅n2)
总共为O(
d
⋅
n
2
d \cdot n^2
d⋅n2)
MultiHeadAttention
代码未动,公式先行
M
u
l
t
i
H
e
a
d
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
c
o
n
c
a
t
(
h
e
a
d
0
,
h
e
a
d
1
,
h
e
a
d
2
,
⋯
h
e
a
d
n
)
⋅
W
o
MultiHeadAttention(Q, K, V) = concat(head_0, head_1, head_2, \cdots head_n) \cdot W^o
MultiHeadAttention(Q,K,V)=concat(head0,head1,head2,⋯headn)⋅Wo
代码来了
import torch
import torch.nn as nn
import torch.nn.functional as nn
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, n_head=8):
super().__init__()
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.n_head = n_head
self.input_dim = input_dim
self.scale = torch.sqrt(torch.tensor(self.input_dim // self.n_head, dtype=torch.float32))
def forward(self, x):
# x shape: (B, N, D)
B, N, D = x.shape
# Splitting the embedding into n_head heads with reduced dimension D // n_head
# (B, N, D) -> (B, N, D) -> (B, N, self.n_head, D // self.n_head) -> (B, self.n_head, N, D // self.n_head)
queries = self.query(x).view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
keys = self.key(x).view(B, N, self.n_head, D // self.n_head).transpose(2, 1)
values = self.value(x).view(B, N, self.n_head, D // self.n_head).transpose(2, 1)
# queries, keys, values shape is (B, self.n_head, N, D // self.n_head)
attention_scores = F.softmax(torch.matmul(queries, keys.transpose(-2, -1)) / self.scale, dim=-1)
updated_values = torch.matmul(attention_scores, values).transpose(2, 1).contiguous().view(B, N, D)
return updated_values
Question 1: 在pytorch中,view和reshape的区别是什么?
Answer1: 首先两者都可以改变张量的形状,view需要张量数据在内存中是连续的,reshape会自动检测张量是否连续,因此view相比于reshape更快,但是reshape相比于view更安全。 在确定张量数据在内存中是连续的情况下,优先使用view.
Question 2: transpose(1, 2)和transpose(2, 1)有什么区别吗?
Answer2: 结果没有任何区别,一个是维度1和维度2进行交换,一个是维度2和维度1进行交换,但交换后的最终结果是一样的。
Question 3: 为什么需要把queries,keys,values进行transpose(2, 1)?
Answer3: 这是为了把多头这个维度往前移,使得头和头之间的计算不会受到影响
Positional Encoding
由于attention模块forward更新feature时,是和所有点feature进行相似性计算得到权重,然后加权平均. 没有利用点feature和点feature之间的位置关联性。所以引入了positional encoding来增加sequence内点的position的信息。为了方便计算,每个点的位置信息都用 d m o d e l d_{model} dmodel维度的特征来表示。有很多方法来表示这个posiional encoding。
sine和cosine functions:
P
E
(
p
o
s
,
2
i
)
=
s
i
n
(
p
o
s
/
1000
0
(
2
i
/
d
m
o
d
e
l
)
)
PE_{(pos, 2i)} = sin(pos / 10000^{(2i/d_{model})})
PE(pos,2i)=sin(pos/10000(2i/dmodel))
P
E
(
p
o
s
,
2
i
+
1
)
=
c
o
s
(
p
o
s
/
1000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos, 2i + 1)} = cos(pos / 10000^{2i / d_{model}})
PE(pos,2i+1)=cos(pos/100002i/dmodel)
pos代表的是该点的位置,i代表第i个channel。
上代码
import math
import torch
import torch.nn as nn
class PositionEmbeddingSine(nn.Module):
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature=temperature
self.normalize = normalize
self.scale = scale
def forward(self, x):
B, C, H, W = x.shape
mask = torch.zeros([B, H, W], dtype=torch.bool, device=x.device)
non_mask = ~mask
# 计算 feature map的pos, 有x,y shape均为[B, H, W], 应该也可以用mesh_grid来计算
# ------> x
# '
# '
# y
y_embed = non_mask.cumsum(1, dtype=torch.float32)
x_embed = non_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = (x_embed - 0.5) / (x_embed[:, -1:, :] + eps) * self.scale
# 计算dim_t, shape为[self.num_pos_feats]
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature **( 2 * (dim_t // 2) / self.num_pos_feats)
# 每个pos_xi都应该生成num_pos_feats,最终维度为[B, num_pos_feats, H, W]; pos_yi也是
# 利用广播机制
pos_x = pos_x[:, :, :, None] / dim_t
pos_y = pos_y[:, :, :, None] / dim_t
# sin和cos函数, 这操作挺6的
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
return torch.cat((pos_x, pos_y), dim=-1).permute(0, 3, 1, 2)
Transformer架构
Transformer是基于自注意力机制的深度学习模型。在NLP领域,传统的顺序模型(RNN)在捕获远程依赖性和实现并行计算方面存在局限性。为了解决这些问题,Transformer模型引入了自注意力机制,通过广泛使用该机制,模型能够在生成输出时权衡输入序列中不同位置的重要性。
Transformer模型通过自注意力机制和并行计算的优势,能够更好地处理长距离依赖关系。Transformer模型主要有Transformer Encoder和Transformer Decoder两部分组成, 分别负责编码输入序列和解码生成输出序列。Self-Attention和Cross-Attention是Transformer中的关键组件,用于实现序列建模和特征提取。
Transformer Encoder
- 编码输入序列: Encoder接收输入序列,通过layer对输入序列进行编码,基于注意力机制捕捉序列中上下文信息和特征表示。
- 建模上下文关系: Encoder使用自注意力机制对输入序列中的每个元素进行交互,并计算每个元素与其他元素的相关性得分。使得Encoder能够考虑上下文信息的同时,将重要的相关信息聚焦到每个元素上,从而更好地理解输入序列的整体语义和上下文关系
总体上就是Transformer Encoder接受输入序列,通过自注意力机制让序列中每个元素和其他元素与其他元素进行交互,以此获得上下文信息同时获得更丰富准确的特征信息,完成对输入序列进行特征编码过程。
上代码
import copy
import torch
import torch.nn
class TransformerEncoder(nn.Module):
def __init__(self, num_layers, d_model, nhead):
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.encoder_layer = TranformerEncoderLayer(d_model, nhead)
self.layers = nn.ModuleList([copy.deepcopy(self.encoder_layer) for i in range(self.num_layers)])
def forward(self, x):
for i in range(self.num_layers)
x = self.layers[i](x)
return x
class TranformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0, activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedfoward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2= self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src -= src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout(src2)
return src
def forward(self, src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
Transformer Decoder
持续更新
reference:
- https://arxiv.org/pdf/1706.03762