GCN图卷积网络简单实现

Graph Convolutional Network

从信息传递的角度来分析GCN

    1. 在GCN中每个node都有自己的representation h i h_i hi
    1. 根据信息传递的范式,每个node会收到它的邻接node发送的message(representation)
    1. 每个node将收到邻居的message进行聚合得到 h i ^ \hat{h_i} hi^
    1. 聚合后的representation,进行线性或非线性的变换通过函数 f f f
    1. h i ^ \hat{h_i} hi^经过函数 f ( W u h i ^ ) = h i n e w f(W_u\hat{h_i}) = h^{new}_i f(Wuhi^)=hinew
    1. 根据以上计算得到的新 h i n e w h^{new}_i hinew,更新 h i n e w − − > h i h^{new}_i --> h_i hinew>hi

GCN的数学表示:

H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\tilde{D}^{\frac{-1}{2}}\tilde{A}\tilde{D}^{\frac{-1}{2}}H^{(l)}W^{(l)}) H(l+1)=σ(D~21A~D~21H(l)W(l))

  • H ( l ) H^{(l)} H(l) : l t h l^{th} lth 层所有nodes的representation
  • W ( l ) W^{(l)} W(l) : l t h l^{th} lth 层的权重矩阵
  • D D D : degree matrix 度矩阵
  • A A A : adjacency matrix 邻接矩阵
  • D ~ \tilde{D} D~ : renormalization trick 重正则化技巧:给图中的每个节点增加自连接后的度矩阵
  • A ~ \tilde{A} A~ : renormalization trick
  • H ( 0 ) H^{(0)} H(0) : 输入,每个节点的初始化的特征
  • H ( 0 ) H^{(0)} H(0) : shape : N × F i n N \times F_{in} N×Fin
    • N : 图中的node的数量
    • $F_{in} $: 输入特征的维度
  • H ( o u t ) H^{(out)} H(out) : 输出,shape : N × F o u t N \times F_{out} N×Fout

Build a GCN using DGL

import dgl
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m')

gcn_reduce = fn.sum(msg='m', out='h')
class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h' : h}
    
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gcn1 = GCN(1433, 16, F.relu)
        self.gcn2 = GCN(16, 7, F.relu)

    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x
GCnet = Net()

print(GCnet)
Net(
  (gcn1): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=1433, out_features=16, bias=True)
    )
  )
  (gcn2): GCN(
    (apply_mod): NodeApplyModule(
      (linear): Linear(in_features=16, out_features=7, bias=True)
    )
  )
)

Load data(dgl built-in)

from dgl.data import citation_graph as citegrh
def load_cora_data():
    data = citegrh.load_cora()
    features = th.FloatTensor(data.features)
    labels = th.LongTensor(data.labels)
    mask = th.ByteTensor(data.train_mask)
    g = data.graph
    # add self loop
    g.remove_edges_from(g.selfloop_edges())
    g = DGLGraph(g)
    g.add_edges(g.nodes(), g.nodes())
    return g, features, labels, mask

train model

import time
import warnings
import numpy as np
warnings.filterwarnings('ignore')

# 图,node's 特征,标签,
graph, features, labels, mask = load_cora_data()

optimizer = th.optim.Adam(GCnet.parameters(), lr=0.1)
dur = []
train_loss = []

for epoch in range(50):
    if epoch >= 3:
        t0 = time.time()
    logits = GCnet(graph,  features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >= 3:
        dur.append(time.time() - t0)
        train_loss.append(loss.item())
        
    print("Epoch %5d  |  Loss %.4f  |  Time(s) %.4f"%(epoch, loss.item(), np.mean(dur)))
Epoch     0  |  Loss 0.9992  |  Time(s) nan
Epoch     1  |  Loss 1.0033  |  Time(s) nan
Epoch     2  |  Loss 2.8829  |  Time(s) nan
Epoch     3  |  Loss 1.7264  |  Time(s) 0.2997
Epoch     4  |  Loss 1.4124  |  Time(s) 0.2961
Epoch     5  |  Loss 0.8191  |  Time(s) 0.2988
Epoch     6  |  Loss 0.7352  |  Time(s) 0.3071
Epoch     7  |  Loss 0.6177  |  Time(s) 0.3042
Epoch     8  |  Loss 0.5425  |  Time(s) 0.3030
Epoch     9  |  Loss 0.4691  |  Time(s) 0.3024
Epoch    10  |  Loss 0.3825  |  Time(s) 0.3019
Epoch    11  |  Loss 0.3116  |  Time(s) 0.3017
Epoch    12  |  Loss 0.2253  |  Time(s) 0.3036
Epoch    13  |  Loss 0.1849  |  Time(s) 0.3030
Epoch    14  |  Loss 0.2047  |  Time(s) 0.3027
Epoch    15  |  Loss 0.1770  |  Time(s) 0.3027
Epoch    16  |  Loss 0.1390  |  Time(s) 0.3023
Epoch    17  |  Loss 0.0902  |  Time(s) 0.3022
Epoch    18  |  Loss 0.0822  |  Time(s) 0.3023
Epoch    19  |  Loss 0.0842  |  Time(s) 0.3019
Epoch    20  |  Loss 0.0796  |  Time(s) 0.3027
Epoch    21  |  Loss 0.0689  |  Time(s) 0.3027
Epoch    22  |  Loss 0.0667  |  Time(s) 0.3025
Epoch    23  |  Loss 0.0524  |  Time(s) 0.3024
Epoch    24  |  Loss 0.0486  |  Time(s) 0.3025
Epoch    25  |  Loss 0.0413  |  Time(s) 0.3022
Epoch    26  |  Loss 0.0382  |  Time(s) 0.3021
Epoch    27  |  Loss 0.0314  |  Time(s) 0.3022
Epoch    28  |  Loss 0.0282  |  Time(s) 0.3019
Epoch    29  |  Loss 0.0267  |  Time(s) 0.3018
Epoch    30  |  Loss 0.0254  |  Time(s) 0.3018
Epoch    31  |  Loss 0.0267  |  Time(s) 0.3016
Epoch    32  |  Loss 0.0248  |  Time(s) 0.3016
Epoch    33  |  Loss 0.0246  |  Time(s) 0.3016
Epoch    34  |  Loss 0.0240  |  Time(s) 0.3014
Epoch    35  |  Loss 0.0229  |  Time(s) 0.3013
Epoch    36  |  Loss 0.0225  |  Time(s) 0.3014
Epoch    37  |  Loss 0.0217  |  Time(s) 0.3012
Epoch    38  |  Loss 0.0210  |  Time(s) 0.3012
Epoch    39  |  Loss 0.0209  |  Time(s) 0.3012
Epoch    40  |  Loss 0.0204  |  Time(s) 0.3011
Epoch    41  |  Loss 0.0201  |  Time(s) 0.3011
Epoch    42  |  Loss 0.0200  |  Time(s) 0.3015
Epoch    43  |  Loss 0.0196  |  Time(s) 0.3013
Epoch    44  |  Loss 0.0194  |  Time(s) 0.3013
Epoch    45  |  Loss 0.0192  |  Time(s) 0.3013
Epoch    46  |  Loss 0.0189  |  Time(s) 0.3014
Epoch    47  |  Loss 0.0187  |  Time(s) 0.3013
Epoch    48  |  Loss 0.0185  |  Time(s) 0.3013
Epoch    49  |  Loss 0.0182  |  Time(s) 0.3013
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 7))
plt.plot(train_loss)
plt.title('Train Loss')
plt.grid(True)
plt.show()

在这里插入图片描述

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值