ref:
- https://pytorch.org/docs/1.4.0/tensors.html?highlight=index_add_#torch.Tensor.index_add_
- https://blog.csdn.net/weixin_44289071/article/details/103882658
torch.Tensor.index_add_能实现指定行或列的内容相加的功能,类似于tensorflow中tf.unsorted_segment_sum
函数,可以用在比如实例分割中进行特征聚合的步骤。比如一个N*C
的feature根据实例label可以将属于同一实例的点的特征聚合起来,得到Ins_num*C
的聚合特征。
1. 函数的参数
- dim:这个参数表明你要沿着哪个维度索引;
- index:包含索引的tensor;
- tensor:被索引出来去相加的tensor;
- 注意事项:
x
相加前后的shape保持不变,被索引的tensor在被索引的维度(第dim维)之外的维度上与tensor的对应维度必须保持一致,且index
中的值最大不能超过x
在被索引的维度上的最大维数,index
的长度必须和tensor[dim]
相同。假如x的shape
为(N, C)
,索引的维度为第0维(dim=0
),那么被索引的tensor的dim=1
的维度也必须为C
,index的值必须介于0
和C-1
之间,index
的长度必须和被索引的tensor的dim=0
的数字相同。
2. 使用示例
import torch
x = torch.ones(5