图神经网络-图数据集的表示和使用
Data类-PyG的图表示和使用
首先需要自己安装PyG(pytorch geometric)库
官方文档的具体介绍Data类
- Data函数
class Data(object):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行为尾节点,第1行为头节点,头指向尾
edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
"""
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item
具体简单地构建一个图数据需要五个输入,分别是x(节点特征), edge_index(节点的边,大小为2*节点数量), edge_attr(边特征), y(节点标签), num_nodes(节点数量)
graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, num_nodes=num_nodes, other_attr=other_attr)
这样就完成了一个图的构建!
- 查看图的属性
比如:
'''
graph['x']->查看节点特征
graph['num_nodes']->查看graph节点个数
'''
- 修改图的属性
比如:
'''
graph['x'] = new_x
graph['y'] = new_y
'''
- 获取Data对象属性的关键字
graph.keys()
- Data的其他属性
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
data = dataset[0] # Get the first graph object.
print(f'Number of nodes: {data.num_nodes}') # 节点数量
print(f'Number of edges: {data.num_edges}') # 边数量
print(f'Number of node features: {data.num_node_features}') # 节点属性的维度
print(f'Number of node features: {data.num_features}') # 同样是节点属性的维度
print(f'Number of edge features: {data.num_edge_features}') # 边属性的维度
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}') # 平均节点度
print(f'if edge indices are ordered and do not contain duplicate entries.: {data.is_coalesced()}') # 是否边是有序的同时不含有重复的边
print(f'Number of training nodes: {data.train_mask.sum()}') # 用作训练集的节点
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}') # 用作训练集的节点的数量
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}') # 此图是否包含孤立的节点
print(f'Contains self-loops: {data.contains_self_loops()}') # 此图是否包含自环的边
print(f'Is undirected: {data.is_undirected()}') # 此图是否是无向图
Dataset类-PyG中图数据集的表示和使用
这里直接导入PyG中内置的数据集(后续会自己构建数据集来使用gnn-自己构建巨麻烦)
导入PyG的名为Cora的dataset,dataset一般储存大的图数据,Data一般储存小的图数据
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/dataset/Cora', name='Cora')
#输出图的属性
data = dataset[0]
print(data)
print(f'class_num:{dataset.num_classes}')
print(f'feature_dim:{dataset.num_node_features}')
print(f'is undirected:{data.is_undirected()}')
print(f'train_size:{data.train_mask.sum().item()}')
print(f'val_size:{data.val_mask.sum().item()}')
print(f'test_size:{data.test_mask.sum().item()}')
数据集的使用-构建GCN模型
GCN的核心公式:
x
X
′
=
D
^
−
1
/
2
A
^
D
^
−
1
/
2
X
Θ
,
x \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}\mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
xX′=D^−1/2A^D^−1/2XΘ,概括一下GCN就是-节点每次迭代的结果是他的邻居节点的特征以及邻居节点的度卷积而成的。
- H为节点的特征
- D为图的度矩阵
- A为邻接矩阵
- W是一个超参数,应该是正则或者其他的作用
GCN模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
#优化
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#训练模型
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
在测试集上评估结果-Accuracy
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('Accuracy: {:.4f}'.format(acc))
练习
- 请通过继承
Data
类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法
class Data(object):
def __init__(self, x_og=None, x_autor=None, x_paper=None, edge_index_ao=None, edge_index_ap=None,
edge_attr_ap=None, edge_attr_ao=None, y=None, **kwargs):
self.x_og = x_og
self.x_autor = x_autor
self.x_paper = x_paper
self.edge_index_ao = edge_index_ao
self.edge_index_ap = edge_index_ap
self.edge_attr_ap = edge_attr_ap
self.edge_attr_ao = edge_attr_ao
self.y = y
def get_node_og():
return x_og.shape[0]
def get_node_autor():
return x_autor.shape[0]
def get_node_x_paper():
return x_x_paper.shape[0]