Vision transforme基本结构介绍

论文原文:https://arxiv.org/pdf/2010.11929.pdf

中文版知乎解读: [论文笔记] ViT - 知乎

网络基本思想:

通过transformer完成nlp和cv类网络结构的统一。  NLP网络输入序列,通过Encoder和Decoder完成对于序列信息的编码和解码。CV网络中基础为图片(假设图片大小为224×224×3),通过卷积和MLP完成图片信息的抽取和分类(假设为分类网络)。 将Transformer应用与图片上,主要就是替换掉原有的卷积层(更确切地说是通过attention的结构替换原有的Conv),完成图片信息的编码。

Attention的网络结构输入一般为NTC,其中N为batch size, T为序列长度, C为特征信息。 由此ViT网络切片将大图切分为16×16的图片块。即将N×224×224×3 转换为 N × (14 × 14) × (16×16×3)的序列输入,其中14×14可以理解为Token, 16×16×3可以理解为特征信息。

网络基本结构(摘自论文原文)

主体步骤:

1)通过Rearrange+Linear或者(Conv空洞卷积+Linear)完成图片切分和线性变换,切片是将图片切分为16×16的一个一个patch,而线性变换主要是patch转换为一维向量(一般dim和patch_dim大小一致;输出维度: N × 196 × 768。

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

2)添加cls token和位置信息,这两个变量都是可学习的。(其中cls token借鉴了Bert的思想,认为其拥有全局的所有信息,最终通过cls token进行分类。 论文中也说明不使用cls token,通过avg pooling也可以达到相同的效果)

        # cls_token shape: (N, 1, dim)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1) 
        # pos_embedding Shape: (N, 16*16+1, dim)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

3)transformer结构进行图片信息编码抽取,即Attention结构 + FFN;

attention基本公式:

多头原因:【深度学习笔记】为什么transformer(Bert)的多头注意力要对每一个head进行降维? – Sniper

FFN基本公式:

延伸阅读: 

详解Transformer (Attention Is All You Need) - 知乎

https://arxiv.org/pdf/1706.03762.pdf

The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time.

详解Transformer中Self-Attention以及Multi-Head Attention_霹雳吧啦Wz-CSDN博客_multi-head self-attention

4)MLP 结构给出分类结果。

代码:

1)网络代码

    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        # Rearrange类似进行多次slice抽取,将图片分割成对应大小的小块图片。通过linear,将特征维度映射到更高的dim维度。
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        # 生成可训练的pos_embedding和cls_token
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

2)Transformer基本结构代码

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        # 生成scale,论文里使用根号下dim值,减少QK矩阵乘后数据太大,
        # 防止softmax反向grad过小,缩放功能。
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim = -1)
        # 输入分别和Wq,Wk, Wv权值矩阵运算,生成Q,K,V
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
    def forward(self, x):

        qkv = self.to_qkv(x).chunk(3, dim = -1) 
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

引用:

Vision Transformer详解_霹雳吧啦Wz-CSDN博客_vit详解

ViT(vision transformer)原理快速入门_不精,不诚,不足以动人-CSDN博客

  • 22
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值