这里用到了Pytorch的scatter_函数:
scatter_(dim, index, src) → Tensor
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.
For a 3-D tensor,self
is updated as:self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
注意要保证self和index维度一致
对于分类问题,标签可以是类别索引值也可以是one-hot表示。以10类别分类为例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一致的.
>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
3
0
0
8
>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
0 0 0 1 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0
自己做图像分割时,把图像标签转成one-hot编码形式,原图像(1,512,512),生成one-hot(class_nums, 512, 512):
gt_onehot = torch.zeros((class_nums, gt.shape[1], gt.shape[2]))
gt_onehot.scatter_(0, gt.long(), 1)
参考:
1.https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
2.PyTorch——Tensor_把索引标签转换成one-hot标签表示