简单了解DGL中的数据格式

本文介绍了图神经网络框架DGL,包括DGL数据集如Citeseer网络,以及如何使用dgl.DGLGraph和dgl.graph创建同质图和异质图。DGLGraph提供了ntypes、etypes等属性来获取图的节点和边类型,还展示了如何通过数据字典构建异质图dgl.heterograph。
摘要由CSDN通过智能技术生成

前言

PyG搭建GCN前的准备:了解PyG中的数据格式中讲解了PyG中的数据格式,DGL是与PyG齐名的另一大图神经网络框架,二者各有优缺点,建议都学习并掌握。

1. DGL数据集

本篇文章使用Citeseer网络。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。

DGL加载Citeseer网络:

import dgl
from dgl.data.citation_graph import CiteseerGraphDataset

dataset = CiteseerGraphDataset()
print(len(dataset))

输出为1,说明只有一个网络,然后我们输出一下这个网络:

graph = dataset[0]
print(type(graph))
print(graph)

输出:

<class 'dgl.heterograph.DGLHeteroGraph'>
Graph(num_nodes=3327, num_edges=9228,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(3703,), dtype=torch.float32)}
      edata_schemes={'__orig__': Scheme(shape=(), dtype=torch.int64)})

在DGL中,所有图都为dgl.DGLGraph格式,为了创建图,可以有以下两种方法:dgl.graph()和dgl.heterograph(),这两个方法分别创建同质图和异质图,二者返回的都是dgl.DGLGraph。

因此,下面先了解一下dgl.DGLGraph。

2. dgl.DGLGraph

dgl.DGLGraph类有以下属性和方法:

首先是属性:

DGLGraph.ntypes

返回图中所有类型节点的名称,如上面的网络返回:

['_N']

表明只有一种类型的节点,即论文节点。

同理还有边类型:

DGLGraph.etypes

输出为:

['_E']

同样只有一种类型的边。此外,DGLGraph.srctypes和DGLGraph.dsttypes分别返回源节点和目标节点的类型。

print(graph.metagraph())   # 返回异质图的元图
MultiDiGraph with 1 nodes and 1 edges
print(graph.num_nodes())   # 返回节点数
3327
print(graph.num_edges())   # 返回边数
9228

DGLGraph.nodes()返回节点集合:

print(graph.nodes())
tensor([   0,    1,    2,  ..., 3324, 3325, 3326])

边集合:

print(graph.edges())
(tensor([   2,    3,    0,  ..., 3323, 3326, 3325]), tensor([   0,    0,    0,  ..., 3324, 3325, 3326]))

同样是两个列表,分别对应两端节点编号。

ndata返回节点上的一些信息:

print(graph.ndata)
{'train_mask': tensor([False,  True, False,  ..., False, False,  True]), 'label': tensor([1, 4, 1,  ..., 5, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ..., False, False, False]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}

feat表示节点特征。

同理有edata,这里不再详细讲解。

3. dgl.graph

利用dgl.graph()方法构建同质图:

dgl.graph(data, ntype=None, etype=None, *, num_nodes=None, idtype=None, device=None, row_sorted=False , col_sorted=False, **deprecated_kwargs )

其中data的形式为(U, V),表示边的两边节点集合;num_nodes表示节点数,如果没有给出,将使用data中的最大id+1,这可能会引发错误:

src_ids = torch.tensor([2, 3, 4])
# Destination nodes for edges (2, 1), (3, 2), (4, 3)
dst_ids = torch.tensor([1, 2, 3])
g = dgl.graph((src_ids, dst_ids))
print(g.num_nodes())

返回的节点数为5,如果给定num_nodes<=4,将引发错误。

4. dgl.heterograph

dgl.heterograph( data_dict , num_nodes_dict=None , idtype=None , device=None )

具体例子:

data_dict = {
    ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
    ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
    ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
}
g = dgl.heterograph(data_dict)
print(g)

上图中,一共有user、topic和game三种类型的节点,他们有三种类型的关系,右边的数据表示边两边节点的索引。

此外,可以显式指定节点个数:

num_nodes_dict = {'user': 4, 'topic': 4, 'game': 6}
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)

这里一样指定数目不能小于边集合中的最小索引+1。

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值