MessagePassing
Message passing is the essence of GNN which describes how node embeddings are learned. I have talked about in my last post, so I will just briefly run through this with terms that conform to the PyG documentation.
x denotes the node embeddings, e denotes the edge features, 𝜙 denotes the message function, □ denotes the aggregation function, 𝛾 denotes the update function. If the edges in the graph have no feature other than connectivity, e is essentially the edge index of the graph. The superscript represents the index of the layer. When k=1, x represents the input feature of each node. Below I will illustrate how each function works:
-
propagate(edge_index, size=None, **kwargs):
It takes in edge index and other optional information, such as node features (embedding). Calling this function will consequently call message and update. -
message(**kwargs):
You specify how you construct “message” for each of the node pair (x_i, x_j). Since it follows the calls of propagate, it can take any argument passing to propagate. One thing to note is that you can define the mapping from arguments to the specific nodes with “_i” and “_j”. Therefore, you must be very careful when naming the argument of this function. -
update(aggr_out, **kwargs)
It takes in the aggregated message and other arguments passed into propagate, assigning a new embedding value for each node.
Example
Let’s see how we can implement a SageConv layer from the paper “Inductive Representation Learning on Large Graphs”. The message passing formula of SageConv is defined as:
https://arxiv.org/abs/1706.02216
Here, we use max pooling as the aggregation method. Therefore, the right-hand side of the first line can be written as:
https://arxiv.org/abs/1706.02216
which illustrates how the “message” is constructed. Each neighboring node embedding is multiplied by a weight matrix, added a bias and passed through an activation function. This can be easily done with torch.nn.Linear.
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max')
self.lin = torch.nn.Linear(in_channels, out_channels)
self.act = torch.nn.ReLU()
def message(self, x_j):
# x_j has shape [E, in_channels]
x_j = self.lin(x_j)
x_j = self.act(x_j)
return x_j
As for the update part, the aggregated message and the current node embedding is aggregated. Then, it is multiplied by another weight matrix and applied another activation function.
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max')
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
self.update_act = torch.nn.ReLU()
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]
new_embedding = torch.cat([aggr_out, x], dim=1)
new_embedding = self.update_lin(new_embedding)
new_embedding = torch.update_act(new_embedding)
return new_embedding
Putting it together, we have the following SageConv layer.
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
class SAGEConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(SAGEConv, self).__init__(aggr='max') # "Max" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)
self.act = torch.nn.ReLU()
self.update_lin = torch.nn.Linear(in_channels + out_channels, in_channels, bias=False)
self.update_act = torch.nn.ReLU()
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
def message(self, x_j):
# x_j has shape [E, in_channels]
x_j = self.lin(x_j)
x_j = self.act(x_j)
return x_j
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]
new_embedding = torch.cat([aggr_out, x], dim=1)
new_embedding = self.update_lin(new_embedding)
new_embedding = self.update_act(new_embedding)
return new_embedding