关于Vit Transformer中depth参数的理解

  • depth: int.
    Number of Transformer blocks.

而transformer block在文中具体是指:

也即是,如果depth = 8,就是设置了8层transformer encoder

### 关于Vision Transformer (ViT) 的代码实现 在探索 Vision Transformer (ViT) 实现的过程中,可以基于简化版的PyTorch实现来了解其工作原理[^2]。下面展示了一个简化的 ViT 模型结构: ```python import torch from torch import nn, optim class PatchEmbedding(nn.Module): """将图像分割成多个patch并嵌入""" def __init__(self, img_size=224, patch_size=16, embed_dim=768): super().__init__() self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x).flatten(2).transpose(1, 2) return x class Block(nn.Module): """创建单个Transformer编码器层""" def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop_rate, batch_first=True) # ...其余初始化省略... def forward(self, x): attn_output, _ = self.attn(query=self.norm1(x), key=x, value=x) x = x + attn_output # 剩余前向传播逻辑... return x class VisionTransformer(nn.Module): """构建完整的Vision Transformer模型""" def __init__(self, img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., class_token=True): super().__init__() self.patch_embed = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim) if class_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate ) for i in range(depth)]) self.norm = nn.LayerNorm(embed_dim) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.blocks(x) x = self.norm(x) return x[:, 0] def forward(self, x): x = self.forward_features(x) return x ``` 此段代码展示了如何定义 `PatchEmbedding` 类用于处理输入图片到补丁序列转换;`Block` 类代表了每个变压器编码单元的核心组件;而整个架构则由 `VisionTransformer` 完整表示。 为了使类token参与训练,在实例化过程中加入了可学习参数 `cls_token` 并将其与patch embeddings连接在一起作为输入传递给标准的Transformer Encoder 中[^3]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值