假设我们有一个包含 3个超表面单元 的简单图(有向):
节点特征:每个单元有5个特征
边连接:单元0与单元1相连,单元1与单元2相连(链式结构)
dgl.graph((src, dst)):
src 是源节点ID列表,dst 是目标节点ID列表。
此处 edges[:,0] 是源节点 [0, 1],edges[:,1] 是目标节点 [1, 2]。
g.ndata['h'] 和 g.ndata['pos']:
为每个节点添加名为 h 和 pos 的属性,后续GNN层可通过这些字段访问数据。
import dgl
import torch
# 定义节点特征(3个节点,每个节点5维特征)
node_features = torch.tensor([
[6.0, 14.0, 0.0, 0.0, 0.0], # 单元0
[8.0, 16.0, 1.0, -1.0, 1.0], # 单元1
[10.0,18.0, -1.0, 1.0, 2.0] # 单元2
], dtype=torch.float32)
# 定义边(邻接关系)
edges = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) # 0→1, 1→2
# 创建图
g = dgl.graph((edges[:,0], edges[:,1]))
g.ndata['h'] = node_features # 节点特征命名为'h'
g.ndata['pos'] = torch.tensor([[0,0], [1,0], [0,1]], dtype=torch.float32) # 假位置坐标
图结构可视化:
import networkx as nx
import matplotlib.pyplot as plt
# 创建无向图(便于可视化)
G = nx.Graph()
G.add_nodes_from([0, 1, 2])
G.add_edges_from([(0,1), (1,2)])
# 设置节点位置
pos = {0: (0,0), 1: (1,0), 2: (0,1)}
# 绘制
nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=800)
plt.title("Graph Structure with 3 Nodes and 2 Edges")
plt.show()
创建有向图
# 创建DGL图(有向图)
g = dgl.graph((edges[:,0], edges[:,1]))
# 添加节点特征和位置
g.ndata['h'] = node_features # 特征存储在'h'字段
g.ndata['pos'] = pos # 位置存储在'pos'字段