batch & print pro_番外篇:GNN训练的Batch问题

5f594d7e36eddbb201ded17cbdd94020.png
寒假假期受新冠病毒影响延长,在家实在无心学习,想到之前有知友问关于GNN模型中如何实现batch的问题,于是查阅资料,略有感悟,因作此篇。代码使用jupyter notebook编写,已上传Github。望天佑中华,天佑武汉。

前言

在本篇考虑一个很简单的图分类问题,输入一个图(graph)

,输出该图是否包含某个子图(subgraph),即0和1。该任务和之前的Cora数据集任务不同之处在于:
  • Cora数据集是面向一个graph、对节点的进行分类的任务。
  • 该任务是面向多个graph、对graph进行分类的任务。

在代码中提供人工数据集的生成、数据输入、模型构建、训练和评估等流程。

任务定义

子图匹配分类:给定一个子图(subgraph)

以及图(graph)的数据集
,对应的标签为
,对于任意一个图(graph)
及其标签
,有:

生成数据集

生成字典代码:

"""

生成数据集代码:

subgraph = (["A", "A", "B", "C"], 
            [(0, 1),
             (0, 2),
             (1, 2),
             (2, 3)])
min_nodes_num = 5
max_nodes_num = 50
graph_num = 10000

"""
Generating graph dataset. There are three steps:
Step1 : Randomly choose number of nodes(N).
Step2 : Generate random graph with edge number ranging from N-1 to N * (N - 1) / 2.
Step3 : Remove unconnected graph.
Step4 : Add subgraph to some graphs.
"""
N = 0
graphs = []
random.seed(0)
while N < graph_num:
    node_num = random.randint(min_nodes_num, max_nodes_num)
    edge_num = random.randint(node_num-1, node_num * (node_num - 1) / 2)
    G = nx.random_graphs.dense_gnm_random_graph(node_num, edge_num)
    if nx.connected.is_connected(G):
        graphs.append(G)
        N += 1
        if N % 1000 == 0:
            print("{} graphs have been generated!".format(N))

"""
Transform nx.Graph into our graph type.
"""
def transform_nx_graph(g):
    nodes = random.choices(population=node_types, k=len(g.nodes))
    edges = list(g.edges)
    
    return (nodes, edges)

"""
Merge subg
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值