对于输入(1,3,32,32)和使用 patch size 为 8 的设置,维度变化的步骤可以总结为以下过程:
-
图像块的提取:
- 使用 patch size 为 (8, 8) 将输入图像划分为图像块。
- (32//8)*(32//8)=16个图像块,对应image_sequence_length
- 维度变为(1,16,3,8,8)。
-
展平成向量:
- 将每个图像块展平为长度为
8 * 8 * 3
的向量。 - 维度变为(1,16,192)。
- 将每个图像块展平为长度为
-
加入 Class Token 和 位置嵌入:
- 添加 Class Token,将维度变为(1,17,192)。
- 添加位置嵌入(相加),最终维度仍为(1,17,192)。
-
线性投影:
- 使用线性投影将维度投影到 input_dim=1024(可变)。对应input_dim
- 维度变为(1,17,1024)。
- transformer
-
多头自注意力机制(Multi-Head Self-Attention):
- 如果默认的头数是 h(通常为 12),并且头维度(head dim)是 d_h(通常为 64),则每个头产生的输出维度为(1,17,64)。embed_dim对应d_h*h
- 线性变换得到 Q, K, V:
- 对输入进行三个线性变换,得到Q,K,V,每个变换的权重矩阵维度为 (d_model,d_h*h)即(1024->12x64=768)。h为头的个数
- Q 的维度为 (1,17,d_h*h),K 的维度为 (1,17,d_h*h),V 的维度为 (1,17,d_h*h)。
- 这可以用矩阵相乘的方式实现:Q=X⋅Wq,K=X⋅Wk,V=X⋅Wv,其中 X 是输入张量。
-
拆分为多头:
- 将得到的 Q,K,V 分别拆分为 h 个头,每个头的维度为 (1,17,d_h)。
- 这里的拆分是在最后一维上进行的。
-
注意力计算:
- 对每个头进行注意力计算,得到注意力分数。
- 注意力计算包括对 Q 和 K 进行点积,然后进行缩放操作,最后应用 softmax 函数。
- 将注意力分数乘以 V,得到每个头的注意力输出(1,17,d_h)。
-
合并多头输出:
- 将每个头的注意力输出按最后一维度进行连接(concatenate),得到多头注意力的输出。
- 多头注意力的输出的维度为 (1,17,768)。这里768=h*d_h
- 映射为(1,17,1024)
- 加残差,维度不变(1,17,1024)
- LayerNormalization 和 Feedforward:
- 对于每个位置,进行 LayerNormalization 维度不变(1,17,1024)。
- Feedforward 网络:
- 如果默认的 Feedforward 网络的中间维度(为 2048):
- 输入维度:(1,17,1024)
- 变化为:(1,17,2048)
- 输出维度:(1,17,1024)
- 如果默认的 Feedforward 网络的中间维度(为 2048):
- 加残差,维度不变(1,17,1024)
-
- (1,17,1024)全局池化或者只取cls token(1,1024)mlp输出1000类->(1,1000)
图:
更新补充代码实现
分组件实现
分别对应上图右边模块
Norm
class PreNorm(nn.Module):
def __init__(self,dim,fn):#fn指norm之后的函数 如上图,可以接mha或者mlp
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self,x):
return self.fn(self.norm(x))
Multi-Head Attention
相对于别人用einsum的写法,我的更加直白适合新手阅读
class MultiheadAttention(nn.Module):
def __init__(self,input_dim,embed_dim,num_heads=8):
#embed_dim写的是多头的总dim,即head_dim*num_heads
#你也可以修改成输入为num_heads和head_dim
super().__init__()
#self.mha = nn.MultiheadAttention(hid_dim,num_heads)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim//num_heads
self.W_q = nn.Linear(input_dim, embed_dim)
self.W_k = nn.Linear(input_dim, embed_dim)
self.W_v = nn.Linear(input_dim, embed_dim)
self.to_out = nn.Linear(embed_dim, input_dim)
def forward(self, x):
# (batch_size, seq_length, embed_dim)
# 例如32*32的图 取patchsize为8,则image_sequence_length为(32//8)*(32//8)= 16
# image_sequence_length也就是块的个数
batch_size, image_sequence_length, _ = x.shape
# 应用线性变换生成Q, K, V
Q = self.W_q(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
#实际为(batch_size, self.num_heads, image_sequence_length, self.head_dim)
K = self.W_k(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
#得到(b,num_heads,image_sequence_length,image_sequence_length)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
# 应用Softmax获取注意力权重
# 计算块与块之间的注意力
attn_weights = torch.softmax(attn_scores, dim=-1)
# 使用注意力权重加权V
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.view(batch_size, image_sequence_length, self.embed_dim)
return self.to_out(attn_output)
MLP
#ffn也就是mlp
class FeedForward(nn.Module):
def __init__(self,dim,hidden_dim,drop_out=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim,hidden_dim),
nn.GELU(),
nn.Dropout(drop_out),
nn.Linear(hidden_dim,dim),
nn.Dropout())
def forward(self,x):
return self.net(x)
整合transformer block
class Transformer(nn.Module):
def __init__(self,input_dim,depth,embed_dim,mlp_dim,num_heads) -> None:
super().__init__()
self.layers = nn.ModuleList()
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(input_dim,MultiheadAttention(input_dim,embed_dim,num_heads)),
PreNorm(input_dim,FeedForward(input_dim,mlp_dim))
]))
def forward(self,x):
for attn,ff in self.layers:
#norm->mha->+残差->norm->mlp->加残差
x = attn(x)+x
x = ff(x)+x
return x
整合VIT
class VIT(nn.Module):
def __init__(self,image_size,patch_size,num_classes,input_dim,depth,embed_dim,mlp_dim,heads,channels=3,pool='cls') -> None:
super().__init__()
assert image_size%patch_size==0
num_patches = (image_size//patch_size)*(image_size//patch_size)
patch_dim = channels*patch_size*patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
nn.Linear(patch_dim,input_dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, input_dim))
self.cls_token = nn.Parameter(torch.randn(1,1,input_dim))
self.transformer = Transformer(input_dim,depth,embed_dim,mlp_dim,heads)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(input_dim),
nn.Linear(input_dim,num_classes)
)
def forward(self,img):
x = self.to_patch_embedding(img) # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
b,n,_ = x.shape #(b,16,dim)
cls_tokens = repeat(self.cls_token,'() n d -> b n d',b=b)
x = torch.cat((cls_tokens,x),dim=1) # 将cls_token拼接到patch token中去(b,17,dim)
x +=self.pos_embedding[:,:(n+1)]#利用广播机制直接相加(b, 17, dim)
x =self.transformer(x)#(b, 17, dim)
x = x.mean(dim=1) if self.pool =='mean' else x[:,0]# (b, dim)
#可以看到,如果指定池化方式为'mean'的话,则会对全部token做平均池化,然后全部进行送到mlp中,但是我们可以看到,默认的self.pool='cls',也就是说默认不会进行平均池化,而是按照ViT的设计只使用cls_token,即x[:, 0]只取第一个token(cls_token)。
x = self.to_latent(x) # Identity (b, dim)
#print(x.shape)
return self.mlp_head(x)
测试
model_vit = VIT(
image_size = 32,
patch_size = 8,
num_classes = 1000,
input_dim = 1024,
embed_dim = 768,
depth = 6,
heads = 12,
mlp_dim = 2048,
)
img = torch.randn(1, 3, 32, 32)
preds = model_vit(img)
print(preds.shape) # (1, 1000)