torch_scatter

官方文档

文章目录

scatter

用一张官网的图
在这里插入图片描述

import torch
from torch_scatter import scatter

index = torch.tensor([0,0,1,0,2])
input = torch.tensor([[1,1],[1,1],[2,2],[1,1],[1,1]])
result = scatter(input,index,dim=0,reduce="sum")
"""
tensor([[3, 3],
        [2, 2],
        [1, 1]])
"""
index =00102
input =[1,1][1,1][2,2][1,1][1,1]

index = 0 的 input有 :[1,1] [1,1] [1,1] ,sum为[3,3]
index = 1 的 input有 :[2,2] ,sum为[2,2]
index = 2 的 input有 :[1,1],sum为[1,1]

inputshape[5,2],由于函数中dim=0,而index有3个不用的值,index所以将5换成3.result的形状应为[5 3,2]
故:

  • result[0] = [3,3]
  • result[1] = [2,2]
  • result[2] = [1,1]
index = torch.tensor([0,0,1])
input = torch.tensor([[1,1,1],[1,1,2],[2,2,3],[21,10,9]])
scatter(input,index,dim=1,reduce="sum")
"""
tensor([[ 2,  1],
        [ 2,  2],
        [ 4,  3],
        [31,  9]])
"""

index = 0 的 input有 : [ 1 , 1 , 2 , 21 ] T [1,1,2,21]^T [1,1,2,21]T [ 1 , 1 , 2 , 10 ] T [1,1,2,10]^T [1,1,2,10]T ,sum为 [ 2 , 2 , 4 , 31 ] T [2,2,4,31]^T [2,2,4,31]T
index = 1 的 input有 : [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T,sum为 [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T

在这里插入图片描述
reduce参数可选值有:

  • sum 求和
  • mul 乘法
  • mean 平均
  • min 最小
  • max 最大
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值