最近在研究使用PyG进行图数据处理时,发现PyG从networkx导入数据时,原有节点的名字都被转为了整数,而后面还要把训练的节点嵌入与原有节点名字标签一一对应起来。
经过一番探索,发现是PyG的from_networkx函数在实现时,用到了networkx.relabel.convert_node_labels_to_integers函数:
def from_networkx(G, group_node_attrs: Optional[Union[List[str], all]] = None,
group_edge_attrs: Optional[Union[List[str], all]] = None):
r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
:class:`torch_geometric.data.Data` instance.
Args:
G (networkx.Graph or networkx.DiGraph): A networkx graph.
group_node_attrs (List[str] or all, optional): The node attributes to
be concatenated and added to :obj:`data.x`. (default: :obj:`None`)
group_edge_attrs (List[str] or all, optional): The edge attributes to
be concatenated and added to :obj:`data.edge_attr`.
(default: :obj:`None`)
.. note::
All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must
be numeric.
"""
import networkx as nx
G = nx.convert_node_labels_to_integers(G)
G = G.to_directed() if not nx.is_directed(G) else G
edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
data = defaultdict(list)
# 后面省略
以上源码来自:torch_geometric.utils.convert — pytorch_geometric 2.0.1 documentation
下面来看networkx中convert_node_labels_to_integers的用法。
convert_node_labels_to_integers(G, first_label=0, ordering='default', label_attribute=None)[source]
Returns a copy of the graph G with the nodes relabeled using consecutive integers.
Parameters
G graph
A NetworkX graph
first_label int, optional (default=0)
An integer specifying the starting offset in numbering nodes. The new integer labels are numbered first_label, …, n-1+first_label.
ordering string
“default” : inherit node ordering from G.nodes() “sorted” : inherit node ordering from sorted(G.nodes()) “increasing degree” : nodes are sorted by increasing degree “decreasing degree” : nodes are sorted by decreasing degree
label_attribute string, optional (default=None)
Name of node attribute to store old label. If None no attribute is created.
只要是节点在转整数时,排序方法一致,出来的顺序就是一样的。因此可以把原节点标签暂存到nodename属性上,后面再直接用整数对应也比较省事了。
最终解决方案如下:
import torch
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
import torch.nn.functional as F
from torch_geometric.utils import from_networkx, to_networkx, add_self_loops, degree as pyg_degree
import networkx as nx
if __name__ == "__main__":
#构建图
G = nx.DiGraph()
G.add_node("来", embedding=[1.0,2.0])
G.add_node("去", embedding=[1.4,2.1])
G.add_node("灰", embedding=[1.6,1.1])
G.add_node("粉", embedding=[2.1,2.9])
G.add_edge("来","灰")
G.add_edge("粉","灰")
G.add_edge("去","灰")
G.add_edge("来","粉")
G.add_edge("灰","粉")
#将G转为整数后,返回一个副本G1,原节点名存到nodename属性中
G1 = nx.relabel.convert_node_labels_to_integers(G, label_attribute="nodename")
print("AFTER relabel:")
print( G1.nodes(data=True))
#转为PyG数据,将'embedding'属性作为data['x']
data = from_networkx(G, ['embedding'])
#使用pytorch处理,每个数字加10
data['x'].add_( 10.0)
#处理完成后,再转出为networkx的图对象,带上训练后的节点属性
G = to_networkx(data, node_attrs=['x'])
print(nx.info(G))
print(G.nodes(data=True))
for n,d in G.nodes(data=True):
print(n,d)
#更新G1中的属性
G1.nodes[n]['embedding']=d['x']
print(G1.nodes(data=True))