SAM+RS:SAM-Assisted Remote Sensing Imagery Semantic Segmentation with Object and Boundary Constraint

论文

上图表明,c:samg - generated Object (SGO)和d:samg - generated Boundary (SGB)可以提供详细的对象和边界信息。为了充分利用它们的潜力,同时尽量减少对一般语义分割模型的修改,我们提出了一种新的损失函数,即对象一致性损失,并进一步引入边界保留损失来帮助模型训练。

这个方法不需要对语义分割模型、训练策略或伪标签生成进行特定设计,只需要添加两个新的损失函数。这是在不依赖语义信息的情况下实现的,主要关注两个关键的角度:对象和边界。通过直接利用SAM的原始输出来改进分割结果,而不需要额外的类提示。

所提出的框架的示意图如图3所示。图3 (a)展示了传统的语义分割方法。相比之下,如图3 (b)所示,我们的方法包含了一个使用SAM的额外阶段。具体来说,我们使用SAM直接创建SGO和SGB。这些输出分别参与计算目标一致性损失和边界保留损失。

首先利用SAM在网格提示设置的所有可能位置生成整个图像的分割掩码,将分割掩码视为对象,存储在一个列表中,我们设置一个阈值K=50来限制数量,还建立一个阈值S=50来限制单个对象可以包含的像素数量,有效地过滤掉非常小的分割掩码。由此可以得到一个SGO,未分割为对象的像素和边界赋值为0。SGO的数据组织如图4 (a)所示。同时,从SGO得到边界先验图。这个过程包括勾画出列表内每个对象的外部边界,并将这些边界合并成一个综合的边界先验图,即SGB。

1)对象一致性损失:对象一致性损失的目的是保持给定输入图像中对象之间的像素一致性。给定输入X,语义分割模型的输出记为p。为了计算对象一致性损失,我们遍历Yo中的所有对象。数据流如图5所示。对于每个目标,我们首先提取其掩码Mi,即Yo中像素值等于i的区域,然后通过以下方法获得目标特征:

⊙代表哈达玛积,目标特征F io表示基于第i个目标的面积过滤的模型预测(其实就是获取对应位置的特征,其余设置为0)。接下来,我们可以计算出目标的平均特征为

其中,G计算空间维度中所有像素的总和,并将其重塑为原始形状,Ni为第i个对象的点数。为了避免分母为零,要加一个额外的1。F i avg表示第i个对象中所有像素的期望平均值。因此,我们可以计算所有对象的对象一致性损失Lobj:

2)边界保留损失:

实操中,在总损失函数中加入分割的交叉熵损失+对象一致性损失+边界保留损失。

源码

class ObjectLoss(nn.Module):
    def __init__(self, max_object=50):
        super().__init__()
        self.max_object = max_object

    def forward(self, pred, gt, device):
        num_object = int(torch.max(gt)) + 1
        num_object = min(num_object, self.max_object)
        total_object_loss = 0

        for object_index in range(1, num_object):
            mask = torch.where(gt == object_index, 1, 0).unsqueeze(1).to(device)  # 当前object设置为1,便于统计数量
            num_point = mask.sum(2).sum(2).unsqueeze(2).unsqueeze(2).to(device)
            avg_pool = mask / (num_point + 1)  # +1防止0,mask中object的1变为1/num

            object_feature = pred.mul(avg_pool)  # 哈达玛积

            avg_feature = object_feature.sum(2).sum(2).unsqueeze(2).unsqueeze(2).repeat(1, 1, gt.shape[1], gt.shape[2])
            avg_feature = avg_feature.mul(mask)

            object_loss = torch.nn.functional.mse_loss(num_point * object_feature, avg_feature, reduction='mean')
            # 点num*输出特征图对应区域*1/num  and 对象所有feature
            total_object_loss = total_object_loss + object_loss

        return total_object_loss


class BoundaryLoss(nn.Module):  # todo
    def __init__(self, theta0=3, theta=5):
        super().__init__()

        self.theta0 = theta0
        self.theta = theta

    def forward(self, pred, gt):
        """
        Input:
            - pred: the output from model (before softmax)
                    shape (N, C, H, W)
            - gt: ground truth map
                    shape (N, H, w)
        Return:
            - boundary loss, averaged over mini-bathc
        """

        n, c, _, _ = pred.shape
        # softmax so that predicted map can be distributed in [0, 1]
        pred = torch.softmax(pred, dim=1)
        # one-hot vector of ground truth
        one_hot_gt = one_hot(gt, c)
        # boundary map
        gt_b = F.max_pool2d(
            1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
        gt_b -= 1 - one_hot_gt

        pred_b = F.max_pool2d(
            1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
        pred_b -= 1 - pred

        # extended boundary map
        gt_b_ext = F.max_pool2d(
            gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

        pred_b_ext = F.max_pool2d(
            pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

        # reshape
        gt_b = gt_b.view(n, c, -1)
        pred_b = pred_b.view(n, c, -1)
        gt_b_ext = gt_b_ext.view(n, c, -1)
        pred_b_ext = pred_b_ext.view(n, c, -1)

        # Precision, Recall
        P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
        R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)

        # Boundary F1 Score
        BF1 = 2 * P * R / (P + R + 1e-7)

        # summing BF1 Score for each class and average over mini-batch
        loss = torch.mean(1 - BF1)

        return loss

看论文源码,学习了一个比较好用的找边界的函数:

skimage.segmentation.find_boundaries(label_img, connectivity=1, mode='thick', background=0)

如果一个像素的任何相邻像素具有不同的标签,则该像素被视为边界像素。连通性控制哪些像素被认为是邻居。连接性为 1(默认)意味着共享边(2D)或面(3D)的像素将被视为邻居。 

mode:{‘thick’, ‘inner’, ‘outer’, ‘subpixel’}

thick:任何未完全被相同标签的像素包围的像素(定义为连通性) 被标记为边界。这会产生 2 个像素厚的边界。

inner:勾勒像素就在里面对象,保持背景像素不变。

outer:对象边界周围背景中的轮廓像素。当两个物体接触时,它们的边界也会被标记。

subpixel:返回一个加倍的图像,带有像素之间在适当的地方标记为边界的原始像素。

土地覆盖分类实验

以下是两幅数据展示,分别对应影像、GT、SGO、SGB。

  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值