首先可以通过很多个conv,将imag弄成feature map:(N, C, H, W)
,然后reshape + transpose
弄成(N, H*W, C)
。
此后需要构建Q
, K
, V
矩阵,其实就是nn.Linear
层:
nn.Linear(C, C)
通过全连接层之后实现channel
之间的信息传递,得到三个矩阵:
q
(
x
)
:
(
N
,
H
∗
W
,
C
)
,
k
(
x
)
:
(
N
,
H
∗
W
,
C
)
,
v
(
x
)
:
(
N
,
H
∗
W
,
C
)
。
q(x): (N, H*W, C), k(x): (N, H*W, C), v(x): (N, H*W, C)。
q(x):(N,H∗W,C),k(x):(N,H∗W,C),v(x):(N,H∗W,C)。
对于multi-head
,将
q
(
x
)
,
k
(
x
)
,
v
(
x
)
q(x), k(x), v(x)
q(x),k(x),v(x)分别进行reshape:
q
′
(
x
)
:
(
N
,
H
∗
W
,
n
u
m
_
h
e
a
d
s
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
q'(x): (N, H*W, num\_heads, size\_per\_head)
q′(x):(N,H∗W,num_heads,size_per_head)
k
′
(
x
)
:
(
N
,
H
∗
W
,
n
u
m
_
h
e
a
d
s
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
k'(x): (N, H*W, num\_heads, size\_per\_head)
k′(x):(N,H∗W,num_heads,size_per_head)
v
′
(
x
)
:
(
N
,
H
∗
W
,
n
u
m
_
h
e
a
d
s
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
v'(x): (N, H*W, num\_heads, size\_per\_head)
v′(x):(N,H∗W,num_heads,size_per_head)
进行(0, 2, 1, 3)的transpose:
q
′
′
(
x
)
:
(
N
,
n
u
m
_
h
e
a
d
s
,
H
∗
W
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
q''(x): (N, num\_heads, H*W, size\_per\_head)
q′′(x):(N,num_heads,H∗W,size_per_head)
k
′
′
(
x
)
:
(
N
,
n
u
m
_
h
e
a
d
s
,
H
∗
W
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
k''(x): (N, num\_heads, H*W, size\_per\_head)
k′′(x):(N,num_heads,H∗W,size_per_head)
v
′
′
(
x
)
:
(
N
,
n
u
m
_
h
e
a
d
s
,
H
∗
W
,
s
i
z
e
_
p
e
r
_
h
e
a
d
)
v''(x): (N, num\_heads, H*W, size\_per\_head)
v′′(x):(N,num_heads,H∗W,size_per_head)
然后在每个head
上进行q(x)*k(x).transpose(-1, -2)*v(x)
。
(
N
,
n
u
m
_
h
e
a
d
s
,
H
∗
W
,
H
∗
W
)
(N, num\_heads, H*W, H*W)
(N,num_heads,H∗W,H∗W)$。
每个feature map上的attention score得size为(H*W, H*W)
, 然后除以
s
i
z
e
_
p
e
r
_
h
e
a
d
size\_per\_head
size_per_head。
再用attention_score(N, num_heads, H*W, H*W)
乘以v(x) (N, num_heads, H*W, size_per_head)
。得到的结果为: (N, num_heads, H*W, size_per_head)
。
相当于在H*W
上进行attention。
最后还原: (N, H*W, num_heads, size_per_head)
-> (N, H*W, C)
。
attention完成之后进行还原,直接transpose然后reshape成(N, C, H, W)
。
实现multi-head attention
时不是循环的计算每个头,而是通过 transposes and reshapes,用矩阵乘法来完成的。
In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors.
hidden_size (d) = num_attention_heads (m) * attention_head_size (a)
。
也即 d=m*a
。
并将 num_attention_heads 维度transpose到前面,使得Q和K的维度都是(m,n,a)。这样点积可以看作大小为(m,n,a)和(m,a,n)的两个张量相乘,得到一个(m,n,n)的矩阵,其实就相当于(n,a)和(a,n)的两个矩阵相乘,做了m次。
timm实现的代码如下:
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
其中sequence length
为 H*W
,而hidden_size
为channel
的数目。
image embedding实现:
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x