论文标题:SegFix: Model-Agnostic Boundary Refinement for Segmentation
论文地址:https://arxiv.org/pdf/2007.04269.pdf
代码地址:https://github.com/openseg-group/openseg.pytorch
![](https://img-blog.csdnimg.cn/direct/c175563d2f2745558937b89d47e76dd8.png)
两种loss监督
八种方向变回归问题为分类问题
代码地址:
1、使用sobel算子把边界点的方向换成分类问题
for id in range(1, len(label_list) + 1):
labelmap_i = labelmap.copy()
labelmap_i[labelmap_i != id] = 0
labelmap_i[labelmap_i == id] = 1
if labelmap_i.sum() < 100:
continue
if args.metric == 'euc':
depth_i = distance_transform_edt(labelmap_i)
elif args.metric == 'taxicab':
depth_i = distance_transform_cdt(labelmap_i, metric='taxicab')
else:
raise RuntimeError
depth_map += depth_i
dir_i_before = dir_i = np.zeros_like(dir_map)
dir_i = torch.nn.functional.conv2d(torch.from_numpy(depth_i).float().view(1, 1, *depth_i.shape), sobel_ker, padding=ksize//2).squeeze().permute(1, 2, 0).numpy()
# The following line is necessary
dir_i[(labelmap_i == 0), :] = 0
dir_map += dir_i
2、计算偏移量
def shift(x, offset):
"""
x: h x w
offset: 2 x h x w
"""
h, w = x.shape
x = torch.from_numpy(x).unsqueeze(0)
offset = torch.from_numpy(offset).unsqueeze(0)
coord_map = gen_coord_map(h, w)
norm_factor = torch.FloatTensor([(w-1)/2, (h-1)/2])
grid_h = offset[:, 0]+coord_map[0]
grid_w = offset[:, 1]+coord_map[1]
grid = torch.stack([grid_w, grid_h], dim=-1) / norm_factor - 1
x = F.grid_sample(x.unsqueeze(1).float(), grid, padding_mode='border', mode='bilinear').squeeze().numpy()
x = np.round(x)
return x.astype(np.uint8)
3、重新计算label
class LabelTransformer:
label_list = [7, 8, 11, 12, 13, 17, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
@staticmethod
def encode(labelmap):
labelmap = np.array(labelmap)
shape = labelmap.shape
encoded_labelmap = np.ones(
shape=(shape[0], shape[1]), dtype=np.int) * 255
for i in range(len(LabelTransformer.label_list)):
class_id = LabelTransformer.label_list[i]
encoded_labelmap[labelmap == class_id] = i
return encoded_labelmap
@staticmethod
def decode(labelmap):
labelmap = np.array(labelmap)
shape = labelmap.shape
encoded_labelmap = np.ones(
shape=(shape[0], shape[1]), dtype=np.uint8) * 255
for i in range(len(LabelTransformer.label_list)):
class_id = i
encoded_labelmap[labelmap ==
class_id] = LabelTransformer.label_list[i]
return encoded_labelmap