保姆级VIT详解输入维度变化以及对应代码实现

对于输入(1,3,32,32)和使用 patch size 为 8 的设置,维度变化的步骤可以总结为以下过程:

  1. 图像块的提取:

    • 使用 patch size 为 (8, 8) 将输入图像划分为图像块。
    • (32//8)*(32//8)=16个图像块,对应image_sequence_length
    • 维度变为(1,16,3,8,8)。
  2. 展平成向量:

    • 将每个图像块展平为长度为 8 * 8 * 3 的向量。
    • 维度变为(1,16,192)。
  3. 加入 Class Token 和 位置嵌入:

    • 添加 Class Token,将维度变为(1,17,192)。
    • 添加位置嵌入(相加),最终维度仍为(1,17,192)。
  4. 线性投影:

    • 使用线性投影将维度投影到 input_dim=1024(可变)。对应input_dim
    • 维度变为(1,17,1024)。
  5. transformer
    • 多头自注意力机制(Multi-Head Self-Attention):

      • 如果默认的头数是 h(通常为 12),并且头维度(head dim)是 d_h(通常为 64),则每个头产生的输出维度为(1,17,64)。embed_dim对应d_h*h
      • 线性变换得到 Q, K, V:
        • 对输入进行三个线性变换,得到Q,K,V,每个变换的权重矩阵维度为 (d_model​,d_h*h​)即(1024->12x64=768)。h为头的个数
        • Q 的维度为 (1,17,d_h*h​),K 的维度为 (1,17,d_h*h​),V 的维度为 (1,17,d_h*h​)。
        • 这可以用矩阵相乘的方式实现:Q=X⋅Wq​,K=X⋅Wk​,V=X⋅Wv​,其中 X 是输入张量。
      • 拆分为多头:

        • 将得到的 Q,K,V 分别拆分为 h 个头,每个头的维度为 (1,17,d_h)。
        • 这里的拆分是在最后一维上进行的。
      • 注意力计算:

        • 对每个头进行注意力计算,得到注意力分数。
        • 注意力计算包括对 Q 和 K 进行点积,然后进行缩放操作,最后应用 softmax 函数。
        • 将注意力分数乘以 V,得到每个头的注意力输出(1,17,d_h)。
      • 合并多头输出:

        • 将每个头的注意力输出按最后一维度进行连接(concatenate),得到多头注意力的输出。
        • 多头注意力的输出的维度为 (1,17,768)。这里768=h*d_h
        • 映射为(1,17,1024)
      • 加残差,维度不变(1,17,1024)
    • LayerNormalization 和 Feedforward:
      • 对于每个位置,进行 LayerNormalization 维度不变(1,17,1024)。
      • Feedforward 网络:
        • 如果默认的 Feedforward 网络的中间维度(为 2048):
          • 输入维度:(1,17,1024)
          • 变化为:(1,17,2048)
          • 输出维度:(1,17,1024)
      • 加残差,维度不变(1,17,1024)
  6. (1,17,1024)全局池化或者只取cls token(1,1024)mlp输出1000类->(1,1000)

图:

更新补充代码实现

分组件实现

分别对应上图右边模块

Norm

class PreNorm(nn.Module):
    def __init__(self,dim,fn):#fn指norm之后的函数 如上图,可以接mha或者mlp
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self,x):
        return self.fn(self.norm(x))

Multi-Head Attention

相对于别人用einsum的写法,我的更加直白适合新手阅读

class MultiheadAttention(nn.Module):
    def __init__(self,input_dim,embed_dim,num_heads=8):
        #embed_dim写的是多头的总dim,即head_dim*num_heads
        #你也可以修改成输入为num_heads和head_dim
        super().__init__()
        #self.mha = nn.MultiheadAttention(hid_dim,num_heads)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim//num_heads
        self.W_q = nn.Linear(input_dim, embed_dim)
        self.W_k = nn.Linear(input_dim, embed_dim)
        self.W_v = nn.Linear(input_dim, embed_dim)
        self.to_out = nn.Linear(embed_dim, input_dim)

    def forward(self, x):
        # (batch_size, seq_length, embed_dim)
        # 例如32*32的图 取patchsize为8,则image_sequence_length为(32//8)*(32//8)= 16
        # image_sequence_length也就是块的个数
        batch_size, image_sequence_length, _ = x.shape

        # 应用线性变换生成Q, K, V
        Q = self.W_q(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        #实际为(batch_size, self.num_heads, image_sequence_length, self.head_dim)
        K = self.W_k(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        #得到(b,num_heads,image_sequence_length,image_sequence_length)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # 应用Softmax获取注意力权重
        # 计算块与块之间的注意力
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # 使用注意力权重加权V
        attn_output = torch.matmul(attn_weights, V)
        
        attn_output = attn_output.view(batch_size, image_sequence_length, self.embed_dim)
        return self.to_out(attn_output)

MLP

#ffn也就是mlp
class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,drop_out=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim,hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(hidden_dim,dim),
            nn.Dropout())
    
    def forward(self,x):
        return self.net(x)

整合transformer block

class Transformer(nn.Module):
    def __init__(self,input_dim,depth,embed_dim,mlp_dim,num_heads) -> None:
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(input_dim,MultiheadAttention(input_dim,embed_dim,num_heads)),
                PreNorm(input_dim,FeedForward(input_dim,mlp_dim))
            ]))
    def forward(self,x):
        for attn,ff in self.layers:
            #norm->mha->+残差->norm->mlp->加残差
            x = attn(x)+x
            x = ff(x)+x
        return x

整合VIT

class VIT(nn.Module):
    def __init__(self,image_size,patch_size,num_classes,input_dim,depth,embed_dim,mlp_dim,heads,channels=3,pool='cls') -> None:
        super().__init__()
        assert image_size%patch_size==0
        num_patches = (image_size//patch_size)*(image_size//patch_size)
        patch_dim = channels*patch_size*patch_size
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim,input_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, input_dim))
        self.cls_token = nn.Parameter(torch.randn(1,1,input_dim))

        self.transformer = Transformer(input_dim,depth,embed_dim,mlp_dim,heads)
        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim,num_classes)
        )

    def forward(self,img):
        x = self.to_patch_embedding(img) # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
        b,n,_ = x.shape  #(b,16,dim)
        cls_tokens = repeat(self.cls_token,'() n d -> b n d',b=b) 
        x = torch.cat((cls_tokens,x),dim=1)  # 将cls_token拼接到patch token中去(b,17,dim)
        x +=self.pos_embedding[:,:(n+1)]#利用广播机制直接相加(b, 17, dim)
        x =self.transformer(x)#(b, 17, dim)
        x = x.mean(dim=1) if self.pool =='mean' else x[:,0]# (b, dim)
        #可以看到,如果指定池化方式为'mean'的话,则会对全部token做平均池化,然后全部进行送到mlp中,但是我们可以看到,默认的self.pool='cls',也就是说默认不会进行平均池化,而是按照ViT的设计只使用cls_token,即x[:, 0]只取第一个token(cls_token)。
        x = self.to_latent(x)             # Identity (b, dim)
        #print(x.shape)

        return self.mlp_head(x)

测试

model_vit = VIT(
        image_size = 32,
        patch_size = 8,
        num_classes = 1000,
        input_dim = 1024,
        embed_dim = 768,
        depth = 6,
        heads = 12,
        mlp_dim = 2048,
    )

img = torch.randn(1, 3, 32, 32)

preds = model_vit(img) 

print(preds.shape)  # (1, 1000)

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值