Vision Transformer原理及模型学习笔记

Vision Transformer是一种加入了self-attention的Sequence to Sequence model。它包括特征提取分类两部分。

在特征提取部分,VIT所作的工作就是特征提取。特征提取部分在图片中的对应区域是Patch+Position EmbeddingTransformer Encoder

Patch+Position Embedding就是对输入进来的图片进行分块处理,每隔一定区域大小划分图片块,然后将划分后的图片块组合成序列。在获得序列信息之后,传入Transformer Encoder进行特征提取,这是Transformer特有的Mult-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作就是将提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上classtoken,classtoken会作为一个单位的序列信息一起进行特征提取,在提取的过程中,classtoken会与其他的特征进行特征交互,融合其他图片序列的特征。最终,我们利用Mult-head Self-attention结构提取特征后的classtoken进行全连接分类。

Vision Transformer源码链接:VIT源码

vit的结构

Vision Transformer的结构如下图所示。

 

数据处理

输入的图片尺寸为[H,W,C],patch size为P,则分块的数目为(H\times W)/(P\times P),然后将每一个图片块展开成一维向量,每个向量大小为P\times P\times C

设输入图片尺寸为(480,480,3),patch size为16,分块的数目为(480x480)/(16x16)=900块。每个图片块展开成一维向量,向量大小为16x16x3=768。

序列化后就获得了一个(900,768)的特征层。再加上classtoken后就是一个(901,768)的特征层。再给所有特征加上位置信息,这样网络才有区分不同区域的能力。此时生成的pos_Embedding的shape也为(901,768),代表每一个特征的位置信息。添加位置信息的方式就是生成一个(901,768)的参数矩阵,这个参数矩阵是可训练的,再把这个参数矩阵加上(901,768)的特征层就可以了。

图中0*就是casstoken。

self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches, num_features))

在上一步获得shape为(901,768)的序列信息后,将序列信息传入Transform Encoder进行特征提取,这是Transformer特有的Mult-head self-attention结果。通过自注意力机制,关注每个图片块的重要程度。

处理sequence最常用的就是RNN,它的输入是一串vector sequence,输出是另一串vector sequence。如果是singe directional的RNN,输出b_{4}时,默认a{_1}a_{2}a_{3}a_{4}都已经看过了。如果是bi-directional的RNN,输出b_{n}时,默认a{_1}a_{2}a_{3}a_{4}都已经看过了。但RNN有一个问题:不能并行化(hard to parallel)。例如在singe directional的RNN中要 算出b_{4},就必须 先看a{_1},再看a_{2},再看a_{3},再看a_{4},这个过程很难并行。

而self-attention,它的输入、输出和RNN一样,都是输入一串sequence,输出另外一串sequence,它和bi-directional一样,b_{1}b_{4} 每一个输出都看过了整个的输入sequence。但self-attention与RNN不同的是:它的 b_{1}b_{4} 每一个输出都可以并行化计算。

self-attention

如上图所示,设input是x^{1}-x^{4},这是一个sequence,x^{1}-x^{4}分别与一个矩阵W相乘得到embedding,即a^{1}-a^{4},接着这个embedding进入self-attention层,每个a^{1}-a^{4}分别乘上三个不同的transformer 矩阵W^{q}W^{k}W^{v},例如a^{1}得到q^{1}k^{1}v^{1}

 接着将每个query q都对每个key k做attention,attention就是匹配这2个向量有多接近。比如要对q^{1}k^{1}做attention,就是对这2个向量做scaled inner product,得到a_{1,1},再用q^{1}k^{2}做attention,得到a_{1,2};用q^{^{1}}k^{3}做attention,得到a_{1,3};用q^{1}k^{4}做attention,得到a_{1,4}

 做scaled inner product的方法

a_{1,i} = q^{1} \cdot k^{i}/\sqrt{d}

其中d是q、k的维度,q\cdot v的值会随着维度的增大而增大,所以要除以\sqrt{d},相当于一个归一化。

接下来就是对计算得到的所有a{_1,i}做softmax操作,得到 \widehat{a_{1,i}}

接着将\widehat{a_{1,i}}v相乘,即\widehat{a_{1,1}}v^{1}相乘,\widehat{a_{1,2}}v^{2}相乘,\widehat{a_{1,3}}v^{3}相乘,\widehat{a_{1,4}}v^{4}相乘,再把结果相加得到b_{1}。在产生b_{1}的整个过程中使用了整个sequence的资讯。如果要考虑局部的information,只需要学习相应的\widehat{a_{1,i}}= 0b_{1}就不带有那个分支的信息了。如果要考虑全局的information,只需要学习相应的\widehat{a_{1,i}}\neq 0b_{1}就带有所有分支的信息了。

同样的方法可以求出b_{2}b{_{3}}b_{4}

以上一连串的计算,self-attention layer做的事情和RNN是一样的,但是它可以并行化计算。

 

 首先输入的embedding是I=[a^{1},a^{2},a^{3},a^{4}],用IW^{q}得到Q=[q^{1},q^{2},q^{3},q^{4}],它的每一列代表一个向量q。用IW^{K}得到K=[k^{1},k^{2},k^{3},k^{4}],它的每一列代表一个向量k。用IW^{v}得到V=[v^{1},v^{2},v^{3},v^{4}],它的每一列代表一个向量v

输入矩阵I\in R[d,N]分别乘上三个不同的矩阵W_{q}W_{k}W_{v}得到3个中间矩阵QKV\in R[d,N],将K转置之后与Q相乘得到attention矩阵A\in R[N,N],代表每一个位置两两之间的attention。再将它去softmax操作之后得到\widehat{A}\in R[N,N],最后将它乘以V矩阵得到输出vect或O\in R[d,N]

 

 Multi-head Self-attention

 以2个head为例,由a^{i}生成q^{i}进一步乘以2个转移矩阵变为q^{i,1}q^{i,2},由a^{i}生成k^{i} 进一步乘以2个转移矩阵变为k^{i,1}k^{i,2}, 由a^{i}生成v^{i} 进一步乘以2个转移矩阵变为v^{i,1}v^{i,2}。接下来q^{i,1}k^{i,1}做attention,得到加权和的权重\alpha,再与v^{i,1}做加权和,得到最终的b^{i,1}(i=1,2,...,N),同理得到b^{i,2}(i=1,2,...,N)。现在有了b^{i,1}(i=1,2,...,N)\in R[d,1]b^{i,2}(i=1,2,...,N)\in R[d,1],把它们连接起来,再通过一个transformation矩阵调整维度,使之与b^{i}(i=1,2,...,N)\in R(d,1)维度一致。

 

 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入X分别传递到 2个不同的 Self-Attention 中,计算得到 2 个输出结果。得到2个输出矩阵之后,Multi-Head Attention 将它们拼接在一起 (Concat),然后传入一个Linear层,得到 Multi-Head Attention 最终的输出 Z。可以看到 Multi-Head Attention 输出的矩阵Z与其输入的矩阵X的维度是一样的。

例如在第一步进行图像分割后,我们获得的特征层为(901,768)。在施加多头的时候,我们直接对(901,768)的最后一维度进行分割,比如我们想分割成12个头,那么矩阵的shape就变成(901,12,64)。然后对(901,12,64)进行转置,将12放到前面,获得(12,901,64)的特征层。之后忽略这个12,把它和batchsize维度同等对待,只对901,64进行处理,以上就是注意力机制的过程了。

注意力机制代码如下:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C     = x.shape
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

 在完成MultiHeadSelfAttention的构建后,我们需要在其后加上两个全连接。就构建了整个TransformerBlock。

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs      = (drop, drop)

        self.fc1    = nn.Linear(in_features, hidden_features)
        self.act    = act_layer()
        self.drop1  = nn.Dropout(drop_probs[0])
        self.fc2    = nn.Linear(hidden_features, out_features)
        self.drop2  = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1      = norm_layer(dim)
        self.attn       = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2      = norm_layer(dim)
        self.mlp        = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path  = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
    #-----------------------------------------------#
    #   施加层标准化、多头注意力机制、残差结构
    #-----------------------------------------------#
        x = x + self.drop_path(self.mlp(self.norm2(x)))
    #-----------------------------------------------#
    #   施加层标准化、两次全连接、残差结构
    #-----------------------------------------------#
        return x

整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。

        self.blocks = nn.Sequential(
            *[
                Block(
                    dim         = num_features, 
                    num_heads   = num_heads, 
                    mlp_ratio   = mlp_ratio, 
                    qkv_bias    = qkv_bias, 
                    drop        = drop_rate,
                    attn_drop   = attn_drop_rate, 
                    drop_path   = dpr[i], 
                    norm_layer  = norm_layer, 
                    act_layer   = act_layer
                )for i in range(depth)#depth == 12
            ]
        )

Vision Transformer的构建代码

import math
from collections import OrderedDict
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

#--------------------------------------#
#   Gelu激活函数的实现
#   利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):
    def __init__(self):
        super(GELU, self).__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob       = 1 - drop_prob
    shape           = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor   = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_() 
    output          = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

class PatchEmbed(nn.Module):
    def __init__(self, input_shape=[480, 480], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):
        super().__init__()
        self.num_patches    = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
        self.flatten        = flatten

        self.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(num_features) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

#--------------------------------------------------------------------------------------------------------------------#
#   Attention机制
#   将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
#   然后利用 查询向量query 点乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
#   然后利用 score 点乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):
    def __init__(self, dim, num_heads=12, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv        = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C     = x.shape
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        drop_probs      = (drop, drop)

        self.fc1    = nn.Linear(in_features, hidden_features)
        self.act    = act_layer()
        self.drop1  = nn.Dropout(drop_probs[0])
        self.fc2    = nn.Linear(hidden_features, out_features)
        self.drop2  = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1      = norm_layer(dim)
        self.attn       = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2      = norm_layer(dim)
        self.mlp        = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
        self.drop_path  = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
class Conv2dReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        bn = nn.BatchNorm2d(out_channels)

        super(Conv2dReLU, self).__init__(conv, bn, relu)

class VisionTransformer(nn.Module):
    def __init__(
            self, input_shape=[480, 480], patch_size=16, in_chans=3, num_classes=2, num_features=768,
            depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU
        ):
        super().__init__()
        #-----------------------------------------------#
        #   480, 480, 3 -> 196, 768
        #-----------------------------------------------#
        self.patch_embed    = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)
        num_patches         = (480 // patch_size) * (480 // patch_size)
        self.num_features   = num_features
        self.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]
        self.old_feature_shape = [int(480 // patch_size), int(480 // patch_size)]

        #--------------------------------------------------------------------------------------------------------------------#
        #   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
        #
        #   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
        #   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
        #   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
        #--------------------------------------------------------------------------------------------------------------------#
        #   196, 768 -> 197, 768
        self.cls_token      = nn.Parameter(torch.zeros(1, 1, num_features))
        #--------------------------------------------------------------------------------------------------------------------#
        #   为网络提取到的特征添加上位置信息。
        #   以输入图片为480, 480, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
        #   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
        #--------------------------------------------------------------------------------------------------------------------#
        #   197, 768 -> 197, 768
        self.pos_embed      = nn.Parameter(torch.zeros(1, num_patches, num_features))
        self.pos_drop       = nn.Dropout(p=drop_rate)

        #-----------------------------------------------#
        #   197, 768 -> 197, 768  12次
        #-----------------------------------------------#
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.Sequential(
            *[
                Block(
                    dim         = num_features, 
                    num_heads   = num_heads, 
                    mlp_ratio   = mlp_ratio, 
                    qkv_bias    = qkv_bias, 
                    drop        = drop_rate,
                    attn_drop   = attn_drop_rate, 
                    drop_path   = dpr[i], 
                    norm_layer  = norm_layer, 
                    act_layer   = act_layer
                )for i in range(depth)
            ]
        )
        self.norm = norm_layer(num_features)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
        n_patches = (480 // 16) * (480 // 16)

        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, 768))
        self.conv_more = Conv2dReLU(
            768,
            2048,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )
    def forward_features(self, x):
        x = self.patch_embed(x)
        # cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        # x = torch.cat((cls_token, x), dim=1)
        #
        # cls_token_pe = self.pos_embed[:, 0:1, :]
        # img_token_pe = self.pos_embed[:, 1: , :]
        # print(img_token_pe.shape)
        #
        # img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)
        # print(img_token_pe.shape)
        # img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)
        # print(img_token_pe.shape)
        # img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(2)
        # print(img_token_pe.shape)
        # pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)
        # x = x.flatten(2)
        # x = x.transpose(-1, -2)
        x = x + self.position_embeddings
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return x

    def forward1(self,hidden_states):
        B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
        x = hidden_states.permute(0, 2, 1)
        x = x.contiguous().view(B, hidden, h, w)
        x = self.conv_more(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.forward1(x)
        # print(x.shape)
        return x

    def freeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = False
            except:
                module.requires_grad = False

    def Unfreeze_backbone(self):
        backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]
        for module in backbone:
            try:
                for param in module.parameters():
                    param.requires_grad = True
            except:
                module.requires_grad = True

    
def vit(input_shape=[480, 480], pretrained=False, num_classes=2):
    model = VisionTransformer(input_shape)
    if pretrained:
        model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))

    if num_classes!=1000:
        model.head = nn.Linear(model.num_features, num_classes)
    return model

  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值