论文:arXiv
code:
detectron2
SegmenTron
关键词:语义分割,实例分割,边缘细化,迭代细分
1. 核心思想
-
上采样困难点修补。因为即使是双线性上采样,还是对物体的边缘有损伤,因为边缘细节部分的像素点,相对于平滑区域少很多,这样在上采样之后就容易出现物体边缘不够好的问题。论文提出了困难点,在图像平面中灵活,自适应的选择点(points)来预测分割标签。直观上,这些点应该更密集地位于高频区域附近,如对象边界,类似于光线追踪中的抗锯齿问题(anti-aliasing,也译为边缘柔化、消除混叠等),用这些points来替换上采样之后的结果,这样对上采样有提升
-
上采样修补。这篇论文讲了一个很好听的故事,即:把语义分割以及实例分割问题(统称图像分割问题)当做一个渲染问题来解决。故事虽然这么讲,但本质上这篇论文其实是一个新型上采样方法,针对物体边缘的图像分割进行优化,使其在难以分割的物体边缘部分有更好的表现。
-
cascade思想。直接上采样x4,x8倍,有锯齿效果,边缘不好,那我一次上采样x2,而且每次用bilinear上采样之后,我还修补一下
2. PointRend训练
前面也说了,PointRend本质上是一种对上采样的修补,或者说一种新型上采样方法。其训练和测试的流程是不一样的,这里分开来讲。
是一个通用模块, 可以合并到SOTA backbone 中,如 MaskR-CNN, Deeplabv3的后端再做进一步优化的。相当于一个deepsup的分支。
这里以语义分割deeplabv3+,cityscapes数据集为实例。
总结起来,PointRend就是对困难点的语义分割结果进行监督学习
2.1 整体流程
- deeplabv3+得到 粗糙的语义结果 coarse out(shape:b×19×h×w)
c1, _, _, c4 = self.backbone.encoder(x)
out = self.backbone.head(c4, c1)
- PointRend关键点采样与语义标签预测
a. 从out中随机选择N×N×K×beta个困难采样点,加上N×N×K×(1-beta)个随机采样点,得到采样点。
N = x.shape[-1] // 16
points = sampling_points(out, N * N, self.k, self.beta)
这里困难点的界定还是有文章可以做的。原文是:依据分类结果最大置信度和第二大置信度之间差值来选择的。
b. 根据采样点对:粗糙的语义结果coarse out, 和第二浅层特征res2 进行采样,得到困难点的特征表示
coarse = point_sample(out, points)
fine = point_sample(res2, points)
feature_representation = torch.cat([coarse, fine], dim=1)
c. 对困难点用MLP做语义分类
rend = self.mlp(feature_representation)
d. 根据采样点,对gt采样,得到 困难点语义的 label
gt_points = point_sample(
gt.float().unsqueeze(1),
result["points"],
mode="nearest"
).squeeze_(1).long()
points_loss = F.cross_entropy(result["rend"], gt_points, ignore_index=self.ignore_index)
总结起来就是:deeplabv3+语义分割学习 + 困难点语义的辅助监督
3. PointRend测试
所以,测试的时候,才是核心思想的体现:upsample修补,Cascade级联
3.1 整体流程:
while out.shape < target.shape:
- 对deeplabv3+的粗糙语义结果上采样2倍,得到 out2
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
- 从 out2 中得到 8096 个困难采样点
H_step, W_step = 1 / H, 1 / W
N = min(H * W, N)
uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
_, idx = uncertainty_map.view(B, -1).topk(N, dim=1)
points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
points[:, :, 0] = W_step / 2.0 + (idx % W).to(torch.float) * W_step
points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
return idx, points
- 根据 8096 个采样点, 从 out2 和 res2 采样,得到困难点的特征表示,然后再做语义分类
coarse = point_sample(out, points)
fine = point_sample(res2, points)
feature_representation = torch.cat([coarse, fine], dim=1)
rend = self.mlp(feature_representation)
- 根据 8096 个采样点 和 采样点的语义结果,对 out2 中的结果进行替换(上采样修补)
B, C, H, W = out.shape
points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
out = (out.reshape(B, C, -1)
.scatter_(2, points_idx, rend)
.view(B, C, H, W))
后面一直循环,直到 out.shape >= target.shape
这里嗅出了2020的CascadePSP,SFNET的味道