1 Transformer的基本结构
Transformer可以分为:编码过程和解码过程。NLP中的Transformer对词进行Embedding,Vision Transformer中将图像进行分块(Patch),将每块进行展平(Flatten),对Patch进行Embedding。
Embedding:从低维到高维或者从高维到低维空间的映射。
2 Patch Embedding 简单实现
import torch
import torch.nn as nn
import cv2
class MLP(nn.Module):
def __init__(self, embed_dim, mlp_ratio=4.0, drop_out=0):
super().__init__()
self.fc1 = nn.Linear(embed_dim, int(embed_dim*mlp_ratio))
self.fc2 = nn.Linear(int(embed_dim*mlp_ratio), embed_dim)
self.act = nn.ReLU()
self.dropout = nn.Dropout(drop_out)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.act(x)
return x
class patch_embedding(nn.Module):
def __init__(self, patch_size, in_channels, embed_dim, drop_out=0):
super().__init__()
self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.dropout = nn.Dropout(drop_out)
def forward(self, x):
# x:[1,1,28,28] -> x:[1, embed_dim, 28/patch_size, 28/patch_size]
x = self.patch_embed(x)
x = x.flatten(2) # x: [1, embed_dim, (28/patch_size) * (28/patch_size)]
x = x.transpose(2,1)
x = self.dropout(x)
return x
# 读取一张图片
input = torch.randn((28,28,1))
print(input.shape)
input = torch.as_tensor(input)
input = input.reshape((1,1,28,28))
# patch embedding
patch_embed = patch_embedding(patch_size=7, in_channels=1, embed_dim=8)
out = patch_embed(input)
print(out.size())
# MLP
mlp = MLP(8)
out = mlp(out)
print(out.size())
torch.Size([28, 28, 1])
torch.Size([1, 16, 8])
torch.Size([1, 16, 8])
3 注意力机制
在单个序列中使用不同位置的注意力用于实现该序列的表征方法 ------ 《Attention is All You need》
更详细的讲解推荐详解Transformer,这里只是对有一些有疑问的地方解释。
Attention的计算方法,整个过程可以分成7步:
1.将输入单词转化成嵌入向量;
2.根据嵌入向量得到q,k,v三个向量;
3.为每个向量计算一个score:
s
c
o
r
e
=
q
⋅
k
{score=q\cdot{k}}
score=q⋅k;
4.为了梯度的稳定,Transformer使用了score归一化,即除以
d
k
{\sqrt{d_k}}
dk;
5.对score施以softmax激活函数;
6.softmax点乘Value值v,得到加权的每个输入向量的评分v;
7.相加之后得到最终的输出结果z:
z
=
∑
v
{\quad{z}=\sum{v}}
z=∑v。
q分别和k相乘,可以得到s,s非常接近Attention,然后进行缩放和softmax就可以得到Attention weight
A
t
t
e
n
t
i
o
n
w
e
i
g
h
t
(
Q
,
K
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
\mathrm{Attention\ weight}(Q,K)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})
Attention weight(Q,K)=softmax(dkQKT)
为什么缩放
d
k
(
q
或
k
的长度
)
{\sqrt{d_k}}(q或k的长度)
dk(q或k的长度)?
Variance(var)表示什么?序列的波动。
序列var越大,那么经过softmax越容易偏向大值假设序列(feature)Q和K每一位都是iid,并且是random variable(std=1,mean=0)。
那么
Q
∗
K
T
{Q*K^T}
Q∗KT的var就是
d
k
d_k
dk。
所以需要把var稳定回1.0。
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
Attention(Q,K,V)=softmax(dkQKT)V
Multi-Head self Attention:多个Attention一起进行决策,将多个Attention输出经过W(每个输出可信度\哪个更重要)维度转换进行输出。
4 Attention 简单实现
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, qkv_bias=False, qk_scale=None):
super().__init__()
self.num_head = num_heads
self.head_dim = embed_dim
self.all_head_dim = int(self.head_dim*num_heads)
self.qkv = nn.Linear(embed_dim, self.all_head_dim*3, bias=qkv_bias)
self.scale = self.head_dim ** -0.5 if qk_scale == None else qk_scale
self.softmax = nn.Softmax(-1)
self.proj = nn.Linear(self.all_head_dim, embed_dim)
def tanspose_multi_head(self, x):
new_shape =[x.shape[:-1][0], x.shape[:-1][1], self.num_head, self.head_dim]
x = x.view(new_shape)
x = x.transpose(1,2)
return x
def forward(self, x):
B, N, _ = x.shape
# 生成Q,K,V
qkv = self.qkv(x).chunk(3, -1) # [B, N, all_head_dim] * 3
q,k,v = map(self.tanspose_multi_head, qkv) # [B, N, num_head, head_dim]
# 计算Attention_weight
attn = q @ k.transpose(2, 3)
attn_weight = self.softmax(attn*self.scale)
# 计算Attention
out = attn_weight @ v
out = out.transpose(1,2).contiguous().view(B, N, -1)
# 多个Attention融合
out = self.proj(out)
return out
5 ViT整体的简单实现
class Encoder(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.att = Attention(embed_dim, num_heads=2)
self.att_norm = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim)
self.mlp_norm = nn.LayerNorm(embed_dim)
def forward(self, x):
h = x
x = self.att(x)
x = self.att_norm(x)
x = h + x
h = x
x = self.mlp(x)
x = self.mlp_norm(x)
x = h + x
return x
class ViT(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = patch_embedding(7, 1, 16)
layer_list = [Encoder(16) for i in range(5)]
self.encoder = nn.ModuleList(layer_list)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Linear(16, 10)
def forward(self, x):
x = self.patch_embed(x)
for encoder in self.encoder:
x = encoder(x)
x = x.transpose(2,1)
x = self.avgpool(x)
x = x.flatten(1)
x = self.head(x)
return x
# 读取一张图片
input = torch.randn((224,224,4))
print(input.shape)
input = torch.as_tensor(input)
input = input.reshape((4,1,224,224))
out = ViT().cuda()(input.cuda())
print(out.size())
torch.Size([224, 224, 4])
torch.Size([4, 10])