【DGL学习2】编写自己的GNN模型(MPNN)

使用DGL的消息传递API编写自己的GNN模型。
参考:

DGL遵循消息传递框架

m u → v ( l ) = M ( l ) ( h v ( l − 1 ) , h u ( l − 1 ) , e u → v ( l − 1 ) ) m v ( l ) = ∑ u ∈ N ( v ) m u → v ( l ) h v ( l ) = U ( l ) ( h v ( l − 1 ) , m v ( l ) ) m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) \\[2ex] m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} \\[2ex] h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right) muv(l)=M(l)(hv(l1),hu(l1),euv(l1))mv(l)=uN(v)muv(l)hv(l)=U(l)(hv(l1),mv(l))

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
Using backend: pytorch
import dgl.data
dataset = dgl.data.CoraGraphDataset()
print(dataset)
g = dataset[0]
print(g)
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
<dgl.data.citation_graph.CoraGraphDataset object at 0x7fc05098de90>
Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})
import dgl.function as fn

SAGEConv公式

h N ( v ) k ← Average { h u k − 1 , ∀ u ∈ N ( v ) } h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\} hN(v)kAverage{huk1,uN(v)}
h v k ← ReLU ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) hvkReLU(WkCONCAT(hvk1,hN(v)k))

定义message函数和reduce函数

也可以使用fn中自带的函数

# message函数通过邻居节点的feat(以及边feat)产生消息
def message_func(edges):
    return {'m' : edges.src['h']}
    
# reduce函数对邻居进行聚合
def reduce_func(nodes):
    return {'h_n' : nodes.mailbox['m'].mean(dim=1)}
class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(SAGEConv, self).__init__()
        self.weight = nn.Linear(in_feats*2, out_feats)
        
    def forward(self, g, h):
        g.ndata['h'] = h
#         g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_n'))
        # update_all是核心函数,完成message和reduce过程
        g.update_all(message_func=message_func, reduce_func=reduce_func)
        h_n = g.ndata['h_n']
        # update
        x = self.weight(torch.cat([h, h_n], dim=1))
        
        return x
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, out_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, out_feats)
        
    def forward(self, g, in_feat):
        x = F.relu(self.conv1(g, in_feat))
        x = F.softmax(self.conv2(g, x), dim=1)
        
        return x

WeightedSAGEConv

h N ( v ) k ← Average { w u ∗ h u k − 1 , ∀ u ∈ N ( v ) } h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{w_u * h_u^{k-1},\forall u\in\mathcal{N}(v)\} hN(v)kAverage{wuhuk1,uN(v)}
h v k ← ReLU ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) hvkReLU(WkCONCAT(hvk1,hN(v)k))
Cora本身没有边权重,所以后面默认权重都是1(结果和SAGEConv一样)

# 邻居节点特征feat与边权重点积,得到message
def w_mes(edges):
    return {'m': edges.src['h']*edges.data['w']}

def w_red(nodes):
    return {'h_n': nodes.mailbox['m'].mean(dim=1)}
class WeightedSAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(WeightedSAGEConv, self).__init__()
        self.w_k = nn.Linear(in_feats*2, out_feats)
        
    def forward(self, g, h, w):
        g.ndata['h'] = h
        g.edata['w'] = w  # 边信息
        g.update_all(message_func=w_mes, reduce_func=w_red)
        h_n = g.ndata['h_n']
        x = self.w_k(torch.cat([h, h_n],dim=1))
        
        return x
class W_GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, out_feats):
        super(W_GraphSAGE, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, out_feats)
        
    def forward(self, g, in_feat):
        x = F.relu(self.conv1(g, in_feat, torch.ones(g.num_edges(), in_feat.shape[1]).to(g.device)))  #w的dim=1要和in_feat相同
        x = F.softmax(self.conv2(g, x, torch.ones(g.num_edges(), x.shape[1]).to(g.device)), dim=1)
        
        return x
model = GraphSAGE(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)

# model = W_GraphSAGE(g.ndata['feat'].shape[1], 16, dataset.num_classes)
# print(model)
GraphSAGE(
  (conv1): SAGEConv(
    (weight): Linear(in_features=2866, out_features=16, bias=True)
  )
  (conv2): SAGEConv(
    (weight): Linear(in_features=32, out_features=7, bias=True)
  )
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
g = g.to(device)
cpu
feat = g.ndata['feat']
label = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss().to(device)
def train():
    out = model(g, feat)
    loss = criterion(out[train_mask], label[train_mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    pred = out.argmax(dim=1)
    
    train_acc = (pred == label).float().mean()
    val_acc = (pred == label).float().mean()
    test_acc = (pred == label).float().mean()
    
    return loss.item(), train_acc, val_acc, test_acc
def main():
    best_val_acc = 0
    best_test_acc = 0
    
    for epoch in range(100):
        loss, train_acc, val_acc, test_acc = train()
        
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            
        print('epoch:{:03d}, train_acc:{:.4f}, val_acc:{:.4f}, test_acc:{:.4f}'.format(epoch, train_acc, val_acc, test_acc))
        
    print('best_val_acc:', best_val_acc)
    print('best_test_acc:', best_test_acc)
if __name__ == '__main__':
    main()
epoch:000, train_acc:0.3021, val_acc:0.3021, test_acc:0.3021
epoch:001, train_acc:0.3021, val_acc:0.3021, test_acc:0.3021
epoch:002, train_acc:0.3065, val_acc:0.3065, test_acc:0.3065
epoch:003, train_acc:0.3261, val_acc:0.3261, test_acc:0.3261
epoch:004, train_acc:0.3460, val_acc:0.3460, test_acc:0.3460
epoch:005, train_acc:0.4032, val_acc:0.4032, test_acc:0.4032
epoch:006, train_acc:0.4734, val_acc:0.4734, test_acc:0.4734
epoch:007, train_acc:0.5583, val_acc:0.5583, test_acc:0.5583
epoch:008, train_acc:0.5982, val_acc:0.5982, test_acc:0.5982
epoch:009, train_acc:0.5517, val_acc:0.5517, test_acc:0.5517
epoch:010, train_acc:0.4919, val_acc:0.4919, test_acc:0.4919
epoch:011, train_acc:0.4450, val_acc:0.4450, test_acc:0.4450
epoch:012, train_acc:0.4243, val_acc:0.4243, test_acc:0.4243
epoch:013, train_acc:0.4147, val_acc:0.4147, test_acc:0.4147
epoch:014, train_acc:0.4103, val_acc:0.4103, test_acc:0.4103
epoch:015, train_acc:0.4088, val_acc:0.4088, test_acc:0.4088
epoch:016, train_acc:0.4077, val_acc:0.4077, test_acc:0.4077
epoch:017, train_acc:0.4095, val_acc:0.4095, test_acc:0.4095
epoch:018, train_acc:0.4114, val_acc:0.4114, test_acc:0.4114
epoch:019, train_acc:0.4158, val_acc:0.4158, test_acc:0.4158
epoch:020, train_acc:0.4236, val_acc:0.4236, test_acc:0.4236
epoch:021, train_acc:0.4335, val_acc:0.4335, test_acc:0.4335
epoch:022, train_acc:0.4402, val_acc:0.4402, test_acc:0.4402
epoch:023, train_acc:0.4516, val_acc:0.4516, test_acc:0.4516
epoch:024, train_acc:0.4653, val_acc:0.4653, test_acc:0.4653
epoch:025, train_acc:0.4823, val_acc:0.4823, test_acc:0.4823
epoch:026, train_acc:0.5026, val_acc:0.5026, test_acc:0.5026
epoch:027, train_acc:0.5284, val_acc:0.5284, test_acc:0.5284
epoch:028, train_acc:0.5602, val_acc:0.5602, test_acc:0.5602
epoch:029, train_acc:0.5905, val_acc:0.5905, test_acc:0.5905
epoch:030, train_acc:0.6226, val_acc:0.6226, test_acc:0.6226
epoch:031, train_acc:0.6496, val_acc:0.6496, test_acc:0.6496
epoch:032, train_acc:0.6713, val_acc:0.6713, test_acc:0.6713
epoch:033, train_acc:0.6939, val_acc:0.6939, test_acc:0.6939
epoch:034, train_acc:0.7105, val_acc:0.7105, test_acc:0.7105
epoch:035, train_acc:0.7216, val_acc:0.7216, test_acc:0.7216
epoch:036, train_acc:0.7356, val_acc:0.7356, test_acc:0.7356
epoch:037, train_acc:0.7470, val_acc:0.7470, test_acc:0.7470
epoch:038, train_acc:0.7530, val_acc:0.7530, test_acc:0.7530
epoch:039, train_acc:0.7552, val_acc:0.7552, test_acc:0.7552
epoch:040, train_acc:0.7640, val_acc:0.7640, test_acc:0.7640
epoch:041, train_acc:0.7670, val_acc:0.7670, test_acc:0.7670
epoch:042, train_acc:0.7681, val_acc:0.7681, test_acc:0.7681
epoch:043, train_acc:0.7707, val_acc:0.7707, test_acc:0.7707
epoch:044, train_acc:0.7740, val_acc:0.7740, test_acc:0.7740
epoch:045, train_acc:0.7762, val_acc:0.7762, test_acc:0.7762
epoch:046, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:047, train_acc:0.7788, val_acc:0.7788, test_acc:0.7788
epoch:048, train_acc:0.7773, val_acc:0.7773, test_acc:0.7773
epoch:049, train_acc:0.7770, val_acc:0.7770, test_acc:0.7770
epoch:050, train_acc:0.7777, val_acc:0.7777, test_acc:0.7777
epoch:051, train_acc:0.7784, val_acc:0.7784, test_acc:0.7784
epoch:052, train_acc:0.7781, val_acc:0.7781, test_acc:0.7781
epoch:053, train_acc:0.7788, val_acc:0.7788, test_acc:0.7788
epoch:054, train_acc:0.7792, val_acc:0.7792, test_acc:0.7792
epoch:055, train_acc:0.7784, val_acc:0.7784, test_acc:0.7784
epoch:056, train_acc:0.7773, val_acc:0.7773, test_acc:0.7773
epoch:057, train_acc:0.7781, val_acc:0.7781, test_acc:0.7781
epoch:058, train_acc:0.7799, val_acc:0.7799, test_acc:0.7799
epoch:059, train_acc:0.7795, val_acc:0.7795, test_acc:0.7795
epoch:060, train_acc:0.7803, val_acc:0.7803, test_acc:0.7803
epoch:061, train_acc:0.7803, val_acc:0.7803, test_acc:0.7803
epoch:062, train_acc:0.7803, val_acc:0.7803, test_acc:0.7803
epoch:063, train_acc:0.7806, val_acc:0.7806, test_acc:0.7806
epoch:064, train_acc:0.7806, val_acc:0.7806, test_acc:0.7806
epoch:065, train_acc:0.7810, val_acc:0.7810, test_acc:0.7810
epoch:066, train_acc:0.7814, val_acc:0.7814, test_acc:0.7814
epoch:067, train_acc:0.7818, val_acc:0.7818, test_acc:0.7818
epoch:068, train_acc:0.7829, val_acc:0.7829, test_acc:0.7829
epoch:069, train_acc:0.7821, val_acc:0.7821, test_acc:0.7821
epoch:070, train_acc:0.7810, val_acc:0.7810, test_acc:0.7810
epoch:071, train_acc:0.7803, val_acc:0.7803, test_acc:0.7803
epoch:072, train_acc:0.7788, val_acc:0.7788, test_acc:0.7788
epoch:073, train_acc:0.7792, val_acc:0.7792, test_acc:0.7792
epoch:074, train_acc:0.7795, val_acc:0.7795, test_acc:0.7795
epoch:075, train_acc:0.7792, val_acc:0.7792, test_acc:0.7792
epoch:076, train_acc:0.7803, val_acc:0.7803, test_acc:0.7803
epoch:077, train_acc:0.7795, val_acc:0.7795, test_acc:0.7795
epoch:078, train_acc:0.7777, val_acc:0.7777, test_acc:0.7777
epoch:079, train_acc:0.7762, val_acc:0.7762, test_acc:0.7762
epoch:080, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:081, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:082, train_acc:0.7751, val_acc:0.7751, test_acc:0.7751
epoch:083, train_acc:0.7751, val_acc:0.7751, test_acc:0.7751
epoch:084, train_acc:0.7747, val_acc:0.7747, test_acc:0.7747
epoch:085, train_acc:0.7758, val_acc:0.7758, test_acc:0.7758
epoch:086, train_acc:0.7762, val_acc:0.7762, test_acc:0.7762
epoch:087, train_acc:0.7766, val_acc:0.7766, test_acc:0.7766
epoch:088, train_acc:0.7766, val_acc:0.7766, test_acc:0.7766
epoch:089, train_acc:0.7758, val_acc:0.7758, test_acc:0.7758
epoch:090, train_acc:0.7758, val_acc:0.7758, test_acc:0.7758
epoch:091, train_acc:0.7766, val_acc:0.7766, test_acc:0.7766
epoch:092, train_acc:0.7766, val_acc:0.7766, test_acc:0.7766
epoch:093, train_acc:0.7766, val_acc:0.7766, test_acc:0.7766
epoch:094, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:095, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:096, train_acc:0.7751, val_acc:0.7751, test_acc:0.7751
epoch:097, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:098, train_acc:0.7755, val_acc:0.7755, test_acc:0.7755
epoch:099, train_acc:0.7758, val_acc:0.7758, test_acc:0.7758
best_val_acc: tensor(0.7829)
best_test_acc: tensor(0.7829)
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值