Task 03 基于图神经网络的节点表征学习

本文通过代码展示了在Cora数据集上,MLP、GCN和GAT进行节点分类的效果。GCN和GAT优于只考虑节点特征的MLP,GCN基于节点度进行归一化,而GAT使用注意力机制,其归一化依赖于节点相似度,具有更好的泛化能力。此外,还提到图分类任务中需要加入pool层来汇总节点信息。
摘要由CSDN通过智能技术生成

本次学习图神经网络的一般过程,直接放代码,对比MLP, GCN和GAT在cora数据集上做节点分类的效果:

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GCNConv, GATConv
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
from torch.nn import Linear
import torch.nn.functional as F

dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

def visualize(h, color, fig_name):
    z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.savefig(fig_name)


class MLP(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(MLP, self).__init__()
        torch.manual_seed(12345)
        self.lin1 = Linear(dataset.num_features, hidden_channels)
        self.lin2 = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x):
        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p =0.5, training=self.training)
        x = self.lin2(x)

        return x

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p =0.5, training=self.training)
        x = self.conv2(x, edge_index)

        return x

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GAT, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_features, hidden_channels)
        self.conv2 = GATConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p =0.5, training=self.training)
        x = self.conv2(x, edge_index)

        return x

model = GAT(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    return loss

for epoch in range(1, 501):
    loss = train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')


def test():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.
    test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.
    return test_acc

test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

在该数据集中,MLP, GCN和GAT的准确率分别为0.5877,0.8140,0.7380。

在节点表征的学习中,MLP神经网络只考虑了节点自身属性,忽略了节点之间的连接关系,它的结果是最差的;而GCN图神经网络与GAT图神经网络,同时考虑了节点自身信息与周围邻接节点的信息,因此它们的结果都优于MLP神经网络。也就是说,对周围邻接节点的信息的考虑,是图神经网络由于普通深度神经网络的原因。

GCN图神经网络与GAT图神经网络的相同点为:

  • 它们都遵循消息传递范式;
  • 在邻接节点信息变换阶段,它们都对邻接节点做归一化和线性变换;
  • 在邻接节点信息聚合阶段,它们都将变换后的邻接节点信息做求和聚合;
  • 在中心节点信息变换阶段,它们都只是简单返回邻接节点信息聚合阶段的聚合结果。

GCN图神经网络与GAT图神经网络的区别在于采取的归一化方法不同:

  • 前者根据中心节点与邻接节点的度计算归一化系数,后者根据中心节点与邻接节点的相似度计算归一化系数。
  • 前者的归一化方式依赖于图的拓扑结构:不同的节点会有不同的度,同时不同节点的邻接节点的度也不同,于是在一些应用中GCN图神经网络会表现出较差的泛化能力。
  • 后者的归一化方式依赖于中心节点与邻接节点的相似度,相似度是训练得到的,因此不受图的拓扑结构的影响,在不同的任务中都会有较好的泛化表现
  • 后者效果有时候不如前者可能是因为搜索空间更大

作业:
本次使用tudataset,做一个图分类问题。和节点分类的区别是:图分类需要加一个pool层将所有节点信息汇总。

下面是GCN代码,GAT只需将GCNConv换成GATConv即可

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels,num_classes):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes
        self.conv1 = GCNConv(dataset.num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p =0.5, training=self.training)
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p =0.5, training=self.training)
        x = pyg_nn.global_mean_pool(x, batch)
        x = torch.nn.Linear(self.hidden_channels, self.num_classes)(x)
        
        return F.log_softmax(x, dim=1)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值