scatter_add()函数

Pytorch 的 Tensor 用法

官方解释:https://pytorch.org/docs/stable/tensors.html?highlight=scatter_add#torch.Tensor.scatter_add_

函数参数:scatter_add_(dim,  indexTensor,  otherTensor) → 输出Tensor

函数用法:selfTensor.scatter_add_(dim,  indexTensor,  otherTensor)

要求:

  1. selfindex and other should have same number of dimensions.
  2. index.size(d) <= other.size(d) for all dimensions d
  3. index.size(d) <= self.size(d) for all dimensions d != dim.
  4. as for gather(), the values of index must be between 0 and self.size(dim) - 1
  5. all values in a row along the specified dimension dim must be unique.

示例代码:final_dist = vocab_dist_.scatter_add(1,  enc_batch_extend_vocab,  attn_dist_)

该函数将 otherTensor 的所有值加到 selfTensor 中,加入位置由 indexTensor 指明。

self[ index[i][j][k] ][ j ][ k ] += other[ i ][ j ][ k ]  # if dim == 0

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值