假设存在如下标签矩阵,包含4个类,0,1,2,3。label 的维度为3行 5列。
>>> label = torch.randint(1,4,(3,5))
>>> label
tensor([[2, 2, 1, 3, 1],
[1, 2, 1, 1, 1],
[1, 3, 3, 1, 3]])
>>>
(1)扩展label维度,因为label含有4个类的标签,因此扩展label的维度 至 3x 5 x 4,即三个维度,并且在第三维度上,重复label。
>>> import numpy as np
>>> label_expand = np.tile(np.expand_dims(label,axis=2),(1,1,4))
>>> label_expand[:,:,0]
array([[2, 2, 1, 3, 1],
[1, 2, 1, 1, 1],
[1, 3, 3, 1, 3]])
>>> label_expand[:,:,1]
array([[2, 2, 1, 3, 1],
[1, 2, 1, 1, 1],
[1, 3, 3, 1, 3]])
>>> label
tensor([[2, 2, 1, 3, 1],
[1, 2, 1, 1, 1],
[1, 3, 3, 1, 3]])
>>>
(2) 复制label_expand 到 labels,将labels所有大于或等于0的元素的值重新赋值为1,然后沿着第三个维度,进行累加:np.cumsum()
>>> labels[labels>=0]=1
>>> labels = np.cumsum(labels,axis=2)
>>> labels[:,:,0]
array([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]])
>>>
(4) 对labels, label_expand 每个类的通道进行观察,对label_expand+1之后,就可以在每个类通道中找到对应位置,进而将第2个通道(类1通道)中其他不是标签1的位置重置为0,仅标签1的位置重置为1。
通过上述分析,可以将labels 中不等于 label_expand + 1的位置的元素重置为0,然后将label中非0元素重置为1。这样就可以实现对应类通道只保存对应类标签的信息,其他类重置为0。
>>> labels[labels!=label_expand+1]=0
>>> labels[labels!=0]=1
>>> labels[:,:,0]
array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
>>> labels[:,:,1]
array([[0, 0, 1, 0, 1],
[1, 0, 1, 1, 1],
[1, 0, 0, 1, 0]])
>>> labels[:,:,2]
array([[1, 1, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 0]])
>>> label
tensor([[2, 2, 1, 3, 1],
[1, 2, 1, 1, 1],
[1, 3, 3, 1, 3]])
>>>