VisionTransformer(ViT)详细架构图

这是原版的架构图,少了很多东西。

这是我根据源码总结出来的详细版

在这里插入图片描述

有几点需要说明的,看架构图能看懂就不用看注释了。

(1)输入图片必须是 224x224x3 的,如果不是就把它缩放到这个尺寸。

(2)Tranformer要的是嵌入向量的序列,大概是SeqLen, HidSize形状的二维数组,然后图像是H, W, C的三维数组,想把它塞进去必须经过一步转换,这是嵌入模块做的事情。

简单来讲就是切成大小为16*16*3的片段(Patch)然后每个片段都经过一步线性映射转换为长度768的一维向量。这一步在代码中通过一个Conv2d来一次性完成。

我们的这个卷积层,包含768 个大小为16*16*3的卷积核,步长等于卷积核大小。也就是说,它相当于把图像切成16*16*3的片段,然后每个片段和每个卷积核相乘并求和得到一个值。每个片段一共产生768个值,顺序排列得到一个一维向量,就是它的嵌入向量,然后所有片段的嵌入向量再顺序排列,得到整个图片的嵌入序列,就是这样。

(3)之后会在序列开头添加一个特殊的嵌入向量,是<CLS>,这个嵌入向量没有其它意义,只代表输出的这个位置的嵌入,应该计算为整个图像的类别嵌入。

(4)之后会添加位置嵌入,不是编码,因为它是可以学习的,也就是不锁定梯度。很多 Tranformer 都是位置嵌入,因为它是锁梯度的。

(5)位置嵌入之后会有个Dropout层,在论文原图中没有,似乎很多Bert或者GPT变体都会有这个东西。

(6)之后经过 12 个 TF 块,这个块和 Bert 是一样的,没有啥魔改。

(7)TF块之后会有个LayerNorm,原图里没有,这个也是很多变体里面出现过的。

(8)因为我们要分类,或者说论文中采用分类任务,需要取类别嵌入,也就是SeqLen维度的第一个元素。

(9)之后经过一个线性+Tanh,论文里面说只有预训练时期需要这个,迁移的时候可以直接扔掉。

(10)之后是线性+Softmax,用于把类别嵌入转化成图像属于各类的概率。

### Vision Transformer (ViT) 的结构组成和设计原理 Vision Transformer (ViT)[^6] 是一种基于 Transformer 架构的视觉模型,旨在通过将图像分割为固定大小的 patches 并将其作为序列输入到 Transformer 中来进行图像分类和其他计算机视觉任务。以下是 ViT 的主要组成部分及其设计原理: #### 图像分块与嵌入 为了适应 Transformer 对序列数据的要求,ViT 将输入图像划分为多个不重叠的小块(patches),并将这些小块视为 token 序列的一部分。每个 patch 被展平并映射到一个固定的维度向量表示,这一过程通常通过线性投影实现。 接着,在所有 patch 嵌入之前加入一个可学习的位置嵌入(positional embedding)。位置嵌入的作用是让 Transformer 模型能够感知到各个 patch 在原始图像中的相对顺序[^7]。 #### 类标记(Class Token) 类似于 BERT 使用特殊标记 `[CLS]` 来代表整个句子的语义信息,ViT 在序列的第一个位置引入了一个特殊的类标记(class token)。该标记在整个 Transformer 编码过程中与其他 patch 向量一起被更新,并最终用于生成图像的全局特征表示[^8]。 #### 多头自注意力机制 Transformer 的核心组件——多头自注意力机制(Multi-head Self-Attention Mechanism)允许模型捕捉不同区域之间的关系。具体来说,对于每一个 patch 或者 class token,它会计算自己与其他所有 tokens 之间的重要性权重,从而动态调整它们对当前 token 表征的影响程度。这种灵活的关系建模能力使得 ViT 可以有效地提取复杂的模式和上下文关联[^9]。 #### Feed Forward Network 和残差连接 除了自注意层之外,每一层还包含前馈神经网络(Feed Forward Network, FFN),以及跳过连接(skip connection)和层归一化(Layer Normalization)。FFNs 提供了非线性的转换功能;而跳跃连接有助于缓解梯度消失问题并促进深层架构的学习效率提升[^10]。 #### 输出阶段 经过若干次堆叠后的 Transformer 层之后,取最后一个隐藏状态对应于 class token 部分送入全连接层完成最终的任务目标定义,如分类概率分布预测等操作[^11]。 ```python import torch from torchvision import transforms from PIL import Image from transformers import ViTModel, ViTFeatureExtractor def load_image(image_path): image = Image.open(image_path).convert('RGB') feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') inputs = feature_extractor(images=image, return_tensors="pt") return inputs['pixel_values'] model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') image_tensor = load_image("example.jpg") outputs = model(pixel_values=image_tensor) print(outputs.last_hidden_state[:, 0]) # Class token representation ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值