![5f594d7e36eddbb201ded17cbdd94020.png](https://img-blog.csdnimg.cn/img_convert/5f594d7e36eddbb201ded17cbdd94020.png)
寒假假期受新冠病毒影响延长,在家实在无心学习,想到之前有知友问关于GNN模型中如何实现batch的问题,于是查阅资料,略有感悟,因作此篇。代码使用jupyter notebook编写,已上传Github。望天佑中华,天佑武汉。
前言
在本篇考虑一个很简单的图分类问题,输入一个图(graph)
- Cora数据集是面向一个graph、对节点的进行分类的任务。
- 该任务是面向多个graph、对graph进行分类的任务。
在代码中提供人工数据集的生成、数据输入、模型构建、训练和评估等流程。
任务定义
子图匹配分类:给定一个子图(subgraph)
生成数据集
生成字典代码:
"""
生成数据集代码:
subgraph = (["A", "A", "B", "C"],
[(0, 1),
(0, 2),
(1, 2),
(2, 3)])
min_nodes_num = 5
max_nodes_num = 50
graph_num = 10000
"""
Generating graph dataset. There are three steps:
Step1 : Randomly choose number of nodes(N).
Step2 : Generate random graph with edge number ranging from N-1 to N * (N - 1) / 2.
Step3 : Remove unconnected graph.
Step4 : Add subgraph to some graphs.
"""
N = 0
graphs = []
random.seed(0)
while N < graph_num:
node_num = random.randint(min_nodes_num, max_nodes_num)
edge_num = random.randint(node_num-1, node_num * (node_num - 1) / 2)
G = nx.random_graphs.dense_gnm_random_graph(node_num, edge_num)
if nx.connected.is_connected(G):
graphs.append(G)
N += 1
if N % 1000 == 0:
print("{} graphs have been generated!".format(N))
"""
Transform nx.Graph into our graph type.
"""
def transform_nx_graph(g):
nodes = random.choices(population=node_types, k=len(g.nodes))
edges = list(g.edges)
return (nodes, edges)
"""
Merge subg