"""
@file_name: one_hot.py
@date: 2020/6/16
"""
import cv2
import numpy as np
def encode_onehot(label, color_list):
"""
Convert a segmentation image label array to one-hot format
by replacing each pixel value with a vector of length num_classes
# Arguments
label: The 2D array segmentation image label
label_values
# Returns
A 2D array with the same width and hieght as the input, but
with a depth size of num_classes
"""
# https://stackoverflow.com/questions/46903885/map-rgb-semantic-maps-to-one-hot-encodings-and-vice-versa-in-tensorflow
# https://stackoverflow.com/questions/14859458/how-to-check-if-all-values-in-the-columns-of-a-numpy-matrix-are-the-same
semantic_map = []
for color in color_list: # rgb颜色列表, 不同的color对应不同的channel
equality = np.equal(label, color) # ->[h,w,3]
class_map = np.all(equality, axis=-1).astype(int) # 三个通道都为true,结果才是true。[h,w,3] -> [h,w]
semantic_map.append(class_map) # 当前channel保存当前color类别
semantic_map = np.stack(semantic_map, axis=-1) # [h,w,c]
return semantic_map
class EncodeMaskToOneHot(object):
def __init__(self, num_class=1, color_values=None):
self.num_class = num_class
self.color_values = color_values
def __call__(self, semantic_map):
if self.num_class > 1:
semantic_map = encode_onehot(semantic_map, self.color_values)
else:
gray_map = cv2.cvtColor(semantic_map, cv2.COLOR_BGR2GRAY)
semantic_map = np.where(gray_map > 0, 1.0, 0.0)[..., None]
return semantic_map
if __name__ == '__main__':
mask_path = './17_mask.png'
mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
# 单类别
encoder = EncodeMaskToOneHot(num_class=1, color_values=None)
encoded_mask = encoder(mask)
# 多类别
_color_bgr_list = [[0, 0, 0], [255, 156, 8], [100, 0, 255], [100, 0, 255], [255, 255, 0], [0, 100, 255]]
encoder = EncodeMaskToOneHot(num_class=5, color_values=_color_bgr_list)
encoded_mask = encoder(mask)
python 分割任务的独热编码
于 2020-11-23 10:11:28 首次发布