本次学习图神经网络的一般过程,直接放代码,对比MLP, GCN和GAT在cora数据集上做节点分类的效果:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv, GATConv
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
from torch.nn import Linear
import torch.nn.functional as F
dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
def visualize(h, color, fig_name):
z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.xticks([])
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
plt.savefig(fig_name)
class MLP(torch.nn.Module):
def __init__(self, hidden_channels):
super(MLP, self).__init__()
torch.manual_seed(12345)
self.lin1 = Linear(dataset.num_features, hidden_channels)
self.lin2 = Linear(hidden_channels, dataset.num_classes)
def forward(self, x):
x = self.lin1(x)
x = x.relu()
x = F.dropout(x, p =0.5, training=self.training)
x = self.lin2(x)
return x
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p =0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
class GAT(torch.nn.Module):
def __init__(self, hidden_channels):
super(GAT, self).__init__()
torch.manual_seed(12345)
self.conv1 = GATConv(dataset.num_features, hidden_channels)
self.conv2 = GATConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p =0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss
for epoch in range(1, 501):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
在该数据集中,MLP, GCN和GAT的准确率分别为0.5877,0.8140,0.7380。
在节点表征的学习中,MLP神经网络只考虑了节点自身属性,忽略了节点之间的连接关系,它的结果是最差的;而GCN图神经网络与GAT图神经网络,同时考虑了节点自身信息与周围邻接节点的信息,因此它们的结果都优于MLP神经网络。也就是说,对周围邻接节点的信息的考虑,是图神经网络由于普通深度神经网络的原因。
GCN图神经网络与GAT图神经网络的相同点为:
- 它们都遵循消息传递范式;
- 在邻接节点信息变换阶段,它们都对邻接节点做归一化和线性变换;
- 在邻接节点信息聚合阶段,它们都将变换后的邻接节点信息做求和聚合;
- 在中心节点信息变换阶段,它们都只是简单返回邻接节点信息聚合阶段的聚合结果。
GCN图神经网络与GAT图神经网络的区别在于采取的归一化方法不同:
- 前者根据中心节点与邻接节点的度计算归一化系数,后者根据中心节点与邻接节点的相似度计算归一化系数。
- 前者的归一化方式依赖于图的拓扑结构:不同的节点会有不同的度,同时不同节点的邻接节点的度也不同,于是在一些应用中GCN图神经网络会表现出较差的泛化能力。
- 后者的归一化方式依赖于中心节点与邻接节点的相似度,相似度是训练得到的,因此不受图的拓扑结构的影响,在不同的任务中都会有较好的泛化表现
- 后者效果有时候不如前者可能是因为搜索空间更大
作业:
本次使用tudataset,做一个图分类问题。和节点分类的区别是:图分类需要加一个pool层将所有节点信息汇总。
下面是GCN代码,GAT只需将GCNConv换成GATConv即可
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
class GCN(torch.nn.Module):
def __init__(self, hidden_channels,num_classes):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.hidden_channels = hidden_channels
self.num_classes = num_classes
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p =0.5, training=self.training)
x = self.conv2(x, edge_index)
x = x.relu()
x = F.dropout(x, p =0.5, training=self.training)
x = pyg_nn.global_mean_pool(x, batch)
x = torch.nn.Linear(self.hidden_channels, self.num_classes)(x)
return F.log_softmax(x, dim=1)