TORCH.TENSOR.SCATTER_
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
For a 3-D tensor, self
is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in gather().
self
, index
and src
(if it is a Tensor) should all have the same number of dimensions. It is also required that index.size(d) <= src.size(d)
for all dimensions d
, and that index.size(d) <= self.size(d)
for all dimensions d != dim
. Note that index
and src
do not broadcast.
Moreover, as for gather(), the values of index
must be between 0
and self.size(dim) - 1
inclusive.
Pytorch中,将label变成one hot编码的两种方式_咆哮的阿杰的博客-CSDN博客_label onehot编码
【函数小trick】torch中scatter()、scatter_()详解(多标签one-hot向量生成)_诸葛灬孔暗的博客-CSDN博客_torch生成onehot