VGGT:Visual Geometry Grounded Transformer&&代码逐行解析(一)

本篇文章主要注重于对代码前半部分即预测头之前的代码进行解读,预测头部分原理将在之后更新

图1
  1. 代码理解

由图1.我们可以看出该VGGT模型实际为一个大模型,它的思路并不是很复杂。下面我解释代码。

models/vggt

 这是VGGT代码下models里面vggt代码的前半段,这部分主要为继承heads里面的预测头。

这一部分为vggt的核心,第一个if判断图像的形状,如果 images 是 4 维的,unsqueeze(0) 会在第 0 维(即最前面)增加一个新的维度,使 images 从形状 [batch_size, channels, height, width] 变成 [1, batch_size, channels, height, width]。这种操作通常是为了将数据批量化,使得即使只有一张图片,也可以像处理多张图片一样处理(将其看作批量大小为 1 的情况)。

第二个if为如果 query_points 不为None,且是一个 2 维张量(意味着它包含多个查询点的坐标),则通过 unsqueeze(0) 将其形状变为 [1, num_points, 2],即将其看作只有一个批次的查询点

下一行代码通过调用 self.aggregator(images),将 images 作为输入传递给 aggregator 方法。aggregator 返回两个值:

  • aggregated_tokens_list:包含多个张量的列表,这些张量为通过注意力机制处理后的图像或特征
  • patch_start_idx:表示patch tokens的开始位置,为一个整数

即对应图一中

 接下来,我们将阅读models/aggregator代码来具体了解过程。阅读深度学习这方面代码时,我们应注重于def forward,此为该方法的具体路线。

models/aggregator

在这一行代码中,传入了图像数据, 形状为 [B, S, 3, H, W],其中:B 表示批量大小(batch size),即每次输入的图像数量、S 表示序列长度(sequence length),如果是视频或图像序列,这个值就是序列中的帧数或图像数、3 是颜色通道数,代表 RGB 三个颜色通道(每张图像有三个颜色通道)、H 是图像的高度(height),即图像在垂直方向的像素数、W 是图像的宽度(width),即图像在水平方向的像素数。

这一部分代码为对图像形状进行简单处理,第一个if很简单就不多说了。

214行代码为对输入图像进行归一化处理。通过减去每个通道的均值并除以标准差,图像的像素值将被调整为与预训练 ResNet 模型兼容的范围。这种归一化可以帮助提高模型的训练效果,使得模型更加稳定和高效。而resnet_mean与renet_end数值aggregator第20行与第21行已经定义了

217行将原始图像的批次和序列长度S 合并为一个新的批次大小,使得输入到后续网络的形状为 [B * S, C_in, H, W],即将多个图像展平为一个“样本”,以便后续处理。

218行代码调用self_patch_embed函数对images进行处理,我们转而了解self_patch_embed函数作用。

models/aggregator

在aggregator第150-185行定义了此函数,可以看出进行一个判断后,进行不同操作,而patch_embed里面是否含有“conv”在aggregator第68行已经给出结果,patch_embed="dinov2_vitl14_reg"。所以我们执行171行的vit_large模型。我们转而继续了解vit_large模型

layers/vision_transformer

查看上下文可知 vit_large与其他几个vit_的不同为不同的参数,通过查看DinoVisionTransformer我们可以得知

layers/vision_transformer

当输入一组输入图像和掩码的前向传播,具体步骤如下:

  • 对于每一对图像和掩码,使用 prepare_tokens_with_masks 函数进行预处理。
  • 然后依次经过网络中的各个块(self.blocks),每个块是一个Transformer的层,可能会进行检查点操作(checkpoint),以减少内存使用。
  • 最终输出每一层的归一化后的特征,并且返回
  • "x_norm_clstoken":获取归一化后的 x 的第一个元素(通常是分类 token 或者特殊的标记)。
  • "x_norm_regtokens":获取归一化后的 x 中的前几个注册 token(可能是特定的标记或输入的部分)。
  • "x_norm_patchtokens":获取归一化后的 x 中的剩余部分,即 Patch tokens(通常是输入图像的局部信息块)。
  • "x_prenorm":保存原始的、未归一化的数据。
  • "masks":保存对应的掩码。

 当输入的是一个图像是,可以参考def forward_features。

现在我们回过头来,看models/aggregator代码

第220-221行代码判断 patch_tokens 是否是一个字典 (dict) 类型。如果是字典,则从字典中获取名为 "x_norm_patchtokens" 的键对应的值,并将这个值赋给 patch_tokens。所以,patch_tokens 将是字典中 "x_norm_patchtokens" 这一项的数据。

223行代码patch_tokens.shape 返回的是一个三维的形状信息,代码从中获取了第二和第三个维度的大小并分别赋值给 PC,忽略了第一个维度的大小。

在这两行代码里面,我们首先观察sself.camera_token与self.register_token。这两个参数也在aggregator里面有定义。而slice_expand_and_flatten函数的作用为处理形状为 (1, 2, X, C) 的专用 token,用于多帧处理:

1) 第一个位置(索引=0)仅用于第一帧。

2) 第二个位置(索引=1)用于其余的所有帧(S-1 帧)。

3) 扩展这两个位置以匹配批次大小 B。

4) 将它们拼接成 (B, S, X, C) 形状,其中每个序列的第一位置 token 是 1 个,后面跟着 (S-1) 个第二位置 token。

5) 将形状展平为 (B*S, X, C) 以进行处理。

所以camera_token最终为(B*S,1,embed_dim),register_token最终为(B*S,num_register_tokens,embed_dim)

129行中,torch.randn(1, 2, 1, embed_dim):这行代码创建了一个形状为 (1, 2, 1, embed_dim) 的张量,张量的元素是从标准正态分布(均值为 0,方差为 1)中随机生成的。nn.Parameter 是 PyTorch 中的一个特殊类型,用来定义一个模型的可训练参数。通常在神经网络中,通过 nn.Parameter 定义的变量都会被视为模型的参数,并在训练过程中参与梯度更新。 

130行同理。embed_dim与num_register_tokens在aggregator前面第49行与第53行已定义过。

这行代码的目的是将 camera_tokenregister_tokenpatch_tokens 这三个张量在列方向上拼接起来,形成一个新的张量 tokens,其形状是沿维度 1 拼接后的结果,应该为(B*S, S+num_register_tokens-1, embed_dim)

在233行中判断self.rope是否为None,而在aggregator 第70行中已经得知self.rope不为None,使用了Rotary Position Embedding(ROPE)的技术来为输入的 tokens 生成位置嵌入。

第234行则为使用self.position_getter函数来获取位置嵌入。B * S 表示批次大小和序列长度的乘积,表示每个 token 需要的位置嵌入数量。H // self.patch_size 和 W // self.patch_size 是输入图像的高度和宽度,除以 patch 的大小,用来计算位置嵌入的空间维度。device=images.device 将位置嵌入的设备设为图像所在的设备。

第236行中判断self.patch_start_idx>0,由self.patch_start_idx定义知,符合条件。将 pos 中的所有位置嵌入值加 1。这一步的目的是确保位置嵌入的偏移量,以防止特殊 token 使用位置嵌入时的冲突。创建一个形状为 (B * S, self.patch_start_idx, 2) 的全零张量,用来表示特殊 tokens(比如 camera 和 register tokens)的位置嵌入torch.zeros(...).to(images.device).to(pos.dtype) 将这个张量放到与 images 相同的设备上,并确保它与 pos 张量的数据类型一致。使用 torch.cat 函数在 dim=1 上将 pos_special(表示特殊 token 的位置嵌入)和 pos(表示普通 token 的位置嵌入)拼接起来。这样,拼接后的 pos 将包含所有 token 的位置嵌入,其中特殊 tokens 的位置嵌入为零。所以最终pos的shape为(B*S, patch_start_id+H/patch_size+W/patch_size, 2)

这一部分frame_idx 可能局部注意力机制的索引,global_idx 可能用于全局注意力机制的索引,output_list 用于存储结果或处理后的数据。

这部分代码由内外两个循环构成,外部循环第250行self.aa_block_num 表示循环的次数,假设它是一个整数,这个循环会执行 aa_block_num 次。而self.aa_block_num 在第125行有定义,读者可自行计算。

内部循环第251行的循环表示attn_type 依次取 self.aa_order 中的元素,通常这表示不同的注意力类型 "frame" 和 "global",并使用_process_frame_attention方法来进行注意力计算。现在我们关注一下_process_frame_attention

第278行代码判断 tokens形状是否为(B * S, P, C),若不是则调整为(B * S, P, C),在第244行中定义了P与C为tokens的后两个维度,所以P=S+num_register_tokens-1 C=embed_dim

第281行代码判断是否pos不是空的,并且形状为(B * S, P, 2),则调整形状为(B * S, P, 2)

第284行代码intermediates 是一个空列表,用来存储每个处理步骤的中间结果。在框架注意力中,这些中间结果通常是 tokens 在经过每个处理步骤后的状态。

第287行代码self.aa_block_size 是一个配置参数,表示在处理框架时需要处理的块的数量,在第60行中定义为1,即每次只处理一块。self.frame_blocks[frame_idx](tokens, pos=pos) 使用 frame_idx 索引从 frame_blocks 中取出一个处理块,并将 tokens 和 pos(如果存在)传递给它进行处理。frame_idx += 1 在每次处理后,递增 frame_idx,表示下一个框架块的索引。intermediates.append(tokens.view(B, S, P, C)) 将当前的 tokens 调整为 (B, S, P, C) 的形状,并将其添加到 intermediates 列表中,以便后续使用或返回。

现在我们回过头来看第264行代码。对 frame_intermediatesglobal_intermediates 中的每一对张量,沿着最后一个维度(特征维度)进行拼接,并将结果保存到 output_list 中。最终,output_list 将包含多个拼接后的张量,每个张量的形状为 [B, S, P, 2C],其中 2C 是原来 C 的两倍。

 第269-271行代码目的就是在不再需要这些变量的时候,将它们删除,以便腾出内存空间,尤其是在处理大量数据时,合理释放内存是很重要的。

最后返回output_list与 self.patch_start_idx

如有错误,请及时告知于我,谢谢。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值