WeakTr代码精细解析

class WeakTr(VisionTransformer):
    def __init__(self, depth=12, num_heads=6, reduction=4, pool="avg", 
                 embed_dim=384, AdaptiveAttentionFusion=None, 
                 feat_reduction=None, *args, **kwargs):
        super().__init__(embed_dim=embed_dim, depth=depth, num_heads=num_heads, *args, **kwargs)
        self.head = nn.Conv2d(self.embed_dim, self.num_classes, kernel_size=3, stride=1, padding=1)  
        # 添加一个卷积层作为分类头,用于将Transformer的输出转换为类别预测。
        self.avgpool = nn.AdaptiveAvgPool2d(1)  
        # 添加一个自适应平均池化层,用于减少特征图的空间维度。
        self.head.apply(self._init_weights)  
        # 应用自定义的权重初始化方法到分类头。
        num_patches = self.patch_embed.num_patches  
        # 获取模型中patch的总数。
        self.cls_token = nn.Parameter(torch.zeros(1, self.num_classes, self.embed_dim)) 
         # 创建一个类别标记(cls_token),用于表示图像的类别。
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_classes, self.embed_dim))
        # 创建位置编码,用于表示输入图像的每个patch的位置信息。

        trunc_normal_(self.cls_token, std=.02)
        # 使用截断正态分布初始化类别标记。
        trunc_normal_(self.pos_embed, std=.02)
        # 使用截断正态分布初始化位置编码。
        print(self.training)
        # 打印模型是否处于训练模式(True或False)

        aaf_params = dict(channel=depth*num_heads, reduction=reduction)
        # 创建一个参数字典,包含注意力融合模块的参数,包括通道数(channel)和缩减率(reduction)。
        if feat_reduction is not None:  # 如果提供了特征缩减参数,更新aaf_params字典
            aaf_params["feat_reduction"] = feat_reduction      
            aaf_params["feats_channel"] = embed_dim//num_heads       
            aaf_params["pool"] = pool
            
        self.adaptive_attention_fusion = AdaptiveAttentionFusion(**aaf_params)
        # 根据提供的参数创建一个AdaptiveAttentionFusion实例,这个实例将用于后续的注意力融合操作。

    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - self.num_classes
        # 计算输入特征矩阵 x 中的补丁(patches)数量,即特征矩阵的列数减去类别(classes)的数量。
        N = self.pos_embed.shape[1] - self.num_classes
        # 计算模型中预定义的位置编码矩阵 self.pos_embed 中的补丁数量,同样是列数减去类别数量,即“后n个”。
        if npatch == N and w == h:
            return self.pos_embed
           # 如果输入特征矩阵的补丁数量与模型预定义的位置编码矩阵的补丁数量相同,
           # 并且目标图像是正方形(宽度和高度相等),则不需要进行插值,直接返回预定义的位置编码。
        class_pos_embed = self.pos_embed[:, 0:self.num_classes]
        # 从位置编码矩阵中提取类别位置编码的部分。
        patch_pos_embed = self.pos_embed[:, self.num_classes:]
        # 从位置编码矩阵中提取与补丁相关的位置编码部分。
        dim = x.shape[-1]
        # 获取输入特征矩阵 x 的最后一个维度的大小,这通常代表特征的维度个数。

        w0 = w // self.patch_embed.patch_size[0]
        # 计算目标宽度 w 相对于模型中定义的补丁大小的缩放因子。
        h0 = h // self.patch_embed.patch_size[0]
        # 计算目标高度 h 相对于模型中定义的补丁大小的缩放因子。

        w0, h0 = w0 + 0.1, h0 + 0.1
        # 为了避免浮点数计算中的精度问题,给缩放因子加上一个小的数值(0.1),以确保插值时不会超出矩阵边界。
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode='bicubic',
        )
        # 对补丁位置编码进行双三次插值(bicubic interpolation),
        # 以适应目标图像的尺寸。patch_pos_embed 被重塑为一个四维张量,然后进行插值。
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        # 断言插值后的位置编码矩阵的尺寸与预期的尺寸相符,确保插值操作正确。
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        # 调整插值后的位置编码矩阵的形状,以便与类别位置编码拼接。
        # permute 方法用于改变张量的维度顺序,view 方法用于重塑张量的形状。
        return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
        # 将类别位置编码和插值后的补丁位置编码在第二个维度上拼接起来,返回最终的位置编码矩阵。

    def forward_features(self, x, n=12):
        B, nc, w, h = x.shape
        # 解构输入张量 x 的形状,其中 B 是批量大小(batch size),nc 是通道数(通常对应于类别数),w 是宽度,h 是高度。
        x = self.patch_embed(x)
        # 将输入张量 x 通过 patch_embed 层,一个卷积层,用于将输入图像分割成多个补丁(patches),
        # 并为每个补丁生成一个固定大小的嵌入表示。

        cls_tokens = self.cls_token.expand(B, -1, -1)
        # 创建一个类别嵌入(class token)的张量,并将其扩展到与输入张量 x 的批量大小 B 相匹配。
        # cls_token 是一个特殊的嵌入,用于表示整个图像。
        x = torch.cat((cls_tokens, x), dim=1)
        # 将类别嵌入 cls_tokens 和补丁嵌入 x 在第二个维度(通常是特征维度)上拼接起来,形成新的输入张量。
        x = x + self.interpolate_pos_encoding(x, w, h)
        # 对拼接后的特征矩阵 x 应用位置编码插值,以适应输入图像的尺寸。这里调用了之前解释过的 interpolate_pos_encoding 方法。
        x = self.pos_drop(x)
        # 对位置编码后的特征矩阵 x 应用位置编码的dropout,这有助于防止过拟合。
        attn_weights = []
        # 存储每一层注意力机制的权重。
        attn_feats = []
        # 存储每一层注意力机制的输出特征。

        for i, blk in enumerate(self.blocks): # 遍历模型中的所有注意力块(blocks)
            x, weights_i, feat = blk(x)
            # 对当前的输入 x 应用注意力块 blk,得到输出 x、注意力权重 weights_i 和特征 feat。
            attn_feats.append(feat)
            # 将当前注意力块的注意力权重 weights_i 添加到 attn_weights 列表中。
            attn_weights.append(weights_i)
            # 将当前注意力块的输出特征 feat 添加到 attn_feats 列表中。

        return x[:, 0:self.num_classes], x[:, self.num_classes:], attn_weights, attn_feats
        # 返回四个部分:
		# x[:, 0:self.num_classes]:类别嵌入部分,通常用于分类任务。
		# x[:, self.num_classes:]:剩余的特征部分,可能用于其他任务,如目标检测或分割。
		# attn_weights:所有注意力块的注意力权重列表。
		# attn_feats:所有注意力块的输出特征列表。

    def forward(self, x, return_att=False, attention_type='fused'):
        w, h = x.shape[2:]
        # 获取输入张量 x 的宽度 w 和高度 h。
        x_cls, x_patch, attn_weights, attn_feats = self.forward_features(x)
        # 调用 forward_features 方法来处理输入 x,
        # 得到类别嵌入 x_cls、补丁嵌入 x_patch、注意力权重 attn_weights 和注意力特征 attn_feats。
        n, p, c = x_patch.shape
        #  n 是批量大小,p 是补丁的数量,c 是通道数。
        if w != h:
        # 如果输入图像的宽度 w 和高度 h 不相等,对 x_patch 进行重塑以匹配模型的输入要求。
            w0 = w // self.patch_embed.patch_size[0]
            h0 = h // self.patch_embed.patch_size[0]
            x_patch = torch.reshape(x_patch, [n, w0, h0, c])
        else:  # 如果相等,则按照正方形矩阵的形式重塑。
            x_patch = torch.reshape(x_patch, [n, int(p ** 0.5), int(p ** 0.5), c])
        x_patch = x_patch.permute([0, 3, 1, 2])
        x_patch = x_patch.contiguous()
        # 调整 x_patch 的维度顺序,并确保张量在内存中是连续的。
        x_patch = self.head(x_patch)
        # 将重塑后的 x_patch 通过模型的头部(head)层,一个卷积层。
        coarse_cam_pred = self.avgpool(x_patch).squeeze(3).squeeze(2)
        # 对头部层的输出进行平均池化(avgpool),然后移除不必要的维度,得到粗略的注意力图(coarse CAM)。

        attn_weights = torch.stack(attn_weights)  # 12 * B * H * N * N
        attn_feats = torch.stack(attn_feats)  # 12 * B * N * C
        # 将所有注意力块的权重和特征堆叠起来,形成一个大的张量。

        attn_weights_detach = attn_weights.detach().clone()
        k, b, h, n, m = attn_weights_detach.shape
        attn_weights_detach = attn_weights_detach.permute([1, 2, 0, 3, 4]).contiguous()
        attn_weights_detach = attn_weights_detach.view(b, h * k, n, m)
        # 将注意力权重和特征张量从计算图中分离出来,以便后续处理。这些操作包括重塑、转置和重新排列张量的形状。

        attn_feats_detach = attn_feats.detach().clone()
        k, b, n, c = attn_feats_detach.shape
        attn_feats_detach = attn_feats_detach.view(k, b, n, -1, h)
        attn_feats_detach = attn_feats_detach.permute([1, 4, 0, 2, 3]).contiguous()
        attn_feats_detach = attn_feats_detach.view(b, h * k, n, -1)
        cross_attn_map, patch_attn_map = self.adaptive_attention_fusion(attn_feats_detach, attn_weights_detach)
        # 调用 adaptive_attention_fusion 方法,根据分离出的注意力权重和特征,
        # 计算交叉注意力图(cross ATTN map)和补丁注意力图(patch ATTN map)。

        coarse_cam = x_patch.detach().clone()  # B * C * 14 * 14
        coarse_cam = F.relu(coarse_cam)
        # 从 x_patch 分离出粗略的注意力图,并应用ReLU激活函数。

        n, c, h, w = coarse_cam.shape

        cross_attn = cross_attn_map.mean(1)[:, 0:self.num_classes, self.num_classes:].reshape([n, c, h, w])
        # 计算并重塑交叉注意力图,以便与粗略的注意力图相乘。

        if attention_type == 'fused':
            cams = cross_attn * coarse_cam  # B * C * 14 * 14
        elif attention_type == 'patchcam':
            cams = coarse_cam
        else:
            cams = cross_attn
        # 根据 attention_type 参数的值,选择如何计算最终的注意力图(CAM)。
        # 如果为 'fused',则将交叉注意力图与粗略的注意力图相乘;
        # 如果为 'patchcam',则直接使用粗略的注意力图;否则,使用交叉注意力图。

        patch_attn = patch_attn_map.mean(1)[:, self.num_classes:, self.num_classes:]
		# 计算精细的注意力图(fine CAM),这是通过将补丁注意力图与粗略的注意力图相乘得到的。
        fine_cam = torch.matmul(patch_attn.unsqueeze(1), cams.view(cams.shape[0], cams.shape[1], -1, 1)). \
            reshape(cams.shape[0], cams.shape[1], h, w)

        fine_cam_pred = self.avgpool(fine_cam).squeeze(3).squeeze(2)
		# 对精细的注意力图进行平均池化,然后移除不必要的维度,得到最终的精细注意力图预测。
        patch_attn = patch_attn.unsqueeze(0)
		# 将补丁注意力图添加一个批量维度。
        cls_token_pred = x_cls.mean(-1)
		# 计算类别嵌入的均值,作为分类预测。
        if return_att:
            return cls_token_pred, cams, patch_attn
        else:
            return cls_token_pred, coarse_cam_pred, fine_cam_pred
        # 根据 return_att 参数的值,返回分类预测、粗略的注意力图预测、精细的注意力图预测,或者额外包括注意力权重和特征。


@register_model  # 这是一个装饰器,用于将接下来的函数注册到模型注册表中。这样,当你调用 model_name() 时,它将返回一个预定义的模型实例。
def deit_small_WeakTr_patch16_224(pretrained=False, **kwargs):
    model = WeakTr(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), AdaptiveAttentionFusion=AAF, **kwargs)
    model.default_cfg = _cfg()
    return model

@register_model
def deit_small_WeakTr_AAF_RandWeight_patch16_224(pretrained=False, **kwargs):
    model = WeakTr(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), AdaptiveAttentionFusion=AAF_RandWeight,**kwargs)
    model.default_cfg = _cfg()
    return model
# 允许用户通过简单的函数调用来创建具有特定配置的模型实例,而不需要手动设置所有参数。
  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Env1sage

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值