1. pytorch
参考:https://www.pythonf.cn/read/93750
背景:将多分类分割转为二分类时,crossentropy内置有onehot编码,无需对mask进行onehot处理,但是转为二分类时,loss改为bce loss,需要手动对mask进行onehot编码后再加入loss
# 第一种方法:已亲测有效
def mask2onehot(mask, num_classes):
"""
Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
hot encoding vector
"""
_mask = [mask == i for i in range(num_classes)]
return np.array(_mask).astype(np.uint8)
def onehot2mask(mask):
"""
Converts a mask (K, H, W) to (H,W)
"""
_mask = np.argmax(mask, axis=0).astype(np.uint8)
return _mask
# 第二种方法
def mask_to_onehot(mask, palette):
"""
Converts a segmentation mask (H, W, C) to (H, W, K) where the last dim is a one
hot encoding vector, C is usually 1 or 3, and K is the number of class.
"""
semantic_map = []
for colour in palette:
equality = np.equal(mask, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float32)
return semantic_map
def onehot_to_mask(mask, palette):
"""
Converts a mask (H, W, K) to (H, W, C)
"""
x = np.argmax(mask, axis=-1)
colour_codes = np.array(palette)
x = np.uint8(colour_codes[x.astype(np.uint8)])
return x
2. keras
reference:https://www.cnblogs.com/skyfsm/p/8330882.html
对应github代码:https://github.com/AstarLight/Satellite-Segmentation/blob/master/unet/unet_train.py
没有尝试,具体对错需要自己试下
labelencoder = LabelEncoder()
labelencoder.fit(classes)
train_data = np.array(train_data)
train_label = np.array(train_label).flatten()
train_label = labelencoder.transform(train_label)
train_label = to_categorical(train_label, num_classes=n_label)