import torch
from torch_scatter import scatter
src = torch.range(1,24).view(2,6,2)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(src)
print(out)
Pytorch学习 (二十六)---- torch.scatter的使用
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.],
[19.