id2code={0: (64, 128, 64),
1: (128, 128, 64),
2: (0, 128, 192),
3: (128, 0, 0),
4: (0, 128, 64),
5: (64, 0, 128),
6: (64, 0, 192),
7: (192, 128, 64),
8: (192, 192, 128),
9: (64, 64, 128),
10: (128, 0, 192),
11: (192, 0, 64),
12: (192, 0, 128),
13: (192, 0, 192),
14: (128, 64, 64),
15: (64, 192, 128),
16: (64, 64, 0),
17: (128, 64, 128),
18: (128, 128, 192),
19: (0, 0, 192),
20: (192, 128, 128),
21: (128, 128, 128),
22: (64, 128, 192),
23: (0, 0, 64),
24: (0, 64, 64),
25: (192, 64, 128),
26: (128, 128, 0),
27: (192, 128, 192),
28: (64, 0, 64),
29: (192, 192, 0),
30: (0, 0, 0),
31: (64, 192, 0)}
def onehot_to_rgb(onehot, colormap = id2code):
'''Function to decode encoded mask labels
Inputs:
onehot - one hot encoded image matrix (height x width x num_classes)
colormap - dictionary of color to label id
Output: Decoded RGB image (height x width x 3)