Graph Convolutional Network
从信息传递的角度来分析GCN
在GCN中每个node都有自己的representation
h
i
h_i
h i
根据信息传递的范式,每个node会收到它的邻接node发送的message(representation)
每个node将收到邻居的message进行聚合得到
h
i
^
\hat{h_i}
h i ^
聚合后的representation,进行线性或非线性的变换通过函数
f
f
f
h
i
^
\hat{h_i}
h i ^ 经过函数
f
(
W
u
h
i
^
)
=
h
i
n
e
w
f(W_u\hat{h_i}) = h^{new}_i
f ( W u h i ^ ) = h i n e w
根据以上计算得到的新
h
i
n
e
w
h^{new}_i
h i n e w ,更新
h
i
n
e
w
−
−
>
h
i
h^{new}_i --> h_i
h i n e w − − > h i
GCN的数学表示:
H
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
D
~
−
1
2
H
(
l
)
W
(
l
)
)
H^{(l+1)} = \sigma(\tilde{D}^{\frac{-1}{2}}\tilde{A}\tilde{D}^{\frac{-1}{2}}H^{(l)}W^{(l)})
H ( l + 1 ) = σ ( D ~ 2 − 1 A ~ D ~ 2 − 1 H ( l ) W ( l ) )
H
(
l
)
H^{(l)}
H ( l ) :
l
t
h
l^{th}
l t h 层所有nodes的representation
W
(
l
)
W^{(l)}
W ( l ) :
l
t
h
l^{th}
l t h 层的权重矩阵
D
D
D : degree matrix 度矩阵
A
A
A : adjacency matrix 邻接矩阵
D
~
\tilde{D}
D ~ : renormalization trick 重正则化技巧:给图中的每个节点增加自连接后的度矩阵
A
~
\tilde{A}
A ~ : renormalization trick
H
(
0
)
H^{(0)}
H ( 0 ) : 输入,每个节点的初始化的特征
H
(
0
)
H^{(0)}
H ( 0 ) : shape :
N
×
F
i
n
N \times F_{in}
N × F in
N : 图中的node的数量 $F_{in} $: 输入特征的维度
H
(
o
u
t
)
H^{(out)}
H ( o u t ) : 输出,shape :
N
×
F
o
u
t
N \times F_{out}
N × F o u t
Build a GCN using DGL
import dgl
import torch as th
import torch. nn as nn
import dgl. function as fn
import torch. nn. functional as F
from dgl import DGLGraph
gcn_msg = fn. copy_src( src= 'h' , out= 'm' )
gcn_reduce = fn. sum ( msg= 'm' , out= 'h' )
class NodeApplyModule ( nn. Module) :
def __init__ ( self, in_feats, out_feats, activation) :
super ( NodeApplyModule, self) . __init__( )
self. linear = nn. Linear( in_feats, out_feats)
self. activation = activation
def forward ( self, node) :
h = self. linear( node. data[ 'h' ] )
h = self. activation( h)
return { 'h' : h}
class GCN ( nn. Module) :
def __init__ ( self, in_feats, out_feats, activation) :
super ( GCN, self) . __init__( )
self. apply_mod = NodeApplyModule( in_feats, out_feats, activation)
def forward ( self, g, feature) :
g. ndata[ 'h' ] = feature
g. update_all( gcn_msg, gcn_reduce)
g. apply_nodes( func= self. apply_mod)
return g. ndata. pop( 'h' )
class Net ( nn. Module) :
def __init__ ( self) :
super ( Net, self) . __init__( )
self. gcn1 = GCN( 1433 , 16 , F. relu)
self. gcn2 = GCN( 16 , 7 , F. relu)
def forward ( self, g, features) :
x = self. gcn1( g, features)
x = self. gcn2( g, x)
return x
GCnet = Net( )
print ( GCnet)
Net(
(gcn1): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=1433, out_features=16, bias=True)
)
)
(gcn2): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=16, out_features=7, bias=True)
)
)
)
Load data(dgl built-in)
from dgl. data import citation_graph as citegrh
def load_cora_data ( ) :
data = citegrh. load_cora( )
features = th. FloatTensor( data. features)
labels = th. LongTensor( data. labels)
mask = th. ByteTensor( data. train_mask)
g = data. graph
g. remove_edges_from( g. selfloop_edges( ) )
g = DGLGraph( g)
g. add_edges( g. nodes( ) , g. nodes( ) )
return g, features, labels, mask
train model
import time
import warnings
import numpy as np
warnings. filterwarnings( 'ignore' )
graph, features, labels, mask = load_cora_data( )
optimizer = th. optim. Adam( GCnet. parameters( ) , lr= 0.1 )
dur = [ ]
train_loss = [ ]
for epoch in range ( 50 ) :
if epoch >= 3 :
t0 = time. time( )
logits = GCnet( graph, features)
logp = F. log_softmax( logits, 1 )
loss = F. nll_loss( logp[ mask] , labels[ mask] )
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
if epoch >= 3 :
dur. append( time. time( ) - t0)
train_loss. append( loss. item( ) )
print ( "Epoch %5d | Loss %.4f | Time(s) %.4f" % ( epoch, loss. item( ) , np. mean( dur) ) )
Epoch 0 | Loss 0.9992 | Time(s) nan
Epoch 1 | Loss 1.0033 | Time(s) nan
Epoch 2 | Loss 2.8829 | Time(s) nan
Epoch 3 | Loss 1.7264 | Time(s) 0.2997
Epoch 4 | Loss 1.4124 | Time(s) 0.2961
Epoch 5 | Loss 0.8191 | Time(s) 0.2988
Epoch 6 | Loss 0.7352 | Time(s) 0.3071
Epoch 7 | Loss 0.6177 | Time(s) 0.3042
Epoch 8 | Loss 0.5425 | Time(s) 0.3030
Epoch 9 | Loss 0.4691 | Time(s) 0.3024
Epoch 10 | Loss 0.3825 | Time(s) 0.3019
Epoch 11 | Loss 0.3116 | Time(s) 0.3017
Epoch 12 | Loss 0.2253 | Time(s) 0.3036
Epoch 13 | Loss 0.1849 | Time(s) 0.3030
Epoch 14 | Loss 0.2047 | Time(s) 0.3027
Epoch 15 | Loss 0.1770 | Time(s) 0.3027
Epoch 16 | Loss 0.1390 | Time(s) 0.3023
Epoch 17 | Loss 0.0902 | Time(s) 0.3022
Epoch 18 | Loss 0.0822 | Time(s) 0.3023
Epoch 19 | Loss 0.0842 | Time(s) 0.3019
Epoch 20 | Loss 0.0796 | Time(s) 0.3027
Epoch 21 | Loss 0.0689 | Time(s) 0.3027
Epoch 22 | Loss 0.0667 | Time(s) 0.3025
Epoch 23 | Loss 0.0524 | Time(s) 0.3024
Epoch 24 | Loss 0.0486 | Time(s) 0.3025
Epoch 25 | Loss 0.0413 | Time(s) 0.3022
Epoch 26 | Loss 0.0382 | Time(s) 0.3021
Epoch 27 | Loss 0.0314 | Time(s) 0.3022
Epoch 28 | Loss 0.0282 | Time(s) 0.3019
Epoch 29 | Loss 0.0267 | Time(s) 0.3018
Epoch 30 | Loss 0.0254 | Time(s) 0.3018
Epoch 31 | Loss 0.0267 | Time(s) 0.3016
Epoch 32 | Loss 0.0248 | Time(s) 0.3016
Epoch 33 | Loss 0.0246 | Time(s) 0.3016
Epoch 34 | Loss 0.0240 | Time(s) 0.3014
Epoch 35 | Loss 0.0229 | Time(s) 0.3013
Epoch 36 | Loss 0.0225 | Time(s) 0.3014
Epoch 37 | Loss 0.0217 | Time(s) 0.3012
Epoch 38 | Loss 0.0210 | Time(s) 0.3012
Epoch 39 | Loss 0.0209 | Time(s) 0.3012
Epoch 40 | Loss 0.0204 | Time(s) 0.3011
Epoch 41 | Loss 0.0201 | Time(s) 0.3011
Epoch 42 | Loss 0.0200 | Time(s) 0.3015
Epoch 43 | Loss 0.0196 | Time(s) 0.3013
Epoch 44 | Loss 0.0194 | Time(s) 0.3013
Epoch 45 | Loss 0.0192 | Time(s) 0.3013
Epoch 46 | Loss 0.0189 | Time(s) 0.3014
Epoch 47 | Loss 0.0187 | Time(s) 0.3013
Epoch 48 | Loss 0.0185 | Time(s) 0.3013
Epoch 49 | Loss 0.0182 | Time(s) 0.3013
import matplotlib. pyplot as plt
plt. figure( figsize= ( 15 , 7 ) )
plt. plot( train_loss)
plt. title( 'Train Loss' )
plt. grid( True )
plt. show( )