论文
上图表明,c:samg - generated Object (SGO)和d:samg - generated Boundary (SGB)可以提供详细的对象和边界信息。为了充分利用它们的潜力,同时尽量减少对一般语义分割模型的修改,我们提出了一种新的损失函数,即对象一致性损失,并进一步引入边界保留损失来帮助模型训练。
这个方法不需要对语义分割模型、训练策略或伪标签生成进行特定设计,只需要添加两个新的损失函数。这是在不依赖语义信息的情况下实现的,主要关注两个关键的角度:对象和边界。通过直接利用SAM的原始输出来改进分割结果,而不需要额外的类提示。
所提出的框架的示意图如图3所示。图3 (a)展示了传统的语义分割方法。相比之下,如图3 (b)所示,我们的方法包含了一个使用SAM的额外阶段。具体来说,我们使用SAM直接创建SGO和SGB。这些输出分别参与计算目标一致性损失和边界保留损失。
首先利用SAM在网格提示设置的所有可能位置生成整个图像的分割掩码,将分割掩码视为对象,存储在一个列表中,我们设置一个阈值K=50来限制数量,还建立一个阈值S=50来限制单个对象可以包含的像素数量,有效地过滤掉非常小的分割掩码。由此可以得到一个SGO,未分割为对象的像素和边界赋值为0。SGO的数据组织如图4 (a)所示。同时,从SGO得到边界先验图。这个过程包括勾画出列表内每个对象的外部边界,并将这些边界合并成一个综合的边界先验图,即SGB。
1)对象一致性损失:对象一致性损失的目的是保持给定输入图像中对象之间的像素一致性。给定输入X,语义分割模型的输出记为p。为了计算对象一致性损失,我们遍历Yo中的所有对象。数据流如图5所示。对于每个目标,我们首先提取其掩码Mi,即Yo中像素值等于i的区域,然后通过以下方法获得目标特征:
⊙代表哈达玛积,目标特征F io表示基于第i个目标的面积过滤的模型预测(其实就是获取对应位置的特征,其余设置为0)。接下来,我们可以计算出目标的平均特征为
其中,G计算空间维度中所有像素的总和,并将其重塑为原始形状,Ni为第i个对象的点数。为了避免分母为零,要加一个额外的1。F i avg表示第i个对象中所有像素的期望平均值。因此,我们可以计算所有对象的对象一致性损失Lobj:
2)边界保留损失:
实操中,在总损失函数中加入分割的交叉熵损失+对象一致性损失+边界保留损失。
源码
class ObjectLoss(nn.Module):
def __init__(self, max_object=50):
super().__init__()
self.max_object = max_object
def forward(self, pred, gt, device):
num_object = int(torch.max(gt)) + 1
num_object = min(num_object, self.max_object)
total_object_loss = 0
for object_index in range(1, num_object):
mask = torch.where(gt == object_index, 1, 0).unsqueeze(1).to(device) # 当前object设置为1,便于统计数量
num_point = mask.sum(2).sum(2).unsqueeze(2).unsqueeze(2).to(device)
avg_pool = mask / (num_point + 1) # +1防止0,mask中object的1变为1/num
object_feature = pred.mul(avg_pool) # 哈达玛积
avg_feature = object_feature.sum(2).sum(2).unsqueeze(2).unsqueeze(2).repeat(1, 1, gt.shape[1], gt.shape[2])
avg_feature = avg_feature.mul(mask)
object_loss = torch.nn.functional.mse_loss(num_point * object_feature, avg_feature, reduction='mean')
# 点num*输出特征图对应区域*1/num and 对象所有feature
total_object_loss = total_object_loss + object_loss
return total_object_loss
class BoundaryLoss(nn.Module): # todo
def __init__(self, theta0=3, theta=5):
super().__init__()
self.theta0 = theta0
self.theta = theta
def forward(self, pred, gt):
"""
Input:
- pred: the output from model (before softmax)
shape (N, C, H, W)
- gt: ground truth map
shape (N, H, w)
Return:
- boundary loss, averaged over mini-bathc
"""
n, c, _, _ = pred.shape
# softmax so that predicted map can be distributed in [0, 1]
pred = torch.softmax(pred, dim=1)
# one-hot vector of ground truth
one_hot_gt = one_hot(gt, c)
# boundary map
gt_b = F.max_pool2d(
1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
gt_b -= 1 - one_hot_gt
pred_b = F.max_pool2d(
1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
pred_b -= 1 - pred
# extended boundary map
gt_b_ext = F.max_pool2d(
gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
pred_b_ext = F.max_pool2d(
pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
# reshape
gt_b = gt_b.view(n, c, -1)
pred_b = pred_b.view(n, c, -1)
gt_b_ext = gt_b_ext.view(n, c, -1)
pred_b_ext = pred_b_ext.view(n, c, -1)
# Precision, Recall
P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)
# Boundary F1 Score
BF1 = 2 * P * R / (P + R + 1e-7)
# summing BF1 Score for each class and average over mini-batch
loss = torch.mean(1 - BF1)
return loss
看论文源码,学习了一个比较好用的找边界的函数:
skimage.segmentation.find_boundaries(label_img, connectivity=1, mode='thick', background=0)
如果一个像素的任何相邻像素具有不同的标签,则该像素被视为边界像素。连通性控制哪些像素被认为是邻居。连接性为 1(默认)意味着共享边(2D)或面(3D)的像素将被视为邻居。
mode:{‘thick’, ‘inner’, ‘outer’, ‘subpixel’}
thick:任何未完全被相同标签的像素包围的像素(定义为连通性) 被标记为边界。这会产生 2 个像素厚的边界。
inner:勾勒像素就在里面对象,保持背景像素不变。
outer:对象边界周围背景中的轮廓像素。当两个物体接触时,它们的边界也会被标记。
subpixel:返回一个加倍的图像,带有像素之间在适当的地方标记为边界的原始像素。
土地覆盖分类实验
以下是两幅数据展示,分别对应影像、GT、SGO、SGB。