scatter函数:
scatter_(dim, index, src) → Tensor
Parameters:
- dim (int) – the axis along which to index index (LongTensor) – the
- indices of elements to scatter, can be either empty or the same size
of src. When empty, the operation returns identity - src (Tensor) – the source element(s) to scatter, incase value is not specified
- value (float) – the source element(s) to scatter, incase src is not specified
(1)维度dim:整数,可以是0,1,2,3…
(2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置
(3)原数组input:也是一个tensor,其中的数据类型任意
gather函数:
torch.gather(input, dim, index, out=None) → Tensor
Parameters:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
- out (Tensor, optional) – 目标张量