Implement the GraphSAGE layer directly
1.GraphSage
对于一个具有编码
h
v
l
−
1
h_v^{l-1}
hvl−1的中心节点
v
v
v,进行下一步状态更新的规则为:
h
v
(
l
)
=
W
l
⋅
h
v
(
l
−
1
)
+
W
r
⋅
A
G
G
(
{
h
u
(
l
−
1
)
,
∀
u
∈
N
(
v
)
}
)
h_v^{(l)} = W_l\cdot h_v^{(l-1)} + W_r \cdot AGG(\{h_u^{(l-1)}, \forall u \in N(v) \})
hv(l)=Wl⋅hv(l−1)+Wr⋅AGG({hu(l−1),∀u∈N(v)})
W
l
W_l
Wl 和
W
r
W_r
Wr为可学习的权重,
N
(
v
)
N(v)
N(v) 代表
v
v
v的邻接节点。
A
G
G
(
⋅
)
AGG(·)
AGG(⋅) 为消息聚合函数,当采用 mean aggregation时,有
A
G
G
(
{
h
u
(
l
−
1
)
,
∀
u
∈
N
(
v
)
}
)
=
1
∣
N
(
v
)
∣
∑
u
∈
N
(
v
)
h
u
(
l
−
1
)
AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) = \frac{1}{|N(v)|} \sum_{u\in N(v)} h_u^{(l-1)}
AGG({hu(l−1),∀u∈N(v)})=∣N(v)∣1u∈N(v)∑hu(l−1)
2.Implement
(1)实现方法
实现分三步,分别为
1)每一个邻居 u u u节点传递当前状态 u l − 1 u^{l-1} ul−1;
2)中心节点 v v v 使用聚合函数聚合收到的消息,在GraphSage中为简单求平均;
3)中心节点使用聚合消息更新自己的状态,在GraphSage中为残差。
(2)实现步骤
pytorch
提供了MessagePassing
父类,我们借此可以简洁实现消息传递。
class GraphSage(MessagePassing):
def __init__(self, in_channels, out_channels, normalize = True,
bias = False, **kwargs):
super(GraphSage, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.lin_l=nn.Linear(in_features=in_channels, out_features=out_channels)
self.lin_r=nn.Linear(in_features=in_channels, out_features=out_channels)
def message(self, x_j):
out = None
out = self.lin_r(x_j)
return out
def aggregate(self, inputs, index, dim_size = None):
out = None
node_dim = self.node_dim
out=torch_scatter.scatter(inputs, index, dim=node_dim,reduce='mean')
return out
def forward(self, x, edge_index, size = None):
out=self.propagate(edge_index,x=(x,x))
out=self.lin_l(x)+out
if self.normalize:
out=F.normalize(out)
return out
①message
函数定义全局消息传递的内容。参数x_j
描述所有消息传递关系中源节点的特征,形状为
[
∣
E
∣
,
d
]
[|\mathcal{E}|, d]
[∣E∣,d],
(
i
,
j
)
∈
E
(i, j) \in \mathcal{E}
(i,j)∈E.
②aggregate
函数定义了中心节点接收和聚合消息的方法。参数inputs
是message
函数的返回值,index
描述了每个中心节点
v
v
v接收来自邻居节点
u
u
u的消息在inputs
的哪一行行。scatter
函数声明为
torch_scatter.scatter(input: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum')→ Tensor[source]
函数功能为用index
在dim
指定的维度索引张量input
,再根据reduce
规则计算返回值。
如图所示,中心节点0的邻居节点在input
的第0、1、3个索引。
③propagate
函数定义在MessagePassing
父类。用于启动一次消息传递过程。edge_index
为整张图的边索引信息,形状是
[
2
,
E
]
[2,\mathcal{E}]
[2,E]。参数x
存放邻居节点和中心节点的特征。因为每个节点既是中心节点又是邻居节点,且采用一样的特征描述,所以元组的两个元素是一样的。propagate
函数会自动调用message
和aggregate
完成消息传递和消息聚合。
④当GraphSage
对象被调用时,默认调用forward
来启动消息传递。forward
函数返回更新后的节点特征张量,形状为
[
∣
N
∣
,
d
]
[|N|, d]
[∣N∣,d].
N
N
N是所有节点的集合。
3. Train and Test
使用CORA dataset数据集进行节点分类任务。训练过程如下