import torch
from torch_geometric.data import Data
x = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.float) # 节点特征矩阵(三个节点,每个节点两个特征)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) # 边索引矩阵(四条边,每条边包含两个节点索引)
y = torch.tensor([0, 1, 0], dtype=torch.long) # 每个节点的目标标签
train_mask = torch.tensor([True, False, True]) # 训练掩膜(三个节点)
test_mask = torch.tensor([False, True, False]) # 测试掩膜(三个节点)
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)
print(data)
将数据包装成一个图数据结构(torch_geometric)
最新推荐文章于 2024-11-08 15:02:13 发布