VIT(Vision Transformer)【超详细 pytorch实现】

  • CNN 的局限性:传统的 CNN 通过局部卷积核提取特征,虽然可以通过堆叠多层卷积扩大感受野,但仍然依赖于局部信息的逐步聚合,难以直接建模全局依赖关系。

  • ViT 的优势:ViT 使用自注意力机制(Self-Attention),能够直接捕捉图像中所有 patch(图像块)之间的全局关系。这种全局建模能力在处理需要长距离依赖的任务(如图像分类、目标检测)时表现更好。

全流程

  1. 图像预处理+分块

    1. 图像尺寸标准化,如(224*224)

    2. 分块操作,将图像分为等大小非重叠的块(Patch)

  2. 块嵌入(Patch embedding)

    1. 展平成一维,然后过一个Linear映射成d_model,或者直接卷积操也能达到相同效果。

    2. 类别标记,[class Toekn] 在块嵌入序列前添加可学习的类别标记,用于最终的分类任务 【这个比如说分成了196个bacth,那么此时的张量形状应该是(batch_size,patch_num,patchH*patchW)】,这个加类别相当于patch_num+1,作为类别标记

  3. 位置编码,将位置嵌入添加到嵌入前边

  4. Transformer Encoder处理

    1. 编码器堆叠,含类别标记的嵌入经过多个transformer Encoder处理,出来一个增强表示

    2. transformer里边的MLP通常拓展4倍维度

  5. 分类输出:仅仅取序列中的类别标记作为高层特征() 输入MLP Head,输出结果

图像分patch方法

认为一个图像可以分成若干个相等大小的块,比如3x224x224格式的图像,如果patch大小为14x14,则分为256个patch,这个时候每个patch展平就是长度为196的vector。

对patch做embedding

相当于对wordVector做embedding,之后一串embedding序列传递给transformer的Encoder就可以得到全局表示

class PatchEmbeddings(nn.Module):
    def __init__(self,d_model,pacth_size,in_channels):
        super().__init__()
        self.conv=nn.Conv2d(in_channels,d_model,pacth_size,stride=pacth_size)

    def forward(self,x):
        x=self.conv(x)
        batch_size,c,h,w=x.shape
        x=x.permute(2,3,0,1)
        x=x.view(h*w,batch_size,c)
        # (h*w,batch_size,d_model)
        return x

位置编码

import torch

class LearnedPositionalEmbeddings(nn.Module):
    def __init__(self,d_model,max_len=5000):
        super().__init__()
        self.positional_encodings=nn.Parameter(torch.zeros(max_len,1,d_model),requires_grad=True);
    def forward(self,x):
        pe=self.positional_encodings[:x.shape[0]]
        return x+pe

分类器

class ClassficationHead(nn.Module):
    def __init__(self,d_model,n_hidden,n_classes):
        super().__init__()
        self.linear1=nn.Linear(d_model,n_hidden)
        self.act=nn.ReLU()
        self.linear2=nn.Linear(n_hidden,n_classes)

    def forward(self,x):
        x=self.act(self.linear1(x))
        x=self.linear2(x)
        return x

VIT总体

class VisionTransformer(nn.Module):
    def __init__(self,transformer_layer,n_layers,pacth_emb,pos_emb,classification):
        super().__init__()
        self.pacth_emb=pacth_emb
        self.pos_emb=pos_emb
        self.classification=classification
        self.transformer_layer = nn.ModuleList([transformer_layer for _ in range(n_layers)])          #n个transformer layer
        self.cls_token_emb=nn.Parameter(torch.randn(1,1,transformer_layer.size),requires_grad=True)
        self.ln=nn.LayerNorm([transformer_layer.size])

    def forward(self,x):
        x=self.pacth_emb(x)
        cls_token_emb=self.cls_token_emb.expand(-1,x.shape[1],-1)
        x=torch.cat([cls_token_emb,x])

        x=self.pos_emb
        for layer in self.transformer_layer:
            x=layer(x=x,mask=None)
        x=x[0]
        x=self.ln(x)
        x=self.classification(x)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值