2023年CS22Wassignment中的所有colab答案以及注释已经上传到github:https://github.com/yuyu990116/CS224W-assignment
CS224W课程地址:http://web.stanford.edu/class/cs224w/
class GraphSage(MessagePassing):
def __init__(self, in_channels, out_channels, normalize = True,
bias = False, **kwargs):
super(GraphSage, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.lin_l=Linear(in_channels,out_channels) 这一层的hv
self.lin_r=Linear(in_channels,out_channels) 聚合上一层的hu
self.reset_parameters()
def reset_parameters(self):
self.lin_l.reset_parameters()
self.lin_r.reset_parameters()
def forward(self, x, edge_index, size = None):
out=self.propagate(edge_index,x=(x,x),size=size)
#out=self.propagate(edge_index,x=x,size=size)
x=self.lin_l(x)
out=self.lin_r(out)
out=out+x
if self.normalize:
out=F.normalize(out)
return out
def message(self, x_j):
out=x_j
return out
def aggregate(self, inputs, index, dim_size = None):
node_dim = self.node_dim
out=torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')
return out
scatter
函数的参数含义如下:
inputs
:输入的张量,可以是任意形状的张量。index
:用于指定在哪个维度上进行聚合的索引张量。索引张量的形状应与inputs
的形状匹配,除了指定的维度上的大小应与聚合操作后的维度大小相同。node_dim
:指定在哪个维度上进行聚合操作。dim_size
:指定聚合操作后的维度大小。reduce
:指定聚合操作的方式,可选择的值包括’mean’、‘sum’、‘min’、'max’等。
scatter
函数的作用是根据指定的索引,在指定的维度上对输入进行聚合操作。聚合操作可以是求均值、求和、取最小值、取最大值等,具体取决于reduce
参数的值。
聚合操作的结果将保存在out
中,out
的形状与inputs
相同,除了指定的维度上的大小将变为dim_size
。