python 分割任务的独热编码

"""
@file_name: one_hot.py
@date: 2020/6/16
"""
import cv2
import numpy as np


def encode_onehot(label, color_list):
    """
        Convert a segmentation image label array to one-hot format
        by replacing each pixel value with a vector of length num_classes

        # Arguments
            label: The 2D array segmentation image label
            label_values

        # Returns
            A 2D array with the same width and hieght as the input, but
            with a depth size of num_classes
        """
    # https://stackoverflow.com/questions/46903885/map-rgb-semantic-maps-to-one-hot-encodings-and-vice-versa-in-tensorflow
    # https://stackoverflow.com/questions/14859458/how-to-check-if-all-values-in-the-columns-of-a-numpy-matrix-are-the-same
    semantic_map = []
    for color in color_list:  # rgb颜色列表, 不同的color对应不同的channel
        equality = np.equal(label, color)  # ->[h,w,3]
        class_map = np.all(equality, axis=-1).astype(int)  # 三个通道都为true,结果才是true。[h,w,3] -> [h,w]
        semantic_map.append(class_map)  # 当前channel保存当前color类别
    semantic_map = np.stack(semantic_map, axis=-1)  # [h,w,c]
    return semantic_map


class EncodeMaskToOneHot(object):
    def __init__(self, num_class=1, color_values=None):
        self.num_class = num_class
        self.color_values = color_values

    def __call__(self, semantic_map):
        if self.num_class > 1:
            semantic_map = encode_onehot(semantic_map, self.color_values)
        else:
            gray_map = cv2.cvtColor(semantic_map, cv2.COLOR_BGR2GRAY)
            semantic_map = np.where(gray_map > 0, 1.0, 0.0)[..., None]
        return semantic_map


if __name__ == '__main__':
    mask_path = './17_mask.png'
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
    # 单类别
    encoder = EncodeMaskToOneHot(num_class=1, color_values=None)
    encoded_mask = encoder(mask)

    # 多类别
    _color_bgr_list = [[0, 0, 0], [255, 156, 8], [100, 0, 255], [100, 0, 255], [255, 255, 0], [0, 100, 255]]
    encoder = EncodeMaskToOneHot(num_class=5, color_values=_color_bgr_list)
    encoded_mask = encoder(mask)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值