import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
一、实现GCN节点分类
1.1 加载cora数据集
import dgl.data
# networkx
dataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
输出结果:
- 一个DGL对象是可以表示多图的,因此取出其中一个图的话就是需要索引的。
- 上图的中的train mask就表示哪一部分代表是训练集的。
- 对于qora这个数据集是只对点的特征属性用embedding表示,边的特征没有。
*
定义GCN网络架构
from dgl.nn import GraphConv
class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats) #输入特征维度、输出隐藏层维度
self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat):
h = self.conv1(g, in_feat)
h = F.relu(h)
h = self.conv2(g, h)
return h
# Create the model with given dimensions
model = GCN(g.ndata['feat'