- mmsegmentation保存预测的结果,不覆盖原图
首先找到源码中对应的文件 base.py
对应的代码片段如下:
def show_result(self,
img,
result,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The semantic segmentation results to draw over
`img`.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
if palette is None:
if self.PALETTE is None:
# Get random state before set seed,
# and restore random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
np.random.set_state(state)
else:
palette = self.PALETTE
palette = np.array(palette)
assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
# 关于mmsegmentation保存预测结果 https://zhuanlan.zhihu.com/p/380178024
# mmsegmentation默认将预测得到的mask覆盖在原始图片上进行显示或保存,为了直接输出灰度图,需要对源码进行修改
# 以二分类为例,可修改为
# if out_file is not None:
# seg = np.array(seg)
# seg[seg > 0] = 255
# mmcv.imwrite(seg, out_file)
# 对于二分类可如文末形式进行修改;对于多分类,可在文中标记修改处修改为mmcv.imwrite(color_seg, out_file),此时将保存通过调色板转换后的彩色图,若想保存灰度图,则在修改处修改为mmcv.imwrite(seg, out_file)即可
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
- python的函数装饰器
看了很多次 今天终于懂一点点了,记住了一些帮助理解的关键句
参考链接:http://c.biancheng.net/view/2270.html