DGL官方教程--图注意力网络(GAT)

Note:
Click here to download the full example code

Graph attention network

Authors: Hao Zhang, Mufei Li
, Minjie Wang Zheng Zhang
在本教程中,您将学习图注意力网络(GAT)以及如何在PyTorch中实现它。您还可以学习可视化并了解注意力机制所学到的知识。

图卷积网络(GCN)中描述的研究表明,结合局部图结构和节点级特征可以在节点分类任务上产生良好的性能。但是,GCN聚合的方式取决于结构,这可能会损害其通用性。

一种解决方法是按研究论文GraphSAGE中所述简单地平均所有邻居节点特征。但是,Graph Attention Network提出了另一种类型的聚合。GAN以关注方式使用具有特征依赖和无结构归一化的加权邻居特征。

Introducing attention to GCN

GAT和GCN之间的主要区别在于如何汇总来自一跳社区的信息。

对于GCN,图卷积运算会生成邻居节点特征的归一化总和。
h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) 1 c i j W ( l ) h j ( l ) ) h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right) hi(l+1)=σjN(i)cij1W(l)hj(l)
哪里 N ( i ) \mathcal{N}(i) N(i)是其一跳邻居的集合(包括 v i v_i vi在集合中,只需向每个节点添加一个自环), c i j = ∣ N ( i ) ∣ ∣ N ( j ) ∣ c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|} cij=N(i) N(j) 是基于图结构的归一化常数, σ \sigma σ是激活功能(GCN使用ReLU),并且 W ( l ) W^{(l)} W(l)是用于节点特征转换的共享权重矩阵。GraphSAGE中提出的另一种模型 采用相同的更新规则,只是它们设置了 c i j = ∣ N ( i ) ∣ c_{ij}=|\mathcal{N}(i)| cij=N(i)
GAT引入了注意力机制,以替代静态归一化卷积运算。以下是计算节点嵌入的方程式 h i ( l + 1 ) h_i^{(l+1)} hi(l+1)层数 l + 1 l+1 l+1从图层的嵌入 l l l
在这里插入图片描述
z i ( l ) = W ( l ) h i ( l ) , ( 1 ) e i j ( l ) = LeakyReLU ( a ⃗ ( l ) T ( z i ( l ) ∣ ∣ z j ( l ) ) ) , ( 2 ) α i j ( l ) = exp ⁡ ( e i j ( l ) ) ∑ k ∈ N ( i ) exp ⁡ ( e i k ( l ) ) , ( 3 ) h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) α i j ( l ) z j ( l ) ) , ( 4 ) z_i^{(l)}=W^{(l)}h_i^{(l)},(1)\\ e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),(2)\\ \alpha_{ij}^{(l)}=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},(3)\\ h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),(4) zi(l)=W(l)hi(l),(1)eij(l)=LeakyReLU(a (l)T(zi(l)zj(l))),(2)αij(l)=kN(i)exp(eik(l))exp(eij(l)),(3)hi(l+1)=σjN(i)αij(l)zj(l),(4)
说明:

  • 公式(1)是下层嵌入的线性变换 h(l)i 和 W(l) 是其可学习的权重矩阵。
  • 公式(2)计算两个邻居之间的成对非标准化注意力得分。在这里,它首先将z 两个节点的嵌入 || 表示串联,然后取其点积和可学习的权重向量 a⃗ (l),最后应用LeakyReLU。这种注意形式通常称为加性注意,与Transformer模型中的点积注意相反。
  • 公式(3)应用softmax来标准化每个节点进入边缘上的注意力得分。
  • 等式(4)类似于GCN。来自邻居的嵌入被聚集在一起,并按照注意力得分进行缩放。
    本文还有其他详细信息,例如退出和跳过连接。为了简单起见,本教程省略了这些详细信息。要查看更多详细信息,请下载完整示例。本质上,GAT只是一种不同的聚合函数,它关注邻居的特征,而不是简单的均值聚合。

GAT in DGL

首先,您可以大致了解如何GATLayer在DGL中实现模块。在本节中,上面的四个方程式一次分解为一个。

import torch
import torch.nn as nn
import torch.nn.functional as F


class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

Equation (1)

z i ( l ) = W ( l ) h i ( l ) , ( 1 ) z_i^{(l)}=W^{(l)}h_i^{(l)},(1) zi(l)=W(l)hi(l),(1)
第一个显示线性变换。这很常见,可以使用在Pytorch中轻松实现torch.nn.Linear

Equation (2)

e i j ( l ) = LeakyReLU ( a ⃗ ( l ) T ( z i ( l ) ∣ z j ( l ) ) ) , ( 2 ) e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2) eij(l)=LeakyReLU(a (l)T(zi(l)zj(l))),(2)
非标准化注意力得分 e i j e_{ij} eij使用相邻节点的嵌入来计算 i i i j j j。这表明注意力得分可以看作是边缘数据,可以由apply_edgesAPI 计算得出 。的参数apply_edges是Edge UDF,其定义如下:

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e' : F.leaky_relu(a)}

在这里,点积与可学习的权重向量 a ( l ) ⃗ \vec{a^{(l)}} a(l) 使用PyTorch的线性变换再次实现attn_fc。需要注意的是apply_edges意志批次都在同一个张量的边缘数据,所以 catattn_fc这里是平行的所有边应用。

Equation (3) & (4)

α i j ( l ) = exp ⁡ ( e i j ( l ) ) ∑ k ∈ N ( i ) exp ⁡ ( e i k ( l ) ) , ( 3 ) h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) α i j ( l ) z j ( l ) ) , ( 4 ) \alpha_{ij}^{(l)}=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},(3)\\ h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),(4) αij(l)=kN(i)exp(eik(l))exp(eij(l)),(3)hi(l+1)=σjN(i)αij(l)zj(l),(4)
与GCN类似,update_allAPI用于触发所有节点上的消息传递。消息函数发出两个张量:z 源节点的转换嵌入和e每个边缘上的非标准化注意力得分。然后reduce函数执行两项任务:

  • 使用softmax(公式(3))对注意力得分进行归一化。
  • 通过关注分数加权的邻居嵌入的总和(等式(4))。

这两个任务都首先从邮箱中获取数据,然后在dim=1批量处理邮件的第二个维度()上对其进行操作。

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h' : h}

Multi-head attention

类似于ConvNet中的多个渠道,GAT引入了多头关注,以丰富模型功能并稳定学习过程。每个关注头都有自己的参数,它们的输出可以通过两种方式合并:
concatenation : h i ( l + 1 ) = ∣ ∣ k = 1 K σ ( ∑ j ∈ N ( i ) α i j k W k h j ( l ) ) \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) concatenation:hi(l+1)=k=1KσjN(i)αijkWkhj(l)
要么
average : h i ( l + 1 ) = σ ( 1 K ∑ k = 1 K ∑ j ∈ N ( i ) α i j k W k h j ( l ) ) \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) average:hi(l+1)=σK1k=1KjN(i)αijkWkhj(l)
哪里 K K K是头数。您可以将串联用于中间层,将平均值用于最后一层。
将上面定义的单头GATLayer用作以下内容的构建基块MultiHeadGATLayer

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

Put everything together

现在,您可以定义一个两层的GAT模型。

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

然后,我们使用DGL的内置数据模块加载Cora数据集。

from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx

def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = data.graph
    # add self loop
    g.remove_edges_from(nx.selfloop_edges(g))
    g = DGLGraph(g)
    g.add_edges(g.nodes(), g.nodes())
    return g, features, labels, mask

训练循环与GCN教程中的完全相同。

import time
import numpy as np

g, features, labels, mask = load_cora_data()

# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))

out:

/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9462 | Time(s) nan
Epoch 00001 | Loss 1.9456 | Time(s) nan
Epoch 00002 | Loss 1.9449 | Time(s) nan
Epoch 00003 | Loss 1.9442 | Time(s) 0.2926
Epoch 00004 | Loss 1.9435 | Time(s) 0.2970
Epoch 00005 | Loss 1.9428 | Time(s) 0.2960
Epoch 00006 | Loss 1.9421 | Time(s) 0.2944
Epoch 00007 | Loss 1.9414 | Time(s) 0.2955
Epoch 00008 | Loss 1.9406 | Time(s) 0.2947
Epoch 00009 | Loss 1.9399 | Time(s) 0.2945
Epoch 00010 | Loss 1.9391 | Time(s) 0.2945
Epoch 00011 | Loss 1.9384 | Time(s) 0.2955
Epoch 00012 | Loss 1.9376 | Time(s) 0.2951
Epoch 00013 | Loss 1.9368 | Time(s) 0.2946
Epoch 00014 | Loss 1.9360 | Time(s) 0.2955
Epoch 00015 | Loss 1.9351 | Time(s) 0.2954
Epoch 00016 | Loss 1.9343 | Time(s) 0.2954
Epoch 00017 | Loss 1.9334 | Time(s) 0.2953
Epoch 00018 | Loss 1.9325 | Time(s) 0.2959
Epoch 00019 | Loss 1.9317 | Time(s) 0.2964
Epoch 00020 | Loss 1.9307 | Time(s) 0.2963
Epoch 00021 | Loss 1.9298 | Time(s) 0.2968
Epoch 00022 | Loss 1.9289 | Time(s) 0.2968
Epoch 00023 | Loss 1.9279 | Time(s) 0.2966
Epoch 00024 | Loss 1.9269 | Time(s) 0.2964
Epoch 00025 | Loss 1.9259 | Time(s) 0.2955
Epoch 00026 | Loss 1.9249 | Time(s) 0.2955
Epoch 00027 | Loss 1.9238 | Time(s) 0.2944
Epoch 00028 | Loss 1.9228 | Time(s) 0.2937
Epoch 00029 | Loss 1.9217 | Time(s) 0.2927

Visualizing and understanding attention learned

Cora

下表总结了GAT论文中报告并通过DGL实现获得的Cora模型性能 。

ModelAccuracy
GCN (paper)81.4±0.5
GCN (dgl)82.05±0.33
GAT (paper)83.0±0.7
GAT (dgl)83.69±0.529

我们的模型学到了什么样的注意力分布?
因为注意体重 a i j a_{ij} aij与边缘相关联,您可以通过为边缘着色来形象化它。在下面,您可以选择Cora的一个子图,并绘制最后一个的注意权重GATLayer。节点根据其标签进行着色,而边缘根据注意权重的大小进行着色,这可以通过右侧的色条来参考。
在这里插入图片描述
您可以看到该模型似乎学习了不同的注意力权重。要更全面地了解分布,请测量注意力分布的。对于任何节点 i i i { α i j } j ∈ N ( i ) \{\alpha_{ij}\}_{j\in\mathcal{N}(i)} {αij}jN(i)通过以下公式给出的熵在其所有邻居上形成离散的概率分布
H ( α i j j ∈ N ( i ) ) = − ∑ j ∈ N ( i ) α i j log ⁡ α i j H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij} H(αijjN(i))=jN(i)αijlogαij
熵低意味着集中度高,反之亦然。熵为0表示所有注意力都集中在一个源节点上。均匀分布具有最高的熵 log ⁡ ( N ( i ) ) \log(\mathcal{N}(i)) log(N(i))。理想情况下,您希望看到模型学习较低熵的分布(即,一个或两个邻居比其他邻居重要得多)。

注意,由于节点可以具有不同的度数,因此最大熵也将不同。因此,您可以绘制整个图中所有节点的熵值的聚合直方图。以下是每个注意头学习到的注意直方图。
在这里插入图片描述
作为参考,以下是所有节点均具有统一注意力权重分布的直方图。
在这里插入图片描述
可以看到,学习到的注意力值非常类似于均匀分布 (即,所有邻居都同等重要)。这部分解释了为什么GAT在Cora上的性能接近GCN的性能(根据作者的报告结果,100次运行的平均准确度差异小于2%)。注意并不重要,因为它的区别不大。

*这是否意味着注意力机制没有用?*并不是!另一个不同的数据集表现出完全不同的模式,如下所示。

Protein-protein interaction (PPI) networks

此处使用的PPI数据集包括 24对应于不同人体组织的图形。节点最多可以121 标签的种类,因此节点的标签表示为大小的二进制张量 121。任务是预测节点标签。

采用 20 训练图 2 用于验证和 2 进行测试。每个图的平均节点数为2372。每个节点都有50由位置基因集,基序基因集和免疫特征组成的特征。至关重要的是,在训练过程中完全看不见测试图,这种设置称为“归纳学习”。

比较GAT和GCN的性能 10 随机运行此任务,并在验证集上使用超参数搜索来找到最佳模型。

ModelF1 Score(micro)
GAT0.975±0.006
GCN0.509±0.025
Paper0.973±0.002

上表是该实验的结果,您可以使用micro F1分数来评估模型性能。

Note:
以下是F1分数的计算过程:
p r e c i s i o n = ∑ t = 1 n T P ( t ) ∑ t = 1 n T P ( t ) + F P ( t ) ) precision=\frac{\sum_{t=1}^{n}TP(t)}{\sum_{t=1}^{n}TP(t)+FP(t))} precision=t=1nTP(t)+FP(t))t=1nTP(t)
r e c a l l = ∑ t = 1 n T P ( t ) ∑ t = 1 n T P ( t ) + F N ( t ) ) recall=\frac{\sum_{t=1}^{n}TP(t)}{\sum_{t=1}^{n}TP(t)+FN(t))} recall=t=1nTP(t)+FN(t))t=1nTP(t)
F 1 m i c r o = 2 p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l F1_{micro}=2\frac{precision*recall}{precision+recall} F1micro=2precision+recallprecisionrecall
T P t TP_t TPt 表示同时具有和预计具有标签的节点数 t t t
F P t FP_t FPt 表示没有但预计具有标签的节点数 t t t
F N t FN_t FNt 代表标记为 t t t 但与其他人一样预测。
n n n 是标签数,即 121 121 121 就我们而言。
在训练过程中,BCEWithLogitsLoss用作损失功能。GAT和GCN的学习曲线如下所示;显而易见的是,与GCN相比,GAT的显着性能优势。
在这里插入图片描述
与以前一样,您可以通过显示节点式注意熵的直方图来对学习到的注意事项进行统计理解。以下是不同注意力层学习的注意力直方图。

在第1层中学习到的注意力:
在这里插入图片描述
在第2层中学习到的注意力:
在这里插入图片描述
在最后一层学到的注意力:

在这里插入图片描述
再次,与均匀分布比较:
在这里插入图片描述
显然,GAT确实学习了敏锐的关注权重!各层上也有清晰的图案:一层越多,注意力越集中。

与Cora数据集的GAT增益极少不同,对于PPI,与GAT论文相比,GPI与其他GNN变体之间存在显着的性能差距(至少20%),并且两者之间的注意力分布明显不同。尽管这值得进一步研究,但一个直接的结论是,GAT的优势可能更多在于处理具有更复杂邻域结构的图形的能力。

What’s next?

到目前为止,您已经了解了如何使用DGL来实现GAT。缺少一些遗漏的详细信息,例如退出,跳过连接和超参数调整,这些实践不涉及DGL相关概念。有关更多信息,请查看完整示例。

  • 请参阅优化的完整示例
  • 下一个教程介绍了如何通过并行化多个关注头和SPMV优化来加速GAT模型。

脚本的总运行时间:(0分钟15.065秒)

下载脚本:9_gat.py

下载脚本:9_gat.ipynb

  • 25
    点赞
  • 86
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
"Structure-Aware Transformer for Graph Representation Learning"是一篇使用Transformer模型进行表示学习的论文。这篇论文提出了一种名为SAT(Structure-Aware Transformer)的模型,它利用了中节点之间的结构信息,以及节点自身的特征信息。SAT模型在多个数据集上都取得了非常好的结果。 以下是SAT模型的dgl实现代码,代码中使用了Cora数据集进行示例: ``` import dgl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(GraphAttentionLayer, self).__init__() self.num_heads = num_heads self.out_dim = out_dim self.W = nn.Linear(in_dim, out_dim*num_heads, bias=False) nn.init.xavier_uniform_(self.W.weight) self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1))) nn.init.xavier_uniform_(self.a.data) def forward(self, g, h): h = self.W(h).view(-1, self.num_heads, self.out_dim) # Compute attention scores with g.local_scope(): g.ndata['h'] = h g.apply_edges(fn.u_dot_v('h', 'h', 'e')) e = F.leaky_relu(g.edata.pop('e'), negative_slope=0.2) g.edata['a'] = torch.cat([e, e], dim=1) g.edata['a'] = torch.matmul(g.edata['a'], self.a).squeeze() g.edata['a'] = F.leaky_relu(g.edata['a'], negative_slope=0.2) g.apply_edges(fn.e_softmax('a', 'w')) # Compute output features g.ndata['h'] = h g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) h = g.ndata['h'] return h.view(-1, self.num_heads*self.out_dim) class SATLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(SATLayer, self).__init__() self.attention = GraphAttentionLayer(in_dim, out_dim, num_heads) self.dropout = nn.Dropout(0.5) self.norm = nn.LayerNorm(out_dim*num_heads) def forward(self, g, h): h = self.attention(g, h) h = self.norm(h) h = F.relu(h) h = self.dropout(h) return h class SAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_heads): super(SAT, self).__init__() self.layer1 = SATLayer(in_dim, hidden_dim, num_heads) self.layer2 = SATLayer(hidden_dim*num_heads, out_dim, 1) def forward(self, g, h): h = self.layer1(g, h) h = self.layer2(g, h) return h.mean(0) # Load Cora dataset from dgl.data import citation_graph as citegrh data = citegrh.load_cora() g = data.graph features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.BoolTensor(data.train_mask) val_mask = torch.BoolTensor(data.val_mask) test_mask = torch.BoolTensor(data.test_mask) # Add self loop g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) # Define model and optimizer model = SAT(features.shape[1], 64, data.num_classes, 8) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Train model for epoch in range(200): model.train() logits = model(g, features) loss = F.cross_entropy(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc = (logits[val_mask].argmax(1) == labels[val_mask]).float().mean() if epoch % 10 == 0: print('Epoch {:03d} | Loss {:.4f} | Accuracy {:.4f}'.format(epoch, loss.item(), acc.item())) # Test model model.eval() logits = model(g, features) acc = (logits[test_mask].argmax(1) == labels[test_mask]).float().mean() print('Test accuracy {:.4f}'.format(acc.item())) ``` 在这个示例中,我们首先加载了Cora数据集,并将其转换为一个DGL。然后,我们定义了一个包含两个SAT层的模型,以及Adam优化器。在训练过程中,我们使用交叉熵损失函数和验证集上的准确率来监控模型的性能。在测试阶段,我们计算测试集上的准确率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值