【论文笔记】 VIT论文笔记,重构Patch Embedding和Attention部分

0 前言

相关链接:

  1. VIT论文:https://arxiv.org/abs/2010.11929
  2. VIT视频讲解:https://www.bilibili.com/video/BV15P4y137jb/?spm_id_from=333.999.0.0&vd_source=fff489d443210a81a8f273d768e44c30
  3. VIT源码:https://github.com/vitejs/vite
  4. VIT源码(Pytorch版本,非官方,挺多stars,应该问题不大):https://github.com/lucidrains/vit-pytorch

重点掌握:

  1. 如何将2-D的图像变为1-D的序列,操作:PatchEmbedding,并且加上learnbale embedding 和 Position Embedding
  2. Multi-Head Attention的写法,其中里面有2个Linear层进行维度变换~

VIT历史意义: 展示了在CV中使用纯Transformer结构的可能,并开启了视觉Transformer研究热潮。

1 总体代码

说明: 本文代码是针对VIT的Pytorch版本进行重构修改,若有不对的地方,欢迎交流~
原因: lucidrains的源码中调用了比较高级的封装,如einops包中的rerange等函数,写的确实挺好的,但不好理解shape的变化;
在这里插入图片描述

import torch
import torch.nn as nn

class PatchAndPosEmbedding(nn.Module):
    def __init__(self, img_size=256, patch_size=32, in_channels=3, embed_dim=1024, drop_out=0.):
        super(PatchAndPosEmbedding, self).__init__()

        num_patches = int((img_size/patch_size)**2)
        patch_size_dim = patch_size*patch_size*in_channels

        # patch_embedding, Note: kernel_size, stride
        # a
        self.patch_embedding = nn.Conv2d(in_channels=in_channels, out_channels=patch_size_dim, kernel_size=patch_size, stride=patch_size)
        self.linear = nn.Linear(patch_size_dim, embed_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))   # 添加一个cls_token用于整合信息
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, embed_dim)) # 给patch embedding加上位置信息

        self.dropout = nn.Dropout(drop_out)

    def forward(self, img):
        x = self.patch_embedding(img) # [B,C,H,W] -> [B, patch_size_dim, N, N] # N = Num_patches = (H*W)/Patch_size,
        x = x.flatten(2)
        x = x.transpose(2, 1)  # [B,N*N, patch_size_dim]
        x = self.linear(x)     # [B,N*N, embed_dim]  # patch_size_dim -> embed_dim = 3072->1024 to reduce the computation when encode.

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # cls_token: [1,1 embed_dim] -> [B, 1, embed_dim]
        x = torch.cat([cls_token, x], dim=1) # [B,N*N, embed_dim] -> [B, N*N+1, embed_dim]
        x += self.pos_embedding  # [B, N*N+1, embed_dim]  Consider why not concat , but add?  Trade off due to the computation.

        out = self.dropout(x)

        return out

class Attention(nn.Module):
    def __init__(self, dim, heads=16, head_dim=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = heads * head_dim  # 可以通过FC将 input_dim 映射到inner_dim作为注意力表示维度
        self.heads = heads
        self.scale = head_dim ** -0.5

        project_out = not (heads == 1 and head_dim == dim)

        # 构建 k,q,v,可根据VIT原项目中的rerange进行变化
        # 写法一:直接定义to_q, to_k, to_v
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        # 写法二:先定义qkv,在forward进行chunk拆开
        # self.to_qkv = nn.Linear(dim, inner_dim*3, bias = False)

        self.atten = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(nn.Linear(dim, inner_dim), nn.Dropout(dropout)) if project_out else nn.Identity()

    def forward(self, x):
        # 续上面写法1:
        q = self.to_q(x)  # [3,65,1024]
        k = self.to_k(x)  # [3,65,1024]
        v = self.to_v(x)  # [3,65,1024]

        # 续上面写法2:
        # toqkv = self.to_qkv(x)  # [3, 65, 3072]
        # q, k, v = toqkv.chunk(3, dim=-1)  # q, k, v.shape    [3,65,1024]

        q = q.reshape(q.shape[0], q.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]
        k = k.reshape(k.shape[0], k.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]
        v = v.reshape(v.shape[0], v.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]


        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        atten = self.atten(dots)
        atten = self.dropout(atten)

        out = torch.matmul(atten, v)    # [3, 16, 65, 64]
        out = out.transpose(1, 2)   #
        out = out.reshape(out.shape[0], out.shape[1], -1)   # [3, 65, 16*64]

        return self.to_out(out)

class MLP(nn.Module):  # 搭建2层FC, 使用GELU激活
    def __init__(self, dim, hidden_dim, dropout=0.):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class PreNorm(nn.Module):  # Encoder结构中先LayerNorm再进行Multihead-attention或MLP
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class Transformer(nn.Module):  # 整个Encoder结构
    def __init__(self, dim, depth, heads, head_dim, mlp_hidden_dim, dropout=0.):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList([
                    PreNorm(dim, Attention(dim, heads, head_dim, dropout=dropout)),
                    PreNorm(dim, MLP(dim, mlp_hidden_dim, dropout=dropout))
                ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class VIT(nn.Module):
    def __init__(self, num_classes=10, img_size=256, patch_size=32, in_channels=3,
                 embed_dim=1024, depth=6, heads=16, head_dim=64, mlp_hidden_dim=2048, pool='cls', dropout=0.1):
        super(VIT, self).__init__()

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.pool = pool

        self.patchembedding = PatchAndPosEmbedding(img_size, patch_size, in_channels, embed_dim, dropout)

        self.transformer = Transformer(embed_dim, depth, heads, head_dim, mlp_hidden_dim, dropout)

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
    def forward(self, x):
        x = self.patchembedding(x)
        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        out = self.mlp_head(x)

        return out


net = VIT()
x = torch.randn(3, 3, 256, 256)
out = net(x)
print(out, out.shape)

2 PatchandPosEmbedding

说明:
1.将256x256x3的图像分为32x32x3大小的patches,主要使用nn.Conv2d实现,主要是ernel_size==patch_sizestride==patch_size, 多看代码就能理解这个图了;
2.由于图像切分重排后失去了位置信息,并且Transformer的运算是与空间位置无关的,因此需要把位置信息编码放进网络,使用一个向量进行编码,即PosEmbedding;

问题:
1. 为什么要在Embedding时加上一个patch0,即代码中的cls_tocken?
原因:假设原始输出的9个向量(以图中假设),若随机选择其中一个用于分类,效果都不好。若全用的话,计算量太大;因此加上一个可学习的向量,即learnable embedding用于整合信息。
2.为什么Position Embedding是直接add,而不是concat?
原因:实际上add是concat的一种特例,而concat容易造成维度太大导致计算量爆炸,实际上,该部分的add是对计算量的一种妥协,但在论文中的Appendix部分可以看出,这种方法的定位效果还是不错的。
在这里插入图片描述
在这里插入图片描述

class PatchAndPosEmbedding(nn.Module):
    def __init__(self, img_size=256, patch_size=32, in_channels=3, embed_dim=1024, drop_out=0.):
        super(PatchAndPosEmbedding, self).__init__()

        num_patches = int((img_size/patch_size)**2)
        patch_size_dim = patch_size*patch_size*in_channels

        # patch_embedding, Note: kernel_size, stride
        # a
        self.patch_embedding = nn.Conv2d(in_channels=in_channels, out_channels=patch_size_dim, kernel_size=patch_size, stride=patch_size)
        self.linear = nn.Linear(patch_size_dim, embed_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))   # 添加一个cls_token用于整合信息
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, embed_dim)) # 给patch embedding加上位置信息

        self.dropout = nn.Dropout(drop_out)

    def forward(self, img):
        x = self.patch_embedding(img) # [B,C,H,W] -> [B, patch_size_dim, N, N] # N = Num_patches = (H*W)/Patch_size,
        x = x.flatten(2)
        x = x.transpose(2, 1)  # [B,N*N, patch_size_dim]
        x = self.linear(x)     # [B,N*N, embed_dim]  # patch_size_dim -> embed_dim = 3072->1024 to reduce the computation when encode.

        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # cls_token: [1,1 embed_dim] -> [B, 1, embed_dim]
        x = torch.cat([cls_token, x], dim=1) # [B,N*N, embed_dim] -> [B, N*N+1, embed_dim]
        x += self.pos_embedding  # [B, N*N+1, embed_dim]  Consider why not concat , but add?  Trade off due to the computation.

        out = self.dropout(x)

        return out

3. Attention

实现Attention机制,需要Q(Query),K(Key),V(Value)三个元素对注意力进行计算,实际上是对各个patches之间计算注意力值,公式为
在这里插入图片描述


class Attention(nn.Module):
    def __init__(self, dim, heads=16, head_dim=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = heads * head_dim  # 可以通过FC将 input_dim 映射到inner_dim作为注意力表示维度
        self.heads = heads
        self.scale = head_dim ** -0.5

        project_out = not (heads == 1 and head_dim == dim)

        # 构建 k,q,v,可根据VIT原项目中的rerange进行变化
        # 写法一:直接定义to_q, to_k, to_v
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        # 写法二:先定义qkv,在forward进行chunk拆开
        # self.to_qkv = nn.Linear(dim, inner_dim*3, bias = False)

        self.atten = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_out = nn.Sequential(nn.Linear(dim, inner_dim), nn.Dropout(dropout)) if project_out else nn.Identity()

    def forward(self, x):
        # 续上面写法1:
        q = self.to_q(x)  # [3,65,1024]
        k = self.to_k(x)  # [3,65,1024]
        v = self.to_v(x)  # [3,65,1024]

        # 续上面写法2:
        # toqkv = self.to_qkv(x)  # [3, 65, 3072]
        # q, k, v = toqkv.chunk(3, dim=-1)  # q, k, v.shape    [3,65,1024]

        q = q.reshape(q.shape[0], q.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]
        k = k.reshape(k.shape[0], k.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]
        v = v.reshape(v.shape[0], v.shape[1], self.heads, -1).transpose(1, 2)  # [3,65,1024]  -> [3,16,65,64]


        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        atten = self.atten(dots)
        atten = self.dropout(atten)

        out = torch.matmul(atten, v)    # [3, 16, 65, 64]
        out = out.transpose(1, 2)   #
        out = out.reshape(out.shape[0], out.shape[1], -1)   # [3, 65, 16*64]

        return self.to_out(out)
<think>好的,我现在需要理解VIT(Vision Transformer)中的Embedding层里patch的概念。首先,VIT是用于计算机视觉的Transformer模型,所以应该自然语言处理中的Transformer有相似之处,但处理的是图像数据。 我记得在NLP中,Transformer的输入是词嵌入,也就是将单词转换为向量。而图像是像素组成的,如何将图像转换成类似的向量序列呢?这时候可能就需要用到patch的概念。Patch在中文里是“补丁”或“小块”的意思,可能指的是将图像分割成小块,每个小块作为一个处理单元。 接下来,我需要详细一点。比如,假设有一张224x224的RGB图像,如果分成16x16的patch,那么每个patch的大小就是16x16x3(长、宽、通道)。这样总共有多少个patch呢?224除以16是14,所以总共有14x14=196个patch。每个patch会被展平成一个向量,16x16x3=768,所以每个patch变成一个768维的向量。这可能就是嵌入层的输入。 然后,这些展平的向量会经过线性投影(Linear Projection)生成嵌入向量,再加上位置编码(Position Embedding),因为Transformer本身没有位置信息,需要额外添加位置信息。这样每个patch就类似于NLP中的一个词,整个图像就是一个由patch组成的序列。 可能还需要考虑为什么使用patch而不是单个像素。因为单个像素的信息量太小,而较大的patch可以捕捉局部特征,同时减少序列长度。比如,224x224的图像,如果每个patch是16x16,序列长度是196,这比处理每个像素(50176个)要高效得多。 另外,VITEmbedding层可能还包括可学习的分类token(类似BERT的[CLS]),用于最终的分类任务。这个token会被添加到patch序列的最前面,经过Transformer编码后,用这个token的输出进行分类。 需要注意的是,不同的VIT变种可能使用不同的patch大小,比如更小的patch可能捕获更细粒度的特征,但会增加计算量。此外,位置编码的方式也可能不同,比如使用可学习的位置向量或者固定的正弦编码。 总结一下,patchVIT中的作用是将图像分割成小块,每个小块转换为向量,作为Transformer的输入序列,类似于NLP中的词嵌入。这样可以利用Transformer处理序列数据的优势来处理图像,同时通过局部区域的聚合平衡计算复杂度特征提取能力。 可能存在的疑问点:如何确定patch的大小?位置编码的具体实现方式?如何处理不同分辨率的图像?这些都是需要进一步了解的,但用户的问题主要是理解patch的概念,所以重点应放在分割图像生成嵌入向量的过程上。</think>在 Vision Transformer (ViT) 中,**Patch** 是理解模型如何处理图像的关键概念。以下是逐步解释: --- ### 1. **直观理解:将图像拆解为“拼图块”** - **图像的本质**:一张 $H \times W \times C$ 的图像(如 $224 \times 224 \times 3$)本质是一个三维像素矩阵。 - **Patch 的定义**:将图像均匀分割为多个**局部小块**,每个小块称为一个 Patch。 - **例如**:若 Patch 大小为 $16 \times 16$,则 $224 \times 224$ 的图像会被分为 $\frac{224}{16} \times \frac{224}{16} = 14 \times 14 = 196$ 个 Patch。 --- ### 2. **Patch 的数学表达** - **展平操作**:每个 Patch 是尺寸为 $P \times P \times C$ 的局部区域(如 $16 \times 16 \times 3$),将其展平为一维向量: $$ \text{向量长度} = P \times P \times C = 16 \times 16 \times 3 = 768 $$ - **线性投影**:通过可学习的矩阵 $E$(即 Embedding 层),将每个展平后的 Patch 映射到 $D$ 维嵌入空间: $$ z = x_{\text{patch}} \cdot E \quad (\text{其中 } E \in \mathbb{R}^{768 \times D}) $$ --- ### 3. **为何使用 Patch?** - **降低计算复杂度**:直接处理像素级输入(如 $224^2=50,176$ 像素)会导致序列过长。通过 Patch 划分,将序列长度减少至可控范围(如 196)。 - **捕获局部语义**:每个 Patch 包含局部区域的纹理、边缘等视觉特征,类似 NLP 中“词”的概念。 - **适配 Transformer**:Transformer 擅长处理序列数据,Patch 将图像转换为“视觉词序列”,使 Transformer 能直接处理图像。 --- ### 4. **与位置编码的结合** - **问题**:Transformer 本身不具备空间位置感知能力。 - **解决方案**:为每个 Patch 嵌入添加**位置编码向量** $E_{\text{pos}}$,保留其空间位置信息: $$ z = [x_{\text{class}}; \, x_{\text{patch}}^1 E; \, x_{\text{patch}}^2 E; \, \dots] + E_{\text{pos}} $$ --- ### 5. **代码示例(简化版)** ```python # 输入图像: (B, C, H, W) image = torch.randn(1, 3, 224, 224) # 分割为 Patch: (B, num_patches, P*P*C) patches = image.unfold(2, 16, 16).unfold(3, 16, 16) patches = patches.reshape(1, 196, 768) # 线性投影 + 位置编码 embedding_layer = nn.Linear(768, 512) position_embed = nn.Parameter(torch.randn(1, 197, 512)) # 包含 class token # 最终输入序列 patch_embeddings = embedding_layer(patches) input_sequence = torch.cat([class_token, patch_embeddings], dim=1) + position_embed ``` --- ### 关键总结 - **Patch 的本质**:图像到序列的桥梁,将视觉信息转换为 Transformer 可处理的“视觉词”。 - **设计意义**:平衡计算效率与局部语义提取,是 ViT 区别于 CNN 的核心设计之一。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值