Adaptive FSS论文分析 复现要点 关键介绍 启发思想

论文链接:arxiv.org/pdf/2312.15731.pdf

github:GitHub - jingw193/AdaptiveFSS: Adaptive FSS has been Accepted by AAAI 2024. A Novel Few-Shot Segmentation Framework via Prototype Enhancement

复现要点:

环境依据作者给予的参考DCAMA环境

除此之外需要配置补充环境和数据集:PyTorch版本为1.5.1,缺少timm库:pip install timm
Python环境中缺少名为cv2的模块:pip install opencv-python
Python环境中缺少setproctitle模块:pip install setproctitle
需要补充数据集:VOC2012

关键介绍:

作者先进方面在于融合了多个模型框架,同时增加了PAM模块,作者冻结了其他模块,仅更新PAM模块,作者给定查询图像Iq和支持集{I k s, Mk s} k k=1,编码器首先按照前面的方法提取查询特征Fq和支持特征f。然后,将F, Ms和Fq输入到我们提出的PAM中,通过PEM获得类特定特征F∗s和F∗q。进一步,我们将F∗s和F∗q输入到LAM中以学习新任务的特殊信息,生成了F和Fq。然后,在原始特征f和Fq中注入ff和Fq,用于下游解码器,实现更精确的分割。在微调时,我们只更新PAM的参数,而冻结网络的其余部分,使基本分割模型能够有效地适应新的类别。

以下主要介绍作者先进的PAM流程:

通常先将数据集分为Dtr(训练数据集)和Dts(测试数据集)

  • 原型增强模块PEM:(代码仅以FPTrans模型中的PAM为例
  1. 临时原型:首先求解临时原型Pt = Mean(Fs ◦ Ms)(1)其中◦表示空间的乘法,Mean表示仅在Ms的非零位置计算空间维度的平均值
     def extract_semantic_prototype(self, s_f, s_y):
              """
              extract temporary class prototype according to support features and masks
              input:
                s_f: torch.Tensor
                    [B, S, N, D], support features
                s_y: torch.Tensor
                    [B, S * N, 1], support masks
              output:
                semantic_prototype: torch.Tensor
                    [B, D], temporary prototypes
                sign_fore_per_batch: torch.Tensor
                    [B], the signal of whether including foreground region in this image
              """
              B, S, N, D = s_f.shape
              num_fore_per_batch = torch.count_nonzero(s_y.reshape(B, -1), dim=1)
              s_y = s_y.repeat(1, 1, D)
              semantic_prototype = s_y * s_f.reshape(B, -1, D)
              semantic_prototype = semantic_prototype.mean(1) * (N * S) / (num_fore_per_batch.unsqueeze(1)+1e-4)
              one = torch.ones_like(num_fore_per_batch).cuda()
              sign_fore_per_batch =  torch.where(num_fore_per_batch > 0.5, one, num_fore_per_batch)
              return semantic_prototype, sign_fore_per_batch
          

  2. 训练阶段:在训练阶段,在适应目标类i时,Pi根据i从P处定位,进行精确改进类原型的表征性。Pi = (1 − α) × l2(Pt) + α × l2(Pi), (2)。其中l2和α分别表示l2归一化和动量比。
          def updata_prototype_bank(self, semantic_prototype, class_idx, sign_fore_per_batch):
              """
              updata prototype in class prototype bank during traning
              input:
                semantic_prototype: torch.Tensor
                    [B, D]
                class_id: list
                    len(class_id) = B
                sign_fore_per_batch: torch.Tensor
                    [B], the signal of whether including foreground region in this image
              output:
                new_semantic_prototype: torch.Tensor
                    [B, D], the updated prototypes for feature enhancement
              """  
              B, D = semantic_prototype.shape
              self.prototype = nn.functional.normalize(self.prototype, dim=0)
              semantic_prototype = nn.functional.normalize(semantic_prototype, dim=1)
              new_semantic_prototype_list = []
              for i in range(B):
                   semantic_prototype_per = semantic_prototype[i,: ]
                   class_idx_per = class_idx[i]
                   if sign_fore_per_batch[i] == 1:
                        new_semantic_prototype_per = self.prototype[:, class_idx_per] * self.momentum + (1 - self.momentum) * semantic_prototype_per
                        self.prototype[:, class_idx_per] = new_semantic_prototype_per
                   else: 
                        new_semantic_prototype_per = self.prototype[:, class_idx_per]    
                   new_semantic_prototype_list.append(new_semantic_prototype_per)
              new_semantic_prototype = torch.stack(new_semantic_prototype_list, dim=0)
              return new_semantic_prototype

  3. 测试阶段:在测试阶段,通过Pt和P之间的相似性匹配选择原型Pi。Pi = Pargmax(l2(P )·l2(Pt)), (3)这样,每个Pi都可以很好地表示当前特征表示空间中对应类别的语义。
          def select_prototype_bank(self, semantic_prototype, prototype_bank):
              """
              select prototypes in class prototype bank during testing
              input:
                semantic_prototype: torch.Tensor
                    shape = [B, D]
                prototype_bank: torch.Tensor
                    shape = [D, class_num]
              output:
                new_semantic_prototype: torch.Tensor
                    [B, D], the prototypes for feature enhancement
              """ 
              B, D = semantic_prototype.shape
              prototype_bank = nn.functional.normalize(prototype_bank, dim=0)
              semantic_prototype = nn.functional.normalize(semantic_prototype, dim=1)
              similar_matrix = semantic_prototype @ prototype_bank  # [B, class_num]
              idx = similar_matrix.argmax(1)
    
              new_semantic_prototype_list = []
              for i in range(B):
                  new_semantic_prototype_per = prototype_bank[:, idx[i]]
                  new_semantic_prototype_list.append(new_semantic_prototype_per)
              new_semantic_prototype = torch.stack(new_semantic_prototype_list, dim=0)
              return new_semantic_prototype

  4. 特征增强:首先,我们计算Fs与Pi之间的相似映射s∈R k×h×w,即

其中在d维使用L2归一化和点乘。通过使用ReLU6,可以抑制不同的位置,保留相似的点,同时避免过多值的影响。利用增强矩阵Em,我们可以生成类特定特征F * s:

F∗s = Em ◦ Fs + Fs.   (5)

需要强调的是,对于相同的原型Pi,上述过程也增强了查询Fq特征。以上仅以支持特性f为例进行说明。

      def enhanced_feature(self, feature, new_semantic_prototype, sign_fore_per_batch):
          """ 
            Input:
                feature: torch.Tensor
                    [B, S, N, D]
                new_semantic_prototype: torch.Tensor
                    [B, D]
            Outputs:
                enhanced_feature: torch.Tensor
                    [B, S, N, D]
            """
          B, D = new_semantic_prototype.shape
          feature_sim = nn.functional.normalize(feature, p=2, dim=-1)
          new_semantic_prototype = nn.functional.normalize(new_semantic_prototype, p=2, dim=1)
          similarity_matrix_list = []
          for i in range(B):
              feature_sim_per = feature_sim[i,:, :, :]
              new_semantic_prototype_er = new_semantic_prototype[i, :]
              similarity_matrix_per = feature_sim_per @ new_semantic_prototype_er
              similarity_matrix_list.append(similarity_matrix_per)
          similarity_matrix = torch.stack(similarity_matrix_list, dim=0)

          similarity_matrix = (similarity_matrix * self.dim ** 0.5) *  sign_fore_per_batch.unsqueeze(-1).unsqueeze(-1)

          enhanced_feature = self.act_enhance(similarity_matrix).unsqueeze(-1).repeat(1, 1, 1, D) * feature + feature


          return enhanced_feature

经过上述处理,可以得到增强特征F∗s和F∗q。然而,模型对特定于任务的信息进行编码的能力是不够的。此外,增强特征F∗s (F∗q)与原始特征F (Fq)之间存在分布差异问题,并且随着层数的加深,其分布差异会逐渐增大。因此,我们在PAM中采用了一个可学习的模块来解决这些问题。


  • 可学习自适应模块LAM:由于可用于训练的样本非常有限,因此该模块的参数很少,与整个模型相比可以忽略,以减轻过拟合现象。具体来说,它包括一个参数为Wdown∈R dx d γ的下投影线性层,用于压缩特征维度;一个参数为Wup∈R d γ ×d的上投影线性层,用于恢复特征维度。γ表示隐维比,在两层之间放置一个ReLU层以补充非线性特性。针对特定任务的特性,可得到
    registered_feas = torch.cat([enhanced_feat_sup.reshape(-1, N, D), enhanced_feat_q.squeeze(1)], dim=0)
    registered_feas=self.proj_drop(self.linears_up(self.act(self.linears_down(registered_feas))))

  • 启发思想:只能通过更改PAM模块,以提升更高的准确率

1.是否能通过更换激活函数达到更高的准确率和更好的性能?

考虑点:Swish、Mish、ELU (Exponential Linear Unit)、PReLU (Parametric ReLU)、GELU (Gaussian Error Linear Unit)、Bent Identity

2.是否能通过在PEM的训练阶段,采用数据增强技术来丰富类原型的代表性?

考虑点:几何变换、颜色空间变换、随机裁剪、遮挡和遮蔽、弹性变形、噪声注入、合成数据

  • 52
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值