异构图 Hetero_graph
创建异构图
异构图创建一般是先创建一个空的异构图,然后在添加节点和边
from torch_geometric.data import HeteroData
hetero_data = HeteroData()#创建空异构图
节点
- 添加新的节点类
异构图的存储类似与字典,因此可以直接通过下标来添加新的节点类
hetero_data['new_node_type'].x = torch.randn(num_nodes,num_nodes_features)
- 添加新结点:由于节点的存储是矩阵,每一行是一个结点,因此只需要创建相同特征的
tensor
并且cat
到特征矩阵x上即可:
new_nodes_features = torch.randn(2, 4)
hetero_data['existing_node_type'].x = torch.cat([hetero_data['existing_node_type'].x, new_nodes_features], dim=0)
- 访问结点:torch geometric图的存储方式类似于字典,因此可以类似与字典的下标运算符一样访问不同类型节点的特征,属性
node_features=hetero_data['new_node_type'].x
node_attr=hetero_data['new_node_type'].attr
边
-
添加边的范式:
hetero_graph['source', 'relation', 'target'].edge_index=torch.tensor[source_list,target_list]
例如:假设第1种农药可以防治第1种和第2种病害
hetero_graph['pesticide', 'prevents', 'disease'].edge_index = torch.tensor([[0, 0], [0, 1]])
-
可以通过
edge_attr
来构建边的特征
hetero_graph['source', 'relation', 'target'].edge_attr = torch.randn(2, 2)
-
访问:torch geometric图的存储方式类似于字典,因此可以类似与字典的下标运算符一样访问不同类型边的特征,属性…
edge_index = hetero_graph['source', 'relation', 'target'] # 访问边的连接信息
常用函数
-
to_homogeneous
:将异构图转换为同构图,用于某些只支持同构图输入的图神经网络模型。 -
to(device)
:将所有的图数据移动到指定的设备(如cuda:0,cpu)。 -
clone()
:创建 HeteroData 对象的深拷贝
异构图例子
import torch
from torch_geometric.data import HeteroData
hetero_graph = HeteroData()#空图
# 假设3种农药,每种农药的特征维度为5
hetero_graph['pesticide'].x = torch.randn(3, 5)
# 假设有4种病害,每种病害的特征维度为5
hetero_graph['disease'].x = torch.randn(4, 5)
# 假设有2种植物,每种植物的特征维度为5
hetero_graph['plant'].x = torch.randn(2, 5)
# 定义农药到病害的"防治"关系
# 假设第1种农药可以防治第1种和第2种病害
hetero_graph['pesticide', 'prevents', 'disease'].edge_index = torch.tensor(
[[0, 0], [0, 1]], dtype=torch.long)
# 定义病害到植物的"影响"关系
# 假设第1种病害影响第1种植物,第2种和第3种病害影响第2种植物
hetero_graph['disease', 'affects', 'plant'].edge_index = torch.tensor(
[[0, 1, 2], [0, 1, 1]], dtype=torch.long)
# 为"防治"关系添加特征,假设特征维度为2
hetero_graph['pesticide', 'prevents', 'disease'].edge_attr = torch.randn(2, 2)
# 为"影响"关系添加特征,假设特征维度为3
hetero_graph['disease', 'affects', 'plant'].edge_attr = torch.randn(3, 2)
hetero_graph
>>HeteroData(
pesticide={ x=[3, 5] },
disease={ x=[4, 5] },
plant={ x=[2, 5] },
(pesticide, prevents, disease)={
edge_index=[2, 2],
edge_attr=[2, 2],
},
(disease, affects, plant)={
edge_index=[2, 3],
edge_attr=[3, 2],
}
)