论文链接:arxiv.org/pdf/2312.15731.pdf
复现要点:
环境依据作者给予的参考DCAMA环境
除此之外需要配置补充环境和数据集:PyTorch版本为1.5.1,缺少timm库:pip install timm
Python环境中缺少名为cv2的模块:pip install opencv-python
Python环境中缺少setproctitle模块:pip install setproctitle
需要补充数据集:VOC2012
关键介绍:
作者先进方面在于融合了多个模型框架,同时增加了PAM模块,作者冻结了其他模块,仅更新PAM模块,作者给定查询图像Iq和支持集{I k s, Mk s} k k=1,编码器首先按照前面的方法提取查询特征Fq和支持特征f。然后,将F, Ms和Fq输入到我们提出的PAM中,通过PEM获得类特定特征F∗s和F∗q。进一步,我们将F∗s和F∗q输入到LAM中以学习新任务的特殊信息,生成了F和Fq。然后,在原始特征f和Fq中注入ff和Fq,用于下游解码器,实现更精确的分割。在微调时,我们只更新PAM的参数,而冻结网络的其余部分,使基本分割模型能够有效地适应新的类别。
以下主要介绍作者先进的PAM流程:
通常先将数据集分为Dtr(训练数据集)和Dts(测试数据集)
- 原型增强模块PEM:(代码仅以FPTrans模型中的PAM为例)
- 临时原型:首先求解临时原型Pt = Mean(Fs ◦ Ms)(1)其中◦表示空间的乘法,Mean表示仅在Ms的非零位置计算空间维度的平均值
def extract_semantic_prototype(self, s_f, s_y): """ extract temporary class prototype according to support features and masks input: s_f: torch.Tensor [B, S, N, D], support features s_y: torch.Tensor [B, S * N, 1], support masks output: semantic_prototype: torch.Tensor [B, D], temporary prototypes sign_fore_per_batch: torch.Tensor [B], the signal of whether including foreground region in this image """ B, S, N, D = s_f.shape num_fore_per_batch = torch.count_nonzero(s_y.reshape(B, -1), dim=1) s_y = s_y.repeat(1, 1, D) semantic_prototype = s_y * s_f.reshape(B, -1, D) semantic_prototype = semantic_prototype.mean(1) * (N * S) / (num_fore_per_batch.unsqueeze(1)+1e-4) one = torch.ones_like(num_fore_per_batch).cuda() sign_fore_per_batch = torch.where(num_fore_per_batch > 0.5, one, num_fore_per_batch) return semantic_prototype, sign_fore_per_batch
- 训练阶段:在训练阶段,在适应目标类i时,Pi根据i从P处定位,进行精确改进类原型的表征性。Pi = (1 − α) × l2(Pt) + α × l2(Pi), (2)。其中l2和α分别表示l2归一化和动量比。
def updata_prototype_bank(self, semantic_prototype, class_idx, sign_fore_per_batch): """ updata prototype in class prototype bank during traning input: semantic_prototype: torch.Tensor [B, D] class_id: list len(class_id) = B sign_fore_per_batch: torch.Tensor [B], the signal of whether including foreground region in this image output: new_semantic_prototype: torch.Tensor [B, D], the updated prototypes for feature enhancement """ B, D = semantic_prototype.shape self.prototype = nn.functional.normalize(self.prototype, dim=0) semantic_prototype = nn.functional.normalize(semantic_prototype, dim=1) new_semantic_prototype_list = [] for i in range(B): semantic_prototype_per = semantic_prototype[i,: ] class_idx_per = class_idx[i] if sign_fore_per_batch[i] == 1: new_semantic_prototype_per = self.prototype[:, class_idx_per] * self.momentum + (1 - self.momentum) * semantic_prototype_per self.prototype[:, class_idx_per] = new_semantic_prototype_per else: new_semantic_prototype_per = self.prototype[:, class_idx_per] new_semantic_prototype_list.append(new_semantic_prototype_per) new_semantic_prototype = torch.stack(new_semantic_prototype_list, dim=0) return new_semantic_prototype
- 测试阶段:在测试阶段,通过Pt和P之间的相似性匹配选择原型Pi。Pi = Pargmax(l2(P )·l2(Pt)), (3)这样,每个Pi都可以很好地表示当前特征表示空间中对应类别的语义。
def select_prototype_bank(self, semantic_prototype, prototype_bank): """ select prototypes in class prototype bank during testing input: semantic_prototype: torch.Tensor shape = [B, D] prototype_bank: torch.Tensor shape = [D, class_num] output: new_semantic_prototype: torch.Tensor [B, D], the prototypes for feature enhancement """ B, D = semantic_prototype.shape prototype_bank = nn.functional.normalize(prototype_bank, dim=0) semantic_prototype = nn.functional.normalize(semantic_prototype, dim=1) similar_matrix = semantic_prototype @ prototype_bank # [B, class_num] idx = similar_matrix.argmax(1) new_semantic_prototype_list = [] for i in range(B): new_semantic_prototype_per = prototype_bank[:, idx[i]] new_semantic_prototype_list.append(new_semantic_prototype_per) new_semantic_prototype = torch.stack(new_semantic_prototype_list, dim=0) return new_semantic_prototype
- 特征增强:首先,我们计算Fs与Pi之间的相似映射s∈R k×h×w,即
其中在d维使用L2归一化和点乘。通过使用ReLU6,可以抑制不同的位置,保留相似的点,同时避免过多值的影响。利用增强矩阵Em,我们可以生成类特定特征F * s:
F∗s = Em ◦ Fs + Fs. (5)
需要强调的是,对于相同的原型Pi,上述过程也增强了查询Fq特征。以上仅以支持特性f为例进行说明。
def enhanced_feature(self, feature, new_semantic_prototype, sign_fore_per_batch):
"""
Input:
feature: torch.Tensor
[B, S, N, D]
new_semantic_prototype: torch.Tensor
[B, D]
Outputs:
enhanced_feature: torch.Tensor
[B, S, N, D]
"""
B, D = new_semantic_prototype.shape
feature_sim = nn.functional.normalize(feature, p=2, dim=-1)
new_semantic_prototype = nn.functional.normalize(new_semantic_prototype, p=2, dim=1)
similarity_matrix_list = []
for i in range(B):
feature_sim_per = feature_sim[i,:, :, :]
new_semantic_prototype_er = new_semantic_prototype[i, :]
similarity_matrix_per = feature_sim_per @ new_semantic_prototype_er
similarity_matrix_list.append(similarity_matrix_per)
similarity_matrix = torch.stack(similarity_matrix_list, dim=0)
similarity_matrix = (similarity_matrix * self.dim ** 0.5) * sign_fore_per_batch.unsqueeze(-1).unsqueeze(-1)
enhanced_feature = self.act_enhance(similarity_matrix).unsqueeze(-1).repeat(1, 1, 1, D) * feature + feature
return enhanced_feature
经过上述处理,可以得到增强特征F∗s和F∗q。然而,模型对特定于任务的信息进行编码的能力是不够的。此外,增强特征F∗s (F∗q)与原始特征F (Fq)之间存在分布差异问题,并且随着层数的加深,其分布差异会逐渐增大。因此,我们在PAM中采用了一个可学习的模块来解决这些问题。
- 可学习自适应模块LAM:由于可用于训练的样本非常有限,因此该模块的参数很少,与整个模型相比可以忽略,以减轻过拟合现象。具体来说,它包括一个参数为Wdown∈R dx d γ的下投影线性层,用于压缩特征维度;一个参数为Wup∈R d γ ×d的上投影线性层,用于恢复特征维度。γ表示隐维比,在两层之间放置一个ReLU层以补充非线性特性。针对特定任务的特性,可得到
registered_feas = torch.cat([enhanced_feat_sup.reshape(-1, N, D), enhanced_feat_q.squeeze(1)], dim=0) registered_feas=self.proj_drop(self.linears_up(self.act(self.linears_down(registered_feas))))
- 启发思想:只能通过更改PAM模块,以提升更高的准确率
1.是否能通过更换激活函数达到更高的准确率和更好的性能?
考虑点:Swish、Mish、ELU (Exponential Linear Unit)、PReLU (Parametric ReLU)、GELU (Gaussian Error Linear Unit)、Bent Identity
2.是否能通过在PEM的训练阶段,采用数据增强技术来丰富类原型的代表性?
考虑点:几何变换、颜色空间变换、随机裁剪、遮挡和遮蔽、弹性变形、噪声注入、合成数据