官方文档
message passing networks
Torch geometric GCNConv 源码分析
补充说明
这里附上以上博客中提到的D,A
这里附上综合以上博客,对 GCNConv 注释后的代码
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels): # 构造的时候必须输入in,out
super(GCNConv,self)._init_(aggr='add')
self.lin = torch.nn.Linear(in_channels,out_channels)
def forward(self, x, edge_index): # 调用的时候必须输入 x, edge_index
# x has shape [N, in_channels]
# edge_index has shape [2, E]
#### Steps 1-2 are typically computed before message passing takes place.
# Step 1: Add self-loops to the adjacency matrix.
edge_index,_ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix. 压缩 node feature
x = self.lin(x)
#### Steps 3-5 can be easily processed using the torch_geometric.nn.MessagePassing base class.
# Step 3-5: Start propagating messages.
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)#得到x_j
def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# Step 3: Normalize node features.
row,col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype) # [N, ] dtype是数据类型
deg_inv_sqrt = deg.pow(-0.5) # [N(-0.5), ]
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # (N(-0.5),N(-0.5),E)
# step 4:Aggregation.
return norm.view(-1,1) * x_j # [N, E]*[E, out_channels]=[N, out_channels]
def update(self, aggr_out):
# aggr_out has shape [N, out_channels]
# Step 5: Return new node embeddings.
return aggr_out