GNN Augmentation and Training
0. A General GNN Framework
Idea: raw input graph ≠ \neq = computational graph
- Graph feature augmentation
- Graph structure manipulation
1). Why Agument Graphs?
Our assumption so far has been: raw input graph = computational graph
Reasons for breaking this assumption
- Features
The input graph lacks features - Graph structure
The graph is too sparse → \rightarrow → inefficient message passing
The graph is too dense → \rightarrow → message passing is too costly
The graph is too large → \rightarrow → cannot fit the computational graph into a GPU
It is unlikely that the input graph happens to be the optimal computation graph for embeddings.
2). Graph Augmentation Approaches
- Graph feature augmentation
The input graph lacks features → \rightarrow → feature augmentation - Graph structure augmentation
The graph is too sparse → \rightarrow → add virtual nodes / edges
The graph is too dense → \rightarrow → sample neighbors when doing message passing
The graph is too large → \rightarrow → sample subgraphs to compute embeddings
1. Feature Augmentation on Graphs
Why Do We Need Feature Augmentation?
- Input graph does not have node features
- Certain structures are hard to learn by GNN
1). Input Graph Does Not Have Node Features
This is common when we only have the adjacent matrix
Standard approaches:
- Assign constant values to nodes
- Assing unique IDs to nodes (one-hot vectors)
Constant node feature | One-hot node feature | |
---|---|---|
Expressive power | Medium. All nodes are identical but GNN can still learn from the graph structure | High. Each node has a unique ID, so node-specific information can be stored |
Inductive learning (generalize to unseen nodes) | High. Simple to generalize to new nodes | Low. Cannot generalize to new nodes: new nodes introduce new IDs, GNN does not know how to embed unseen IDs |
Computational cost | Low. Only 1 dimensional feature | High. O ( ∥ V ∥ ) O(\|V\|) O(∥V∥) dimensional feature, cannot apply to large graphs |
Use cases | Any graph. inductive settings (generalize to new nodes) | Small graph, transductive settings (no new nodes) |
2). Certain Structures Are Hard to Learn by GNN
Example: GNN cannot learn the length of a cycle because all the nodes in the graph have degree of 2 and the computational graphs will be the same binary tree.
Solution: we can use cycle count as augmented node features
Other solutions: node degree, clustering coefficient, PageRank, Centrality, …
2. Structure Augmentation on Graphs
1). Add Virtual Nodes / Edges
Motivation: augment sparse graphs
a). Add virtua edges
- Common approach: connect 2-hop neighbors via virtual edges
- Intuition: instead of using adjacent matrix A A A for GNN computation, use A + A 2 A+A^2 A+A2
- Use cases: bipartite graphs
Author-to-papers: 2-hop virtual edges make an author-author collaboration graph
b). Add virtua nodes
The virtual node will connect to all the nodes in the graph and all nodes will have a distance of two: Node A - Virtual node - Node B
Benefits: greatly improve message passing in sparse graphs
2). Node Neighborhood Sampling
Idea: (randomly) sample a node’s neighborhood for message passing
Example: we can randomly choose 2 neighbors to pass messages in a given layer; in the next layer when we compute the embeddings we can sample different neighbors. In expectation, we get embeddings similar to the case where all the neighbors are used.
Befenits: greatly reduce computational cost allowing for scaling to large graph
3. Training GNNs
1). GNN Prediction Heads
Idea: different task levels require different prediction heads
a). Node-level prediction
Directly make prediction using node embeddings
After GNN computation, we have
d
d
d-dim node embeddings {
h
v
l
∈
R
d
,
∀
v
∈
G
h_v^l\in R^d, \forall v\in G
hvl∈Rd,∀v∈G}
For a
k
k
k-way prediction problem
- Classification: classify among k k k categories
- Regression: regress on k k k targets
y
^
v
=
H
e
a
d
n
o
d
e
(
h
v
l
)
=
W
H
h
v
l
\hat{y}_v=Head_{node}(h_v^l)=W^Hh_v^l
y^v=Headnode(hvl)=WHhvl
where
W
H
∈
R
k
×
d
W^H\in R^{k\times d}
WH∈Rk×d maps node embeddings from
h
v
l
∈
R
d
h_v^l\in R^d
hvl∈Rd to
y
^
v
∈
R
k
\hat{y}_v\in R^k
y^v∈Rk so that we can compute the loss.
b). Edge-level prediction
Make prediction using pairs of node embeddings
For a
k
k
k-way prediction problem
y
^
u
v
=
H
e
a
d
e
d
g
e
(
h
u
l
,
h
v
l
)
\hat{y}_{uv}=Head_{edge}(h_u^l, h_v^l)
y^uv=Headedge(hul,hvl)
Options for
H
e
a
d
e
d
g
e
(
h
u
l
,
h
v
l
)
Head_{edge}(h_u^l, h_v^l)
Headedge(hul,hvl)
- Concatenation + Linear
y ^ u v = Linear ( Concat ( h u l , h v l ) ) \hat{y}_{uv}=\text{Linear}(\text{Concat}(h_u^l, h_v^l)) y^uv=Linear(Concat(hul,hvl))
where Linear ( ⋅ ) \text{Linear}(\cdot) Linear(⋅) will map 2 d 2d 2d-dim embeddings to k k k-dim embeddings - Dot product
For 1-way prediction (e.g., link prediction: predict the existense of an edge):
y ^ u v = ( h u l ) T h v l \hat{y}_{uv}=(h_u^l)^Th_v^l y^uv=(hul)Thvl
Apply to k k k-way prediction, similar to multi-head attention: W 1 , ⋯ , W k W^1, \cdots, W^k W1,⋯,Wk trainable weight matrices
y ^ u v 1 = ( h u l ) T W 1 h v l \hat{y}_{uv}^1=(h_u^l)^TW^1h_v^l y^uv1=(hul)TW1hvl
⋯ \cdots ⋯
y ^ u v k = ( h u l ) T W k h v l \hat{y}_{uv}^k=(h_u^l)^TW^kh_v^l y^uvk=(hul)TWkhvl
y ^ u v = Concat ( y ^ u v 1 , ⋯ , y ^ u v k ) ∈ R k \hat{y}_{uv}=\text{Concat}(\hat{y}_{uv}^1, \cdots,\hat{y}_{uv}^k)\in R^k y^uv=Concat(y^uv1,⋯,y^uvk)∈Rk
c). Graph-level prediction
Make prediction using all node embeddings in the graph
For a
k
k
k-way prediction problem
y
^
G
=
H
e
a
d
g
r
a
p
h
(
{
h
v
l
∈
R
d
,
∀
v
∈
G
}
)
\hat{y}_G=Head_{graph}(\{h_v^l\in R^d, \forall v\in G\})
y^G=Headgraph({hvl∈Rd,∀v∈G})
where
H
e
a
d
g
r
a
p
h
(
⋅
)
Head_{graph}(\cdot)
Headgraph(⋅) is similar to
AGG
(
⋅
)
\text{AGG}(\cdot)
AGG(⋅) in a GNN layer
Options for
H
e
a
d
g
r
a
p
h
(
{
h
v
l
∈
R
d
,
∀
v
∈
G
}
)
Head_{graph}(\{h_v^l\in R^d, \forall v\in G\})
Headgraph({hvl∈Rd,∀v∈G}) in small graphs
- Global mean pooling: y ^ G = Mean ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Mean}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Mean({hvl∈Rd,∀v∈G})
- Global max pooling: y ^ G = Max ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Max}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Max({hvl∈Rd,∀v∈G})
- Global sum pooling: y ^ G = Sum ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Sum}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Sum({hvl∈Rd,∀v∈G})
Issue: global pooling over a (large) graph will lose information
Solution: aggregate all the node embeddings hierarchically (DiffPool)
2). Labels
a). Supervised labels on graphs
Supervised labels com from specific use cases
- Node labels y v y_v yv: in a citation network, which subject area does a node belong to
- Edge labels y u v y_{uv} yuv: in a transaction network, whether an edge is fraudlent
- Graph labels y G y_G yG: among molecular graphs, the drug likeness of graphs
Advice: reduce your task to node / edge / graph labels since they are easy to work with
b). Unsupervised labels on graphs
Problem: sometimes we only have a graph withour any external labels
Solution: “self-supervised learning”; we can find supervision labels within the graph
- Node labels y v y_v yv: node statistics such as clustering coefficient, PageRank, …
- Edge labels y u v y_{uv} yuv: link prediction hiding the edge between two nodes and predicting if there should be a link
- Graph labels y G y_G yG: graph statistics predicting if two graphs isomorphic
3). Loss Function
- Classification loss: cross entropy (CE) is a very common loss function in classification
- Regression loss: we often ise mean square error (MSE) aka L2 loss
4). Evaluation Metrics
a). Regression
- Root MSE (RMSE)
- Mean absulute error (MAE)
b). Classification
- Multi-class classification
1 [ argmax ( y ^ i ) = y i ] N \frac{1[\text{argmax}(\hat{y}^i)=y^i]}{N} N1[argmax(y^i)=yi] - Binary-class classification
Accuracy: metric sensitive to classification
Precision / recall: metric sensitive to classification
ROC AUC: metric agnostic to classification
Actually Positive (1) | Actually Negative (0) | |
---|---|---|
Predicted Positive (1) | True Positives (TPs) | False Positives (TPs) |
Predicted Negative (0) | False Negatives (TNs) | True Negatives (TNs) |
Accuracy:
TP+TN
TP+TN+FP+FN
=
TP+TN
∣
Dataset
∣
\frac{\text{TP+TN}}{\text{TP+TN+FP+FN}}=\frac{\text{TP+TN}}{|\text{Dataset}|}
TP+TN+FP+FNTP+TN=∣Dataset∣TP+TN
Precision (P):
TP
TP+FP
\frac{\text{TP}}{\text{TP+FP}}
TP+FPTP
Recall (R):
TP
TP+FN
\frac{\text{TP}}{\text{TP+FN}}
TP+FNTP
F1-score:
2P*R
P+R
\frac{\text{2P*R}}{\text{P+R}}
P+R2P*R
ROC curve: captures the tradeoff in TPR (
TP
TP+FN
\frac{\text{TP}}{\text{TP+FN}}
TP+FNTP=Reacll) and FPR (
FP
FP+TN
\frac{\text{FP}}{\text{FP+TN}}
FP+TNFP) as the classification threshold is varied for a binary classifier.
ROC AUC: area under the ROC curve
5). Dataset Split: Fixed / Random Split
Fixed split: split dataset once
- Training set: used for optimizing GNN parameters
- Validation set: develop model / hyperparameters
- Test set: held out until reporting final performance
Random split: randomly split dtaset into training / validation / test set
a). Why splitting graphs is special
Image classification: each data point is an image and data points are independent
Node classification: each data point is a node and data points are NOT independent
Solutions
- Transductive setting: the entire input graph can be observed in all the dataset splits (training. validation and test set) and only the (node) labels are split
- Inductive setting: break the edges between splits to get multiple independent graphs
Transductive setting | Inductive setting | |
---|---|---|
training/validation/test | On the same entire graph | On different graphs |
Applications | Node/edge tasks | Node/edge/graph tasks |
b). Example: node classification
- Transductive node classification
All the splits can observe the entire graph structure but can only observe the labels of their respective nodes. - Inductive node classification
Each splits contains an independent graph.
c). Example: graph classification
Only the inductive setting is well defined for graph classification because we have to test on unseen graphs.
d). Example: link prediction
- Goal of link prediction: predict missing edges
- Link prediction is an unsupervised / self-supervised task. We nned to create the labels and dataset splits on our own
- Concretely, we need to hide some edges from the GNN and let the GNN predict if the edges exist
Setting up link prediction
Split edges twice
- Step 1: Assign 2 types of edges in the original graph - message edges (for GNN message passing) and supervision edges (for computing objectives). After Step 1, only message edges remain in the graph and supervision edges are used as supervision for edge prediction made by the model, will not be fed into GNN.
- Step 2: split edges into training / validation / test
Option 1 for Step 2: inductive link prediction split. Each inductive split contains an independent graph and each graph has two types of edges - message edges and supervision edges
Option 2 for Step 2: transductive link prediction split (default setting). The entire graph can be observed in all dataset splits by definition of “transductive”. But since edges are both part of graph structure and the supervision, we need to hold out validation / test edges. To train the training set, we further need to hold out supervision edges for the training set
- 1). At training time: use training message edges to predict training supervision edges
- 2). At validation time: use training message edges & training supervision edges to predict validation edges
- 3). At test time, use training message edges & training supervision edges & validation edges to predict test edges