这里用到了Pytorch的scatter_函数:
scatter_(dim, index, src) → Tensor
Writes all values from the tensor
srcintoselfat the indices specified in theindextensor. For each value insrc, its output index is specified by its index insrcfordimension != dimand by the corresponding value inindexfordimension = dim.
For a 3-D tensor,selfis updated as:self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
注意要保证self和index维度一致
对于分类问题,标签可以是类别索引值也可以是one-hot表示。以10类别分类为例,lable=[3] 和label=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]是一致的.
>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
3
0
0
8
>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
0 0 0 1 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0
自己做图像分割时,把图像标签转成one-hot编码形式,原图像(1,512,512),生成one-hot(class_nums, 512, 512):
gt_onehot = torch.zeros((class_nums, gt.shape[1], gt.shape[2]))
gt_onehot.scatter_(0, gt.long(), 1)
参考:
1.https://pytorch.org/docs/stable/tensors.html?highlight=scatter_#torch.Tensor.scatter_
2.PyTorch——Tensor_把索引标签转换成one-hot标签表示
2038

被折叠的 条评论
为什么被折叠?



