Vision Transformer (VIT) 笔记

论文:《An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale》

源码:https://github.com/google-research/vision_transformer

Pytorch代码:pytorch_classification/vision_transformer

目录

1、背景

2、VIT与CNN的不同

3、网络结构

(1)Patch Embedding(Linear Projection of Flattened Patches)

(2)Class token

(3)Position Embedding

(4)Transformer Encoder

(5)MLP Head


1、背景

        Transformer模型在NLP领域取得了极大地成功,如何将自注意力应用在CV领域还是一个丞待解决的问题。当前采取的方法主要是将自注意力与卷积神经网络一起使用,或者将某些卷积替换为自注意力,但保持模型的整体架构不变。

        VIT证明,尽可能保留Transformer的原始结构,并将其用在视觉领域,能够在ImageNet数据集的分类任务中取得最好的性能。

2、VIT与CNN的不同

(1)局部感受野:CNN只和位置相近的区域进行运算,如果要与很远的信息进行交互,就需要加深模型,而Transformer只需要一次自注意力即可与所有位置的特征进行交互;

(2)CNN需要配合池化,不断减低特征维度,而Transformer不需要池化。

(3)Transformer包含更多的参数,需要强大的算力支持;

(4)Transformer最大的优势在于整合了CV和NLP领域,能够仅使用一套模型处理多模态任务。(而纯CV领域仍使用CNN,因为CNN速度快,且计算资源少)

(5)而VIT缺少CNN的归纳偏置(一种先验知识):

  • 局部相似性:CNN是以滑动窗口的方式在图像上进行卷积,假设图片相邻的区域有着相似的特征;
  • 平移不变性:相同的输入×卷积核=相同的输出

3、网络结构

        将图像分成patch,将patch作为序列进行输入,每个patch通过一个可学习的线性投射层(Linear Projection)后得到一个特征(token),也就是patch embedding,再加上一个位置编码(positional encoding)之后,传入Transformer Encoder得到一个输出。同时在序列最前端增加了一个可学习的class embedding(cls),因为所有的token都在和其他所有的token做交互,所以cls可以从别的embedding中学到有用的信息,从而只需要根据它的输出作为最后的判断,比如接一个分类头进行分类,最后使用交叉熵损失函数进行模型的训练。

(1)Patch Embedding(Linear Projection of Flattened Patches)

        该层的作用是将patch映射为序列化的向量表示。

  • 输入图像大小:224×224×3
  • 每个patch大小:16×16×3=768 → tokens
  • 共(224/16) × (224/16) = 14 × 14 = 196个patch
  • patch输入:X=196×768

        在此,Linear Projection通过一个全连接层/卷积层(E=768×768)对输入的patch进行特征编码,得到196×768(X·E)的输出作为patch embedding(196个长度为768的token)。

代码分析:

class PatchEmbed(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)   # (224,224)
        patch_size = (patch_size, patch_size)   # (16,16)
        self.img_size = img_size   # (224,224)
        self.patch_size = patch_size   # (16,16)
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])   # (14,14)
        self.num_patches = self.grid_size[0] * self.grid_size[1]   # 196

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape   # (1, 3, 224, 224)
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x)   # (1, 768, 14, 14)
        x = x.flatten(2)   # (1, 768, 196)
        x = x.transpose(1, 2)   # (1, 196, 768)
        x = self.norm(x)   # (1, 196, 768)
        return x

(2)Class token

        class token的设计完全是从NLP领域借鉴过来的。在视觉领域,一般是将最后一个输出的特征图做一个Global Average Pooling,再经过一个线性分类器进行输出。但为了和原始的Transformer结构保持尽可能的一致,借鉴了bert中class token的操作,它能够从其他token中学到有用的特征,作为整张图像的特征表示。

        cls token是一个大小为(1, 768)的向量,直接将其与其他tokens拼接在一起:cat[(1, 768), (196, 768)] = (197, 768)。

 代码分析:

def __init__():
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))   # (1, 1, 768)
    nn.init.trunc_normal_(self.cls_token, std=0.02)   # 使用截断正态分布初始化权重,标准差为0.02,(1, 1, 768)

def forward_features(self, x):
    # [B, C, H, W] -> [B, num_patches, embed_dim]
    x = self.patch_embed(x)  # (1, 1, 768),初始化为0
    # [1, 1, 768] -> [B, 1, 768]
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)   # (1, 1, 768)
    x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768], (1, 197, 768)

(3)Position Embedding

        位置编码用于保留输入图像块之间的空间位置信息。由于Transformer不具有CNN的先验知识,所以需要位置嵌入来编码patch tokens的位置信息,这主要是由于自注意力的扰动不变性(Permutation-invariant),即打乱tokens的顺序并不会改变结果。

        具体做法是设置一张表,表的行表示1、2、3、……序号,每一行的数据是一个长度为768的可学习的向量(197×768),然后把位置信息加入到token中(sum)。ViT 采用了标准可学习/训练的1D位置编码。

  • 无位置编码:对位置不敏感
  • 1D位置编码:考虑把 2-D 图像块视为 1-D 序列;
  • 2D位置编码:考虑图像块的 2-D 位置 (x, y);
  • 相对位置编码:考虑图像块的相对位置。

代码分析:

def __init__():
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))   # (1, 197, 768),初始化为0
    nn.init.trunc_normal_(self.pos_embed, std=0.02)   # 使用截断正态分布初始化权重,标准差为0.02,(1, 197, 768)
    self.pos_drop = nn.Dropout(p=drop_ratio)

def forward_features(self, x):
    ……
    # x: embedded patches + cls token(1, 197, 768)
    x = self.pos_drop(x + self.pos_embed)   # (1, 197, 768)

(4)Transformer Encoder

        Transformer Encoder包括L个Transformer block,每个Transformer block包含两个操作:Multi-head Attention和MLP,在进行具体操作之前,都要先经过LayerNorm,并在操作之后进行残差连接。

  • 输入特征大小:197 × 768
  • Multi-Head Attention:假设有N个头,K、Q、V:197 × (768 / N),拼接后仍然输出197×768
  • MLP:197×768 → 197 × 3072 → 197 × 768
  • L层输出特征大小:197 × 768

代码分析:

1、Transformer Encoder初始化

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):

        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

2、 Transformer block

class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        input = x   # (1, 197, 768)
        x = self.norm1(x)   # (1, 197, 768)
        x = self.attn(x)   # (1, 197, 768)
        x = self.drop_path(x)   # (1, 197, 768)
        x = x + input   # (1, 197, 768)

        input = x   # (1, 197, 768)
        x = self.norm2(x)   # (1, 197, 768)
        x = self.mlp(x)   # (1, 197, 768)
        x = self.drop_path(x)   # (1, 197, 768)
        x = x + input   # (1, 197, 768)
        return x

3、Multi-Head Attention

class Attention(nn.Module):
    def __init__(self,
                 dim,  # 768
                 num_heads=8,   # 8个头,实例化的时候是12个头
                 qkv_bias=False,
                 qk_scale=None,   # 根号dk,0.125
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        # 计算每个head的dim,直接均分操作。
        head_dim = dim // num_heads
        # 计算分母,q,k相乘之后要除以一个根号下dk。
        self.scale = qk_scale or head_dim ** -0.5
        # 直接使用一个全连接实现q,k,v。
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        # 多头拼接之后通过W进行映射,跟上面的q,k,v一样,也是通过全连接实现。
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape   # (1, 197, 768)

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x)   # (1, 197, 768*3)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)   # (1, 197, 3, 12, 64)
        qkv = qkv.permute(2, 0, 3, 1, 4)   # (3, 1, 12, 197, 64)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]   # (1, 12, 197, 64)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1))   # 矩阵乘法+转置,(1, 12, 197, 197)
        attn = attn * self.scale   # /根号dk,(1, 12, 197, 197)
        attn = attn.softmax(dim=-1)   # (1, 12, 197, 197)
        attn = self.attn_drop(attn)   #(1, 12, 197, 197)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        # print((attn @ v).shape)
        x = attn @ v   # (1, 12, 197, 64)
        x = x.transpose(1, 2)   # (1, 197, 12, 64)
        x = x.reshape(B, N, C)   # 多头拼接,(1, 197, 768)
        # print(x.shape)
        x = self.proj(x)   # (1, 197, 768)
        x = self.proj_drop(x)   # (1, 197, 768)
        return x

4、MLP

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features   # 4*768
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):   # (1, 197, 768)
        x = self.fc1(x)   # (1, 197, 3072)
        x = self.act(x)   # (1, 197, 3072)
        x = self.drop(x)   # (1, 197, 3072)
        x = self.fc2(x)   # (1, 197, 768)
        x = self.drop(x)   # (1, 197, 768)
        return x

(5)MLP Head

        从Transformer第L层输出的第0个位置的元素,也就是cls token (1, 768)中获得结果信息(使用MAP也可)。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成,但是迁移到ImageNet1K数据集或者自己的数据上时,只用一个Linear即可。

代码分析:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):

        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # 768
        self.num_tokens = 2 if distilled else 1   # 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))   # (1, 1, 768),初始化为0
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))   # (1, 197, 768),初始化为0
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)   # 使用截断正态分布初始化权重,标准差为0.02,(1, 197, 768)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)   # 使用截断正态分布初始化权重,标准差为0.02,(1, 1, 768)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # (1, 196, 768)
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)   # (1, 1, 768)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768], (1, 197, 768)
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)   # (1, 197, 768)
        x = self.blocks(x)   # (1, 197, 768)
        x = self.norm(x)   # (1, 197, 768)
        # 提取cls token对应的输出,也就是类别向量。
        if self.dist_token is None:
            # 返回所有的batch维度和第二个维度上面索引为0的数据(cls token)
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):   # (1, 2, 224, 224)
        x = self.forward_features(x)   # (1, 768)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)   # (1, 1000)

        return x

参考:

Vision Transformer详解-CSDN博客

Vision Transformer(VIT)代码分析——保姆级教程_vit github-CSDN博客

  • 10
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值