DGL_图的创建、保存、加载

import dgl
import torch as th
from dgl.data.utils import save_graphs

g1 = dgl.DGLGraph()
g1.add_nodes(3)
g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
g1.ndata["x"] = th.ones(3, 5)   # 3个节点的embedding
g1.edata['y'] = th.zeros(6, 5)  # 6条边的embedding
# 补充:添加边的方式
# g1.add_edges(th.tensor([3, 4, 5]), 1)  # three edges: 3->1, 4->1, 5->1
# g1.add_edges(4, [7, 8, 9])  # three edges: 4->7, 4->8, 4->9
# g1.add_edges([1, 2, 3], [3, 4, 5])  # three edges: 1->3, 2->4, 3->5

g2 = dgl.DGLGraph()
g2.add_nodes(3)
g2.add_edges([0, 1, 2], [1, 2, 1])
g2.edata["e"] = th.ones(3, 4)

graph_labels = {"graph_sizes": th.tensor([3, 3])}

save_graphs("data/try1.bin", [g1, g2], graph_labels)
from dgl.data.utils import load_graphs
from dgl.data.utils import load_labels

# glist, label_dict = load_graphs("data/small.bin") # glist will be [g1, g2]
glist, label_dict = load_graphs("data/try1.bin", [0]) # glist will be [g1]
graph_sizes = load_labels("data/try1.bin")

print(glist)
# [DGLGraph(num_nodes=3, num_edges=6,
#          ndata_schemes={'x': Scheme(shape=(5,), dtype=torch.float32)}
#          edata_schemes={'y': Scheme(shape=(5,), dtype=torch.float32)})]
print(label_dict)
# {'graph_sizes': tensor([3, 3])}
print(graph_sizes)
# {'graph_sizes': tensor([3, 3])}
  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值