SegFix:预测边界和预测方向来修正边界

该文章介绍了一种名为SegFix的方法,它将边界检测问题转化为八种方向的分类任务,利用Sobel算子处理边界方向,并通过偏移量计算进行边界细化。文章还包含了用于编码和解码标签的类方法。
摘要由CSDN通过智能技术生成

论文标题:SegFix: Model-Agnostic Boundary Refinement for Segmentation
论文地址:https://arxiv.org/pdf/2007.04269.pdf
代码地址:https://github.com/openseg-group/openseg.pytorch

两种loss监督

八种方向变回归问题为分类问题

代码地址:

1、使用sobel算子把边界点的方向换成分类问题

for id in range(1, len(label_list) + 1):
        labelmap_i = labelmap.copy()
        labelmap_i[labelmap_i != id] = 0
        labelmap_i[labelmap_i == id] = 1

        if labelmap_i.sum() < 100:
            continue

        if args.metric == 'euc':
            depth_i = distance_transform_edt(labelmap_i)
        elif args.metric == 'taxicab':
            depth_i = distance_transform_cdt(labelmap_i, metric='taxicab')
        else:
            raise RuntimeError
        depth_map += depth_i

        dir_i_before = dir_i = np.zeros_like(dir_map)
        dir_i = torch.nn.functional.conv2d(torch.from_numpy(depth_i).float().view(1, 1, *depth_i.shape), sobel_ker, padding=ksize//2).squeeze().permute(1, 2, 0).numpy()

        # The following line is necessary
        dir_i[(labelmap_i == 0), :] = 0
        
        dir_map += dir_i

2、计算偏移量

def shift(x, offset):
    """
    x: h x w
    offset: 2 x h x w
    """
    h, w = x.shape
    x = torch.from_numpy(x).unsqueeze(0)
    offset = torch.from_numpy(offset).unsqueeze(0)
    coord_map = gen_coord_map(h, w)
    norm_factor = torch.FloatTensor([(w-1)/2, (h-1)/2])
    grid_h = offset[:, 0]+coord_map[0]
    grid_w = offset[:, 1]+coord_map[1]
    grid = torch.stack([grid_w, grid_h], dim=-1) / norm_factor - 1
    x = F.grid_sample(x.unsqueeze(1).float(), grid, padding_mode='border', mode='bilinear').squeeze().numpy()
    x = np.round(x)
    return x.astype(np.uint8)

3、重新计算label

class LabelTransformer:

    label_list = [7, 8, 11, 12, 13, 17, 19, 20,
                  21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]

    @staticmethod
    def encode(labelmap):
        labelmap = np.array(labelmap)

        shape = labelmap.shape
        encoded_labelmap = np.ones(
            shape=(shape[0], shape[1]), dtype=np.int) * 255
        for i in range(len(LabelTransformer.label_list)):
            class_id = LabelTransformer.label_list[i]
            encoded_labelmap[labelmap == class_id] = i

        return encoded_labelmap

    @staticmethod
    def decode(labelmap):
        labelmap = np.array(labelmap)

        shape = labelmap.shape
        encoded_labelmap = np.ones(
            shape=(shape[0], shape[1]), dtype=np.uint8) * 255
        for i in range(len(LabelTransformer.label_list)):
            class_id = i
            encoded_labelmap[labelmap ==
                             class_id] = LabelTransformer.label_list[i]

        return encoded_labelmap
  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值