DW_图深度学习_Task_3

DW_图深度学习_Task_3
学习内容:基于图神经网络的节点表征学习
学习地址: github/datawhalechina/team-learning-nlp/GNN/

本节的学习内容是基于图神经网络的节点表征学习,具体是学习实现多层图神经网络的方法。

学习材料里是基于对MLP, GCN以及GAT的实现并进行效果对比,来学习比较三者在表征学习能力上的差异。在整理学习内容之前,下面先罗列图神经网络的五大分类。本次学习主要是涉及前两种图神经网络

  • 图卷积网络(Graph Convolution Networks,GCN)
  • 图注意力网络(Graph Attention Networks)
  • 图自编码器( Graph Autoencoders)
  • 图生成网络( Graph Generative Networks)
  • 图时空网络(Graph Spatial-temporal Networks)

图神经网络分类综述接下来是学习材料部分的内容:

准备

数据

获取并打印Cora数据集的一些参数:

Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
======================
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Training node label rate: 0.05
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True

可视化

定义可视化节点表征分布的方法

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def visualize(h, color):
    z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())
    plt.figure(figsize=(10,10))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

MLP

构造网络

MLP(
(lin1): Linear(in_features=1433, out_features=16, bias=True)
(lin2): Linear(in_features=16, out_features=7, bias=True)
)

训练

model = MLP(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)  # Define optimizer.

def train():
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    out = model(data.x)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

for epoch in range(1, 201):
    loss = train()
#     print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 20 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

Epoch: 020, Loss: 1.7441
Epoch: 040, Loss: 1.2543
Epoch: 060, Loss: 0.8578
Epoch: 080, Loss: 0.6368
Epoch: 100, Loss: 0.5350
Epoch: 120, Loss: 0.4745
Epoch: 140, Loss: 0.4031
Epoch: 160, Loss: 0.3782
Epoch: 180, Loss: 0.4203
Epoch: 200, Loss: 0.3810

测试

Test Accuracy: 0.5900

GCN

PyG中GCNConv模块参数

GCNConv构造函数接口:

GCNConv(in_channels: int, out_channels: int, improved: bool = False, cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, **kwargs)
  • in_channels:输入数据维度;
  • out_channels:输出数据维度;
  • improved:如果为true A ^ = A + 2 I \mathbf{\hat{A}} = \mathbf{A} + 2\mathbf{I} A^=A+2I,其目的在于增强中心节点自身信息;
  • cached:是否存储 D ^ − 1 / 2 A ^ D ^ − 1 / 2 \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} D^1/2A^D^1/2的计算结果以便后续使用,这个参数只应在归纳学习(transductive learning)的场景中设置为true(归纳学习可以简单理解为在训练、验证、测试、推理(inference)四个阶段都只使用一个数据集);
  • add_self_loops:是否在邻接矩阵中增加自环边;
  • normalize:是否添加自环边并在运行中计算对称归一化系数;
  • bias:是否包含偏置项。

基于和MLP相同的训练和测试步骤,GCN的准确率为:

Test Accuracy: 0.8140

GAT

PyG中GATConv模块参数

GATConv构造函数接口:

GATConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, dropout: float = 0.0, add_self_loops: bool = True, bias: bool = True, **kwargs)

其中:

  • in_channels:输入数据维度;
  • out_channels:输出数据维度;
  • heads:在GATConv使用多少个注意力模型(Number of multi-head-attentions);
  • concat:如为true,不同注意力模型得到的节点表征被拼接到一起(表征维度翻倍),否则对不同注意力模型得到的节点表征求均值;
    基于和MLP相同的训练和测试步骤,GCN的准确率为:

Test Accuracy: 0.7380

比较

目前看来,MLP的准确率是最低的。而GCN和GAT因为考虑了节点自身以及周围邻接节点的信息,所做的预测比MLP要跟好。关于其他三种图神经网络的内容,将作为后续的补充学习内容。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值