DGL的Blitz ---- Blitz的节点分类Node Classification with DGL (v9.0 版DGL)

直接看代码吧:

'''对应的文档在:https://docs.dgl.ai/tutorials/blitz/1_introduction.html#sphx-glr-tutorials-blitz-1-introduction-py'''
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import RedditDataset
from dgl.nn.pytorch.conv import GraphConv       #教程上说的是from dgl.nn import GraphConv 会pycharm划红线警告但是运行时不会报错

"""
Kipf定义 节点分类问题 为一个semi-supervised问题 : 即,通过一小部分已知标签的节点,就可以预测出图上全部节点的分类
采取的数据集为Cora数据集,这个数据集是一个论文引用网络:论文作为图上的节点;两篇论文之间存在引用关系,那么在图上这两个论文节点上存在边
这个节点分类任务就是要求预测出给定论文的种类。
每个论文节点都包含一个字数向量作为其特征,经过归一化以使它们总和为 1

NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
g.ndata包含了一个字典,其中的key有'feat' : (是一个大小为2708(节点数)行,1433(节点的feature的维度)列的tensor) ; 
                              'label':是一个一维的tensor,表示节点的y,即,节点属于那个类别(因为这毕竟是一个节点分类任务嘛)
                              'test_mask':是一个一维的值为True或者False的tensor,用来标识一个节点是test集合的 ;
                              'train_mask': 同理
                              'val_mask': 同理
"""

# dataset_Reddit = RedditDataset()


dataset = dgl.data.CoraGraphDataset()

#由于部分DGL中的Dataset的实例会包含多个图,但是对于Cora dataset来说,只包含一个图,故:
g = dataset[0]

# train_mask 、val_mask 、 test_mask 是三个bool类型的tensor,用来标识节点是属于 训练集、验证集 还是 测试集, 在ndata属性中的三类key对应的value长如下样子
# 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]),
# 'train_mask': tensor([ True,  True,  True,  ..., False, False, False]),
# 'val_mask': tensor([False, False, False,  ..., False, False, False])}
# ndata 、 edata 是两个字典类型的数据,分别用于存储节点和边的特征




class GCN(nn.Module):
    # 本例子使用两层的GCN网络。我们可以简单地通过堆叠dgl提供的dgl.nn.GraphConv进行构建   (dgl.nn.GraphConv继承于torch.nn.Module)
    # 创建一个两层的GCN网络:
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']  #因为图的ndata属性返回的是一个字典,那么,通过key值为'feat' ,就可以得到各个节点的特征tensor
    labels = g.ndata['label']   #因为图的ndata属性返回的是一个字典,那么,通过key值为'label' ,就可以得到各个节点的标签,即,属于哪一类别
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(100):
        # Forward
        logits = model(g, features)      #logits返回的是对于各个节点是属于哪个类别的一个预测值的tensor
        # Compute prediction
        pred = logits.argmax(1)         #pred是对应该logits中找到最大值,这个对应的就是经过本轮训练后预测的该点的类别

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])  #这里注意,logits[train_mask]对应的是经过一次计算后得到的对该节点属于各个分类的一个预测值向量,与对应的y值label做交叉熵

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()   #对优化器内的梯度进行清空,只是为了防止上一轮优化器中的梯度对此次计算产生影响
        loss.backward()         #根据上面的loss,使用优化器计算grad
        optimizer.step()        #根据loss.backward()计算出来的grad对模型参数进行更新

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))


g = g.to('cuda')

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)      #g.ndata['feat'].shape[1]就是对应的每个节点的对应的feature的tensor的维度,等同于:len(g.ndata['feat'][0]
# train(g, model)

# print("shape[1]:",g.ndata['feat'].shape[1])
print("节点的特征:",len(g.ndata['feat']))
print("ndata[feat]的维度:",g.ndata['feat'].shape)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值