Scatter函数
scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.
简单理解就是将 src 张量中的元素散落到 self 张量中,具体选择哪个元素,选择的元素散落到哪个位置由index张量决定,具体的映射规则为:
# 其中 i,j,k 为index张量中元素坐标。
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
参数
- dim(int) 指index数组元素替代的坐标(dim = 0 替代src中的横坐标)
- index (LongTensor) 可以为空,最大与src张量形状相同
- src(Tensor or float) 源张量
- reduce 聚集函数(src替换元素与self中被替换元素执行的操作,默认是替代,可以进行add,multiply等操作)
具体例子:
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0