将针对全连接网络和全卷积网络输出的形式不同,将one hot编码分两种情况:
-
第一种针对网络输出是二维,即全连接层的输出形式, [Batchsize, Num_class]
-
第二种针对输出是四维特征图,即分割网络的输出形式,[Batchsize, Num_class, H,W]
第二种方式:[Batchsize, Num_class, H,W]
input.scatter_(dim, index, src) → Tensor
参数:
- input: 我们需要插入数据的起源tensor;也就是想要改变内部数据的tensor
- dim:我们想要从哪个维度去改input数据
- index:给出改的元素索引,也就是位置,说在“坐标”可能好理解一点
- src:准备好的插入的数据
def get_one_hot(self,label):
size = label.size()
# create one-hot vector for label map
label_nc = self.cfg.DATASET["label_nc"]
oneHot_size = (size[0], label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label.data.long().cuda(), 1.0)
return input_label