PyG:torchgeometric HeteroData类

同构图

异构图 Hetero_graph

创建异构图

异构图创建一般是先创建一个空的异构图,然后在添加节点和边

from torch_geometric.data import HeteroData
hetero_data = HeteroData()#创建空异构图

节点

  • 添加新的节点类
    异构图的存储类似与字典,因此可以直接通过下标来添加新的节点类
    hetero_data['new_node_type'].x = torch.randn(num_nodes,num_nodes_features)
  • 添加新结点:由于节点的存储是矩阵,每一行是一个结点,因此只需要创建相同特征的tensor并且cat到特征矩阵x上即可:
    new_nodes_features = torch.randn(2, 4)
    hetero_data['existing_node_type'].x = torch.cat([hetero_data['existing_node_type'].x, new_nodes_features], dim=0)
  • 访问结点:torch geometric图的存储方式类似于字典,因此可以类似与字典的下标运算符一样访问不同类型节点的特征,属性
node_features=hetero_data['new_node_type'].x
node_attr=hetero_data['new_node_type'].attr

  • 添加边的范式:
    hetero_graph['source', 'relation', 'target'].edge_index=torch.tensor[source_list,target_list]
    例如:假设第1种农药可以防治第1种和第2种病害
    hetero_graph['pesticide', 'prevents', 'disease'].edge_index = torch.tensor([[0, 0], [0, 1]])

  • 可以通过edge_attr来构建边的特征
    hetero_graph['source', 'relation', 'target'].edge_attr = torch.randn(2, 2)

  • 访问:torch geometric图的存储方式类似于字典,因此可以类似与字典的下标运算符一样访问不同类型边的特征,属性…

edge_index = hetero_graph['source', 'relation', 'target'] # 访问边的连接信息

常用函数

  • to_homogeneous:将异构图转换为同构图,用于某些只支持同构图输入的图神经网络模型。

  • to(device):将所有的图数据移动到指定的设备(如cuda:0,cpu)。

  • clone():创建 HeteroData 对象的深拷贝

异构图例子

import torch
from torch_geometric.data import HeteroData
hetero_graph = HeteroData()#空图
# 假设3种农药,每种农药的特征维度为5
hetero_graph['pesticide'].x = torch.randn(3, 5)
# 假设有4种病害,每种病害的特征维度为5
hetero_graph['disease'].x = torch.randn(4, 5)

# 假设有2种植物,每种植物的特征维度为5
hetero_graph['plant'].x = torch.randn(2, 5)
# 定义农药到病害的"防治"关系
# 假设第1种农药可以防治第1种和第2种病害
hetero_graph['pesticide', 'prevents', 'disease'].edge_index = torch.tensor(
    [[0, 0], [0, 1]], dtype=torch.long)

# 定义病害到植物的"影响"关系
# 假设第1种病害影响第1种植物,第2种和第3种病害影响第2种植物
hetero_graph['disease', 'affects', 'plant'].edge_index = torch.tensor(
    [[0, 1, 2], [0, 1, 1]], dtype=torch.long)
# 为"防治"关系添加特征,假设特征维度为2
hetero_graph['pesticide', 'prevents', 'disease'].edge_attr = torch.randn(2, 2)
# 为"影响"关系添加特征,假设特征维度为3
hetero_graph['disease', 'affects', 'plant'].edge_attr = torch.randn(3, 2)
hetero_graph
>>HeteroData(
  pesticide={ x=[3, 5] },
  disease={ x=[4, 5] },
  plant={ x=[2, 5] },
  (pesticide, prevents, disease)={
    edge_index=[2, 2],
    edge_attr=[2, 2],
  },
  (disease, affects, plant)={
    edge_index=[2, 3],
    edge_attr=[3, 2],
  }
)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值