[小样本图像分割]SG-One: Similarity Guidance Network for One-Shot Semantic Segmentation

SG-One:用于一次性语义分割的相似引导网络(IEEE Transactions on Cybernetics 2020)

这篇文章提到的Masked Global Average Pooling比较重要,后面的很多小样本分割文章都有用到,甚至是作为融合支持图像与其标注的标配!

论文地址

摘要

在这篇文章中,主要提出了一个简单而有效的相似性引导网络来解决One-Shot分割问题。我们的目的是通过和查询图像同一类别的一个带有密集标记的支持图像来预测一个查询图像的分割掩码。为了获得支持图像的鲁棒代表特征,我们首先采用掩码平均池化(masked average pooling)策略,只考虑属于支持图像的像素来产生引导特征。然后使用余弦相似度建立引导特征和查询图像中的像素特征之间的关系。这样就可以使用产生的相似映射中所蕴涵的概率来指导对象的分割过程。

存在的问题及解决方案

过去的方案一般是训练一对并行的网络,分别用于提取有标注的支持图像和查询图像的特征。然后融合这些特征生成目标物体的概率映射。该网络的目的实际上是学习高级特征空间内的标注支持图像和查询图像之间的关系。这些方法提供了一个优势,即已见过类的训练参数可以直接用于测试未见过的类,而无需进行微调。然而,这些方法存在一些缺点:1)参数使用两个平行的网络冗余,这是容易过度拟合,导致浪费计算资源和;2)仅仅通过乘法来结合支持图像和查询图像不足以指导网络进行高质量的分割。

为了解决上述问题,本文提出了一种用于One-Shot分割的相似性引导网络。SG-One的基本思想是通过有效地融合支持对象和查询图像的特征之间的像素级相似度来指导分割过程。特别是,我们首先提取了输入的支持图像和查询图像的高级特征映射。高层特征通常是抽象的表示,且属于同一类别对象的像素嵌入表达在高级特征中更为接近。背景像素的嵌入表达通常被抑制并且这些嵌入表达与前景对象的嵌入表达距离较远。因此,我们提出了一种掩码平均池化方法来从支持图像中获取表达向量。掩码平均池化可以排除背景噪声的影响,提取出与对象相关的特征。然后我们通过计算每个支持图像表达向量和查询图像在每个像素上的余弦相似度来得到引导映射。如果查询图像中对象像素对应的特征向量与支持图像中提出的表达向量接近,那么引导映射的对应得分较高。否则,如果像素属于背景,引导映射的分数就会很低。所生成的引导映射被用于向分割过程提供所需区域的引导信息。具体来说,查询图像的位置特征向量乘以相应的相似度值。该策略在支持图像及其掩码的引导下,能够有效激活查询图像的目标对象区域。

OSLSM和co-FCN通过改变网络的输入结构或输入图像的统计分布来合并支持图像的分割掩码。不同的是,我们使用掩码平均池化操作从中间特征映射中提取代表向量,而不是改变输入。我们的方法既不损害网络的输入结构,也不损害输入数据的统计。仅对目标区域进行平均可以避免背景的影响。否则,当背景像素占主导地位时,学习到的特征会偏向背景内容。其次,OSLSM和co-FCN直接将表达向量相乘到查询图像的特征图中,用于预测分割掩码。SG-One计算查询图像的每个像素处的表达向量和特征之间的相似度,利用相似度图指导分割分支寻找目标对象区域。
在这里插入图片描述

方法

掩码平均池化(Masked Average Pooling)

OSLSM建议通过增加二进制掩码来支持图像,从而从支持图像中删除背景像素。co-FCN提出将支持图像与正负掩码连接,构建五个通道的输入块。然而,这两种方法都有两个缺点。首先,将背景像素擦除为零将改变支持图像集的统计分布。如果使用统一的网络对查询图像和擦除后的支持图像进行处理,输入数据的方差将大大增加。其次,将支持图像与其掩码拼接在一起,破坏了网络的输入结构,也会阻碍统一网络的实现。

我们建议使用掩码平均池化来提取支持对象的代表向量。假设我们有一张RGB支持图像 I ∈ R 3 × w × h I\in R^{3\times w\times h} IR3×w×h和它的分割掩码 Y ∈ { 0 , 1 } w × h Y\in \left\{ 0,1 \right\} ^{w\times h} Y{0,1}w×h w w w h h h表示图像的宽和高。 I I I输出的特征图为 F ′ ∈ R c × w ′ × h ′ F^{\prime}\in R^{c\times w\prime\times h\prime} FRc×w×h c c c表示通道数, w ′ w\prime w h ′ h\prime h表示特征图的宽和高。通过双线性插值将特征图大小调整为和mask Y Y Y相同。重置大小的特征图表示为 F ∈ R c × w × h F\in R^{c\times w\times h} FRc×w×h,然后通过对第 i i i副特征图上目标区域的像素进行平均,计算出向量 v v v的第 i i i个元素 v i v_i vi
在这里插入图片描述
全卷积网络能够保持输入像素的相对位置。因此,我们希望通过掩码平均池化,在不考虑背景内容的情况下,提取出对象区域的特征。同时,我们认为在我们的方法中输入上下文区域有助于学习更好的对象特征。掩码平均池化保持了网络的输入结构不变,使我们能够在统一的网络中处理支持和查询图像。

相似性引导(Similarity Guidance)

掩码平均池化提取了参考对象的代表向量 ν = ( v 1 , v 2 , . . . , v c ) \nu =\left( v1,v2,...,v_c \right) ν=(v1,v2,...,vc),假定查询图像 I q u e I^{que} Ique的特征图为 F q u e ∈ R c × w ′ × h ′ F^{que}\in R^{c\times w\prime\times h\prime} FqueRc×w×h。我们使用余弦距离来度量 ν \nu ν F q u e F^{que} Fque中每个像素的相似度:
在这里插入图片描述
最终相似映射 S S S整合了查询对象和支持图像的特征。我们使用映射 S = { s x , y } S=\left\{s_{x,y}\right\} S={sx,y}作为指导,引导分割分支分割所需的对象区域。我们没有明确地优化余弦相似度。特别地,我们明智地将相似性引导图与来自分割分支的查询图像的特征图相乘。然后,我们对引导特征图进行优化,使其符合相应的ground-truth掩码。

相似性引导方法

在这里插入图片描述
SG-One主要包含三个部分,即1)stem;2)相似性指导;3)分割分支。stem其实就是个用来提取特征的全卷积网络(ResNet,VGG啥的)。将提取的查询图像和支持图像的特征输入相似指导分支。我们将参考对象的特征与查询图像的特征相结合,利用该分支生成相似引导图。对于支持图像的特征,我们实现了三个卷积块来提取高度抽象和语义的特征,然后是一个掩码平均池化来获取代表向量。提取的支持图像的代表向量应包含特定对象的高级语义特征。对于查询图像的特征,我们重用这三个块,并利用余弦相似层计算查询图像的每个像素处的代表向量和特征之间的相似度。分割分支用于在生成的相似图的指导下发现查询图像的目标对象区域。

实验结果

在这里插入图片描述
在这里插入图片描述
图3还展示了一些失败的案例,作者把失败归因于:1)目标对象和背景像素太相似,例如公共汽车和普通小汽车;2)目标区域具有相比于已从支持图像中发掘的判别信息十分不符的特征,例如狗身上的背心。
在这里插入图片描述
图5给出了不同类别的相似性同。参考支持对象,查询图像中同类别的对象将被高亮显示,干扰对象和背景则被压制。

SG-One在视频分割中的应用

视频分割和图像分割的关键区别在于支持和查询图像的信息源。例如,给定一个在草地上跳舞的女孩的视频剪辑,前景目标(女孩)和背景环境(草地)在不同的帧之间变化不是很严重。
在这里插入图片描述
图6说明了这两个任务之间的区别。在视频分割任务中,目标对象与背景环境在整个视频剪辑中保持一致。而在我们的图像分割任务中,支持图像和查询图像的对象和环境是完全不同的,背景信息和后续信息都不能应用。

我们将SG-One网络应用于DAVIS2016上的One-Shot视频分割任务。在视频分割中,我们尽量不使用背景相似性,也不使用帧间连续的对象线索来寻求公平的比较结果。我们的SG-One获得了最好的精度,达到了57.3%,超过了基线方法。该模型在仅引用一幅标注图像的情况下,具有更好的鲁棒性和分割给定查询图像的能力。

结论

我们提出SG-One可以有效地分割未见类别的语义像素在仅使用一个标注的样本情况下。我们提出了掩码平均池化的方法来提取更健壮的对象相关的代表性特征。大量的实验表明,掩蔽平均池化方法更方便,能够结合上下文信息来学习更好的代表性向量。我们通过使用一个统一的网络来减少模型参数以降低过拟合的风险。我们所提出的网络同样可以直接应用于多分类图像的分割。我们提出了一个纯端到端网络,它不需要任何预处理或后处理步骤。更重要的是,SG-One提高了One-Shot语义分割的性能,超越了基准方法。最后,我们分析了一次性视频分割和我们的一次性图像语义分割问题之间的关系。实验表明,在公平的比较条件下,所提出的SG-One算法在视频对象分割方面具有优越性。与此形成鲜明对比的是,一个One-Shot图像分割任务提供的图像的目标对象或背景并不连续。查询图像中的对象和背景与支持图像存在较大差异。例如,在我们的One-Shot图像分割任务中,我们可能需要分割一个站在草地上的老人和一个躺在床上的小女孩,因为他们都属于同一类别,即是一个人。其次,得益于视频中的序列线索,视频分割方法可以从连续的帧中计算出帧间的相似度,并通过在线更新提高性能。

  • 2
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是基于匹配网络的One-shot样本分类的MATLAB代码示例: ``` % 加载数据集 load('omniglot.mat'); % 设定超参数 num_classes = 5; % 类别数量 num_samples = 1; % 每个类别的样本数量 num_epochs = 10; % 训练轮数 learning_rate = 0.001; % 学习率 % 初始化模型参数 input_shape = size(X_train{1}); filters = [32, 64, 128, 256]; kernel_sizes = [3, 3, 3, 3]; strides = [1, 2, 2, 2]; pool_sizes = [2, 2, 2, 2]; fc_size = 512; weights = init_weights(input_shape, filters, kernel_sizes, strides, pool_sizes, fc_size); % 训练模型 for epoch = 1:num_epochs for i = 1:size(X_train, 2) % 从训练集中选择一个类别 class_idx = randi(num_classes); class_samples = X_train{class_idx}; % 从该类别中选择两个样本 sample_idxs = randperm(size(class_samples, 2), num_samples+1); support_set = class_samples(:, sample_idxs(1:end-1)); query = class_samples(:, sample_idxs(end)); % 计算支持集和查询样本的嵌入向量 support_set_embed = forward_pass(support_set, weights); query_embed = forward_pass(query, weights); % 计算支持集和查询样本之间的相似度 similarities = compute_cosine_similarity(query_embed, support_set_embed); % 计算损失并进行反向传播 loss = compute_loss(similarities); gradients = backward_pass(similarities); weights = update_weights(weights, gradients, learning_rate); end fprintf('Epoch %d: Loss = %f\n', epoch, loss); end % 测试模型 num_correct = 0; for i = 1:size(X_test, 2) class_idx = randi(num_classes); class_samples = X_test{class_idx}; sample_idxs = randperm(size(class_samples, 2), num_samples+1); support_set = class_samples(:, sample_idxs(1:end-1)); query = class_samples(:, sample_idxs(end)); support_set_embed = forward_pass(support_set, weights); query_embed = forward_pass(query, weights); similarities = compute_cosine_similarity(query_embed, support_set_embed); if similarities(1) == max(similarities) num_correct = num_correct + 1; end end accuracy = num_correct / size(X_test, 2); fprintf('Accuracy: %f\n', accuracy); ``` 这里的`init_weights`函数用于初始化模型参数,`forward_pass`函数用于计算嵌入向量,`compute_cosine_similarity`函数用于计算相似度,`compute_loss`函数用于计算损失,`backward_pass`函数用于反向传播计算梯度,`update_weights`函数用于更新参数。在这个示例中,我们使用了Omniglot数据集进行训练和测试,其中每个类别只有一个样本。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值