dgl教程

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

一、实现GCN节点分类

1.1 加载cora数据集

import dgl.data
# networkx
dataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)

输出结果:
在这里插入图片描述

  • 一个DGL对象是可以表示多图的,因此取出其中一个图的话就是需要索引的。
    在这里插入图片描述
  • 上图的中的train mask就表示哪一部分代表是训练集的。
  • 对于qora这个数据集是只对点的特征属性用embedding表示,边的特征没有。
    *在这里插入图片描述

定义GCN网络架构

from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)  #输入特征维度、输出隐藏层维度
        self.conv2 = GraphConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    
# Create the model with given dimensions
model = GCN(g.ndata['feat'

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值