图神经网络入门代码(1)-简单图构建

假设我们有一个包含 ​​3个超表面单元​​ 的简单图(有向):

​​        节点特征​​:每个单元有5个特征
​​        边连接​​:单元0与单元1相连,单元1与单元2相连(链式结构)

        dgl.graph((src, dst))​​:
                src 是源节点ID列表,dst 是目标节点ID列表。
                此处 edges[:,0] 是源节点 [0, 1],edges[:,1] 是目标节点 [1, 2]。
        ​​g.ndata['h'] 和 g.ndata['pos']​​:
                为每个节点添加名为 h 和 pos 的属性,后续GNN层可通过这些字段访问数据。

import dgl
import torch

# 定义节点特征(3个节点,每个节点5维特征)
node_features = torch.tensor([
    [6.0, 14.0, 0.0, 0.0, 0.0],  # 单元0
    [8.0, 16.0, 1.0, -1.0, 1.0], # 单元1
    [10.0,18.0, -1.0, 1.0, 2.0]  # 单元2
], dtype=torch.float32)

# 定义边(邻接关系)
edges = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)  # 0→1, 1→2

# 创建图
g = dgl.graph((edges[:,0], edges[:,1]))
g.ndata['h'] = node_features  # 节点特征命名为'h'
g.ndata['pos'] = torch.tensor([[0,0], [1,0], [0,1]], dtype=torch.float32) # 假位置坐标

图结构可视化:

import networkx as nx
import matplotlib.pyplot as plt

# 创建无向图(便于可视化)
G = nx.Graph()
G.add_nodes_from([0, 1, 2])
G.add_edges_from([(0,1), (1,2)])

# 设置节点位置
pos = {0: (0,0), 1: (1,0), 2: (0,1)}

# 绘制
nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=800)
plt.title("Graph Structure with 3 Nodes and 2 Edges")
plt.show()

创建有向图

# 创建DGL图(有向图)
g = dgl.graph((edges[:,0], edges[:,1]))

# 添加节点特征和位置
g.ndata['h'] = node_features  # 特征存储在'h'字段
g.ndata['pos'] = pos          # 位置存储在'pos'字段

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值