scatter_max 这个函数还挺有用的,pointnet见到的
from torch_scatter import scatter_max
x, _ = scatter_max(x, data.batch, dim=0)
例子
from torch_scatter import scatter_max
import torch
if __name__ == "__main__":
t1 = torch.tensor([[0, 2], [2, 2], [3, 4], [7, 8], [3, 5]])
t2 = torch.tensor([0, 0, 0, 1, 1])
out, _ = scatter_max(t1, t2, dim=0)
print(out)
输出结果:
tensor([[3,4],
[7,8]])
1.dim=0代表查询维度
2.根据 t2 把 t1 的每个元素归组,这里前三个是一组,后俩是一组。
3.将第一组和第二组中最大元素挑出来。
4.输出:第一组 [3, 4] ,第二组 [7, 8] 。