最近在学习图神经网络,进行了一个小小的demo练习,学习了一下怎么建图,怎么进行图神经网络的训练~
1、库的导入
%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch.nn import Linear
from torch_geometric.nn import GCNConv
import torch
import torch.nn as nn
这一部分需要首先安装torch和torch_geometric,不会安装的小伙伴可以看我之前的博客,需要注意一些版本之间的关系。
torch_geometric踩坑实战--安装与运行 亲测有效!!_汤汤upup的博客-CSDN博客
2、数据集的导入及查看
torch_geometric中有一些自带数据集,这次用到的是KarateClub空手道俱乐部的数据集
from torch_geometric.datasets import KarateClub
dataset=KarateClub()
data=dataset[0]
data
输出:
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
数据集信息如下:
总共34个节点,每个节点34个特征,156条边,节点分为4个类别,任务是对节点进行分类,判断节点属于哪个小团体。因此输出中的x=[34,34]分别代表每个节点的特征个数,以及节点个数,edge_index=[2,156],2是该属性中的不变数值,代表两个节点相连,156代表边的条数。
3、进行数据集的可视化
def visualize_graph(G,color):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap="Set2")
plt.show()
from torch_geometric.utils import to_networkx
G=to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)
利用networkx对Data中的数据进行可视化,得到的可视化结果如下:
4、模型的训练
搭建两层的GCN模型,搭建过程非常简单,不会的同学建议看一下pytorch基础~
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = F.softmax(x, dim=1)
return x
model = GCN(dataset.num_node_features, dataset.num_classes)
print(model)
---------------------------------------------------------------
GCN(
(conv1): GCNConv(34, 16)
(conv2): GCNConv(16, 4)
)
对模型进行训练
def train(model, data):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
out = model(data)
optimizer.zero_grad()
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))
train(model,data)
训练结果如下:
一个简单的图神经网络学习的demo就完成啦,当然如果需要实用自己的数据集也可以,可以使用如下代码:
from torch_geometric.data import Data
x = torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtype=torch.float)
y = torch.tensor([0,1,0,1],dtype=torch.float)
edge_index = torch.tensor([[0,1,2,0,3],
[1,0,1,3,2]],dtype=torch.long)
data = Data(x=x,y=y,edge_index=edge_index)
其中x代表节点的特征,可以看出特征的维度是2*1,y代表节点的标签,edge_index代表边的连接,是一个2*n维的矩阵,表示节点0和节点1相连,等等
建图的结果如下:
本文主要介绍的是节点分类算法,当然图分类和边分类大同小异~