Pytorch 的 Tensor 用法
官方解释:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_add#torch.Tensor.scatter_add_
函数参数:scatter_add_
(dim, indexTensor, otherTensor) → 输出Tensor
函数用法:selfTensor.scatter_add_
(dim, indexTensor, otherTensor)
要求:
self
,index
andother
should have same number of dimensions.index.size(d) <= other.size(d)
for all dimensionsd
index.size(d) <= self.size(d)
for all dimensionsd != dim
.- as for
gather()
, the values ofindex
must be between0
andself.size(dim) - 1
- all values in a row along the specified dimension
dim
must be unique.
示例代码:final_dist = vocab_dist_.scatter_add(1, enc_batch_extend_vocab, attn_dist_)
该函数将 otherTensor 的所有值加到 selfTensor 中,加入位置由 indexTensor 指明。
self[ index[i][j][k] ][ j ][ k ] += other[ i ][ j ][ k ] # if dim == 0