CS224W: Machine Learning with Graphs - 08 GNN Augmentation and Training

GNN Augmentation and Training

0. A General GNN Framework

Idea: raw input graph ≠ \neq = computational graph

  • Graph feature augmentation
  • Graph structure manipulation
1). Why Agument Graphs?

Our assumption so far has been: raw input graph = computational graph
Reasons for breaking this assumption

  • Features
    The input graph lacks features
  • Graph structure
    The graph is too sparse → \rightarrow inefficient message passing
    The graph is too dense → \rightarrow message passing is too costly
    The graph is too large → \rightarrow cannot fit the computational graph into a GPU

It is unlikely that the input graph happens to be the optimal computation graph for embeddings.

2). Graph Augmentation Approaches
  • Graph feature augmentation
    The input graph lacks features → \rightarrow feature augmentation
  • Graph structure augmentation
    The graph is too sparse → \rightarrow add virtual nodes / edges
    The graph is too dense → \rightarrow sample neighbors when doing message passing
    The graph is too large → \rightarrow sample subgraphs to compute embeddings

1. Feature Augmentation on Graphs

Why Do We Need Feature Augmentation?

  • Input graph does not have node features
  • Certain structures are hard to learn by GNN
1). Input Graph Does Not Have Node Features

This is common when we only have the adjacent matrix
Standard approaches:

  • Assign constant values to nodes
  • Assing unique IDs to nodes (one-hot vectors)
Constant node featureOne-hot node feature
Expressive powerMedium. All nodes are identical but GNN can still learn from the graph structureHigh. Each node has a unique ID, so node-specific information can be stored
Inductive learning (generalize to unseen nodes)High. Simple to generalize to new nodesLow. Cannot generalize to new nodes: new nodes introduce new IDs, GNN does not know how to embed unseen IDs
Computational costLow. Only 1 dimensional featureHigh. O ( ∥ V ∥ ) O(\|V\|) O(V) dimensional feature, cannot apply to large graphs
Use casesAny graph. inductive settings (generalize to new nodes)Small graph, transductive settings (no new nodes)
2). Certain Structures Are Hard to Learn by GNN

Example: GNN cannot learn the length of a cycle because all the nodes in the graph have degree of 2 and the computational graphs will be the same binary tree.
Solution: we can use cycle count as augmented node features
Other solutions: node degree, clustering coefficient, PageRank, Centrality, …

2. Structure Augmentation on Graphs

1). Add Virtual Nodes / Edges

Motivation: augment sparse graphs

a). Add virtua edges
  • Common approach: connect 2-hop neighbors via virtual edges
  • Intuition: instead of using adjacent matrix A A A for GNN computation, use A + A 2 A+A^2 A+A2
  • Use cases: bipartite graphs
    Author-to-papers: 2-hop virtual edges make an author-author collaboration graph
b). Add virtua nodes

The virtual node will connect to all the nodes in the graph and all nodes will have a distance of two: Node A - Virtual node - Node B
Benefits: greatly improve message passing in sparse graphs

2). Node Neighborhood Sampling

Idea: (randomly) sample a node’s neighborhood for message passing
Example: we can randomly choose 2 neighbors to pass messages in a given layer; in the next layer when we compute the embeddings we can sample different neighbors. In expectation, we get embeddings similar to the case where all the neighbors are used.
Befenits: greatly reduce computational cost allowing for scaling to large graph

3. Training GNNs

1). GNN Prediction Heads

Idea: different task levels require different prediction heads

a). Node-level prediction

Directly make prediction using node embeddings
After GNN computation, we have d d d-dim node embeddings { h v l ∈ R d , ∀ v ∈ G h_v^l\in R^d, \forall v\in G hvlRd,vG}
For a k k k-way prediction problem

  • Classification: classify among k k k categories
  • Regression: regress on k k k targets

y ^ v = H e a d n o d e ( h v l ) = W H h v l \hat{y}_v=Head_{node}(h_v^l)=W^Hh_v^l y^v=Headnode(hvl)=WHhvl
where W H ∈ R k × d W^H\in R^{k\times d} WHRk×d maps node embeddings from h v l ∈ R d h_v^l\in R^d hvlRd to y ^ v ∈ R k \hat{y}_v\in R^k y^vRk so that we can compute the loss.

b). Edge-level prediction

Make prediction using pairs of node embeddings
For a k k k-way prediction problem
y ^ u v = H e a d e d g e ( h u l , h v l ) \hat{y}_{uv}=Head_{edge}(h_u^l, h_v^l) y^uv=Headedge(hul,hvl)
Options for H e a d e d g e ( h u l , h v l ) Head_{edge}(h_u^l, h_v^l) Headedge(hul,hvl)

  • Concatenation + Linear
    y ^ u v = Linear ( Concat ( h u l , h v l ) ) \hat{y}_{uv}=\text{Linear}(\text{Concat}(h_u^l, h_v^l)) y^uv=Linear(Concat(hul,hvl))
    where Linear ( ⋅ ) \text{Linear}(\cdot) Linear() will map 2 d 2d 2d-dim embeddings to k k k-dim embeddings
  • Dot product
    For 1-way prediction (e.g., link prediction: predict the existense of an edge):
    y ^ u v = ( h u l ) T h v l \hat{y}_{uv}=(h_u^l)^Th_v^l y^uv=(hul)Thvl
    Apply to k k k-way prediction, similar to multi-head attention: W 1 , ⋯   , W k W^1, \cdots, W^k W1,,Wk trainable weight matrices
    y ^ u v 1 = ( h u l ) T W 1 h v l \hat{y}_{uv}^1=(h_u^l)^TW^1h_v^l y^uv1=(hul)TW1hvl
    ⋯ \cdots
    y ^ u v k = ( h u l ) T W k h v l \hat{y}_{uv}^k=(h_u^l)^TW^kh_v^l y^uvk=(hul)TWkhvl
    y ^ u v = Concat ( y ^ u v 1 , ⋯   , y ^ u v k ) ∈ R k \hat{y}_{uv}=\text{Concat}(\hat{y}_{uv}^1, \cdots,\hat{y}_{uv}^k)\in R^k y^uv=Concat(y^uv1,,y^uvk)Rk
c). Graph-level prediction

Make prediction using all node embeddings in the graph
For a k k k-way prediction problem
y ^ G = H e a d g r a p h ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=Head_{graph}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Headgraph({hvlRd,vG})
where H e a d g r a p h ( ⋅ ) Head_{graph}(\cdot) Headgraph() is similar to AGG ( ⋅ ) \text{AGG}(\cdot) AGG() in a GNN layer
Options for H e a d g r a p h ( { h v l ∈ R d , ∀ v ∈ G } ) Head_{graph}(\{h_v^l\in R^d, \forall v\in G\}) Headgraph({hvlRd,vG}) in small graphs

  • Global mean pooling: y ^ G = Mean ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Mean}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Mean({hvlRd,vG})
  • Global max pooling: y ^ G = Max ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Max}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Max({hvlRd,vG})
  • Global sum pooling: y ^ G = Sum ( { h v l ∈ R d , ∀ v ∈ G } ) \hat{y}_G=\text{Sum}(\{h_v^l\in R^d, \forall v\in G\}) y^G=Sum({hvlRd,vG})

Issue: global pooling over a (large) graph will lose information
Solution: aggregate all the node embeddings hierarchically (DiffPool)

2). Labels
a). Supervised labels on graphs

Supervised labels com from specific use cases

  • Node labels y v y_v yv: in a citation network, which subject area does a node belong to
  • Edge labels y u v y_{uv} yuv: in a transaction network, whether an edge is fraudlent
  • Graph labels y G y_G yG: among molecular graphs, the drug likeness of graphs

Advice: reduce your task to node / edge / graph labels since they are easy to work with

b). Unsupervised labels on graphs

Problem: sometimes we only have a graph withour any external labels
Solution: “self-supervised learning”; we can find supervision labels within the graph

  • Node labels y v y_v yv: node statistics such as clustering coefficient, PageRank, …
  • Edge labels y u v y_{uv} yuv: link prediction hiding the edge between two nodes and predicting if there should be a link
  • Graph labels y G y_G yG: graph statistics predicting if two graphs isomorphic
3). Loss Function
  • Classification loss: cross entropy (CE) is a very common loss function in classification
  • Regression loss: we often ise mean square error (MSE) aka L2 loss
4). Evaluation Metrics
a). Regression
  • Root MSE (RMSE)
  • Mean absulute error (MAE)
b). Classification
  • Multi-class classification
    1 [ argmax ( y ^ i ) = y i ] N \frac{1[\text{argmax}(\hat{y}^i)=y^i]}{N} N1[argmax(y^i)=yi]
  • Binary-class classification
    Accuracy: metric sensitive to classification
    Precision / recall: metric sensitive to classification
    ROC AUC: metric agnostic to classification
Actually Positive (1)Actually Negative (0)
Predicted Positive (1)True Positives (TPs)False Positives (TPs)
Predicted Negative (0)False Negatives (TNs)True Negatives (TNs)

Accuracy: TP+TN TP+TN+FP+FN = TP+TN ∣ Dataset ∣ \frac{\text{TP+TN}}{\text{TP+TN+FP+FN}}=\frac{\text{TP+TN}}{|\text{Dataset}|} TP+TN+FP+FNTP+TN=DatasetTP+TN
Precision (P): TP TP+FP \frac{\text{TP}}{\text{TP+FP}} TP+FPTP
Recall (R): TP TP+FN \frac{\text{TP}}{\text{TP+FN}} TP+FNTP
F1-score: 2P*R P+R \frac{\text{2P*R}}{\text{P+R}} P+R2P*R
ROC curve: captures the tradeoff in TPR ( TP TP+FN \frac{\text{TP}}{\text{TP+FN}} TP+FNTP=Reacll) and FPR ( FP FP+TN \frac{\text{FP}}{\text{FP+TN}} FP+TNFP) as the classification threshold is varied for a binary classifier.
ROC AUC: area under the ROC curve

5). Dataset Split: Fixed / Random Split

Fixed split: split dataset once

  • Training set: used for optimizing GNN parameters
  • Validation set: develop model / hyperparameters
  • Test set: held out until reporting final performance

Random split: randomly split dtaset into training / validation / test set

a). Why splitting graphs is special

Image classification: each data point is an image and data points are independent
Node classification: each data point is a node and data points are NOT independent

Solutions

  • Transductive setting: the entire input graph can be observed in all the dataset splits (training. validation and test set) and only the (node) labels are split
  • Inductive setting: break the edges between splits to get multiple independent graphs
Transductive settingInductive setting
training/validation/testOn the same entire graphOn different graphs
ApplicationsNode/edge tasksNode/edge/graph tasks
b). Example: node classification
  • Transductive node classification
    All the splits can observe the entire graph structure but can only observe the labels of their respective nodes.
  • Inductive node classification
    Each splits contains an independent graph.
c). Example: graph classification

Only the inductive setting is well defined for graph classification because we have to test on unseen graphs.

d). Example: link prediction
  • Goal of link prediction: predict missing edges
  • Link prediction is an unsupervised / self-supervised task. We nned to create the labels and dataset splits on our own
  • Concretely, we need to hide some edges from the GNN and let the GNN predict if the edges exist

Setting up link prediction
Split edges twice

  • Step 1: Assign 2 types of edges in the original graph - message edges (for GNN message passing) and supervision edges (for computing objectives). After Step 1, only message edges remain in the graph and supervision edges are used as supervision for edge prediction made by the model, will not be fed into GNN.
  • Step 2: split edges into training / validation / test

Option 1 for Step 2: inductive link prediction split. Each inductive split contains an independent graph and each graph has two types of edges - message edges and supervision edges

Option 2 for Step 2: transductive link prediction split (default setting). The entire graph can be observed in all dataset splits by definition of “transductive”. But since edges are both part of graph structure and the supervision, we need to hold out validation / test edges. To train the training set, we further need to hold out supervision edges for the training set

  • 1). At training time: use training message edges to predict training supervision edges
  • 2). At validation time: use training message edges & training supervision edges to predict validation edges
  • 3). At test time, use training message edges & training supervision edges & validation edges to predict test edges
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值