CS224W作业class GraphSage(MessagePassing)

2023年CS22Wassignment中的所有colab答案以及注释已经上传到github:https://github.com/yuyu990116/CS224W-assignment
CS224W课程地址:http://web.stanford.edu/class/cs224w/

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.lin_l=Linear(in_channels,out_channels) 这一层的hv
        self.lin_r=Linear(in_channels,out_channels) 聚合上一层的hu
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        out=self.propagate(edge_index,x=(x,x),size=size)
        #out=self.propagate(edge_index,x=x,size=size)
        x=self.lin_l(x)
        out=self.lin_r(out)
        out=out+x
        if self.normalize:
            out=F.normalize(out)
        return out

    def message(self, x_j):
    	out=x_j
    	return out
    	
    def aggregate(self, inputs, index, dim_size = None):
    	node_dim = self.node_dim
	out=torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')
		return out

scatter函数的参数含义如下:

  • inputs:输入的张量,可以是任意形状的张量。
  • index:用于指定在哪个维度上进行聚合的索引张量。索引张量的形状应与inputs的形状匹配,除了指定的维度上的大小应与聚合操作后的维度大小相同。
  • node_dim:指定在哪个维度上进行聚合操作。
  • dim_size:指定聚合操作后的维度大小。
  • reduce:指定聚合操作的方式,可选择的值包括’mean’、‘sum’、‘min’、'max’等。

scatter函数的作用是根据指定的索引,在指定的维度上对输入进行聚合操作。聚合操作可以是求均值、求和、取最小值、取最大值等,具体取决于reduce参数的值。

聚合操作的结果将保存在out中,out的形状与inputs相同,除了指定的维度上的大小将变为dim_size

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值