一文带你浏览Graph Transformers

4c472a803b1af23f5c54686792478d26.gif

©作者 | Dream

单位 | 浙江大学

研究方向 | 图表示学习

写在前头

为什么图上要使用 Transformer?

简要提一下 GT 带来的好处:

1. 能捕获长距离依赖

2. 减轻出现过平滑,过挤压现象

3. GT 中甚至可以结合进 GNN 以及频域信息(Laplacian PE),模型会有更强的表现力。

4. 利用好 [CLS] token,不需要额外使用 pooling function。

5. etc


7cc2d17eb5ba82794bf9b8bcd12cfc4c.png


Graph-Bert

Graph-Bert: Only Attention is Needed for Learning Graph Representations (arXiv 2020)

https://arxiv.org/abs/2001.05140

GNN 过度依赖图上的链接,此外由于大图非常耗显存,可能无法直接进行操作。该论文提出了一种新的只依赖 attention 机制的图网络 Graph-BERT,Graph-BERT 的输入不是全图,而是一堆采样得到的子图(不带边)。作者使用节点属性重构以及结构恢复两个任务进行预训练,然后在下游数据集上进行微调。该方法在节点分类和图聚类任务上达到了 SOTA。

a51b07d40f7d0b4f11618758d48ef0fd.png

▲ 图1:Graph-BERT

和之前 NLP 中的 BERT 不一杨的地方主要是 position encoding,Graph-BERT使用了三种 PE,分别是 WL absolute PE,intimacy based relative PE 和 Hop based relative PE,这里三个 PE 都是根据 complete graph 计算得到的。

为了便于操作,作者将 subgraph 根据图亲密度矩阵进行排序 [i, j, ...m],其中 S(i,j) > S(i,m),得到序列化的结果。

其中 Weisfeiler-Lehman Absolute Role Embedding 如下:

37280753ef5147c0acfc32e8c2c875bb.png

经过 WL 之后,子结构一样的节点就会得到相同的 hash code,如果是 1-WL 有点像 degree centrality(针对无向图而言)。因此,WL PE 可以捕获全局节点角色信息。

Intimacy based Relative Positional Embedding

这个 PE 捕获的是偏 local 的信息,因为输入已经根据图亲密度矩阵进行排序过了,这里只需要简单地设 ,越接近i的节点 会越小。映射公式如下:

9527021ffc19d6121a18516b714c1d51.png

Hop based Relative Distance Embedding

该 PE 捕获的是介于 global 和 local 之间的信息:

f7af20f39b874762e10fee9470774a34.png

将节点 embedding 和这些 PE 加起来,然后输入到 Transformer 中得到 subgraph 每个节点的表征 。因为 subgraph 中有很多节点,作者希望最后只得到 target node 的表征 ,因此作者还设计了一个 Fusion function,原文中是把所有节点表征做一下 average。 都会被输出,根据下游任务选择所需的进行使用。

8bcf8732bb1c5138236aa22ce8a785c2.png

节点分类效果(没有进行pre-training)

0c0c313ccd6553db90317fa5d7f04a00.png

▲ 节点分类

总的来说这篇论文比较新颖的地方在于提出了多种图上的 PE,并且在子图上的效果也可以达到之前 GNN 在全图上的效果,但是实验的数据集太少了,而且也没有使用比较大的图数据集,由于这些数据集较小也没有较好地展现预训练的效果。此外,采样得到的子图最后只是用目标节点进行 loss 计算,这样利用效率太低了,inference 效率同样也很低。

81a97f8ea6b0f55a12270eaa8c0a5dc3.png

GROVER

Self-Supervised Graph Transformer on Large-Scale Molecular Data (NeurIPS 2020)

https://papers.nips.cc/paper/2020/file/94aef38441efa3380a3bed3faf1f9d5d-Paper.pdf

GNN 在分子领域被广泛研究,但是该领域存在两个主要问题:(1)没有那么多标签数据,(2)在新合成的分子上表现出很差的泛化能力。为了解决这连个问题,该论文设计了 node-,edge-,graph-level 的自监督任务,希望可以从大量的未标注数据中捕获分子中的丰富的语义和结构信息。作者在一千万未标注的分子图上训练了一个 100M 参数量的 GNN,然后根据下游任务进行 fine-tuning,在 11 个数据集上都达到了 SOTA(平均提升超过 6 个点)。

模型结构如下:

b6c9183e26f9c6e06e47e4664f4351e4.png

▲ 结构图

因为 message passing 的过程可以捕获到图中的局部结构信息,因此将数据先通过 GNN 得到 Q,K,V,可以使得每个节点表征得到 local subgraph structure 的信息。然后,这些表征通过 self-attention 可以捕获到 global 的信息。为了防止 message passing 出现 over-smoothing 现象,该论文将残差链接变成了 long-range residual connection,直接将原始特征接到后面。

此外,由于 GNN 的层数直接影响到感受野,这将影响到 message passing model 的泛化性。由于不同类型的数据集可能需要的 GNN 层数不同,提前定义好 GNN 的层数可能会影响到模型的性能。作者在训练的时候随机选择层数,随机选择 跳的 MPNN,其中 或者 。这种 MPNN 称为 Dynamic Message Passing networks(dyMPN),实验证明这种结构会比普通的 MPNN 好。

预训练:

论文中使用的自监督任务主要有两个:

1. Contextual property prediction(node/edge level task)

2ee4acaec0d0370c28019c475561dd84.png

▲ Contextual property prediction

一个好的 node-level 自监督任务应该满足两个性质:(1)预测目标是可信的并且是容易得到的(2)预测目标应该反应节点或者边的上下文信息。基于这些性质,作者设计了基于节点和边的自监督任务(通过 subgraph 来预测目标节点/边的上下文相关的性质)

c767e4f6ec25be41796e47e65f245114.png

▲ 例子

举一个例子,给定一个 target node C 原子,我们首先获取它的 k-hop 邻居作为subgraph,当 k=1,N 和 O 原子会包含进来以及单键和双键。然后我们从这个 subgraph 中抽取统计特征(statistical properties),我们会计数出针对 center node(node-edge)pairs的共现次数,可表示为 node-edge-counts,所有的 node-edge counts terms 按字母顺序排,在这个例子中,我们可以得到 C_N-DOUBLE1_O-SINGLE1。这个步骤可以看作是一个聚类过程:根据抽取得到的特征,subgraphs 可以被聚类起来,一种特征(property)对应于一簇具有相同统计属性的子图。

通过这种具有 context-aware property 的定义,全局性质预测任务可以分为以下流程:

输入一个分子图,通过 GROVER encoder 我们可以得到原子和边的 embeddings,随机选择原子 (它的 embedding 为 )。我们不是预测原子 的类别,而是希望 能够编码 node 周围的一些上下文信息(contextual information),实现这一目的的方式是将 输入到一个非常简单的 model(一层全连接),然后使用它的输出去预测节点 的上下文属性(contextual properties),这个预测是一个 multi-class classification(一个类别对应一种contextual property)。

2. Graph-level motif prediction

e9793603b7ef65ac917a5178a5d96d6a.png

▲ Graph-level motif prediction

Graph-level 的自监督任务也需要可信和廉价的标签,motifs 是 input graph data 中频繁出现的 sub-graphs。一类重要的 motifs 是官能团,它编码了分子的丰富的领域知识,并且能够被专业的软件检测到(e.g. RDKit)。因此,我们可以将 motif prediction task 公式化成一个多分类问题,每一个 motif 对应一个 label。假设分子数据集中存在 p 个 motifs ,对于某个具体的分子,我们使用 RDKit 检测该分子中是否出现了 motif,然后把其构建为 motif prediction task 的目标。

针对下游任务进行微调:

在海量未标注数据上完成 pre-training 后,我们获得了一个高质量的分子 encoder,针对具体的下游任务(e.g. node classification, link prediction, the property prediction for molecules, etc),我们需要对模型进行微调,针对 graph-level 的下游任务,我们还需要一个额外的 readout 函数来得到全图的表征(node-level 和 edge-level 就不需要 readout 函数了),然后接一个 MLP 进行分类。

实验:

注:绿色表示进行了pre-training

性能的提升还是比较明显的。

767c3e80ae08e22a37b5692cab0bbb25.png

45308cdb83472149ca731a49a650dcec.png

Graph Transformer Architecture

A Generalization of Transformer Networks to Graphs (DLG-AAAI 2021)

https://arxiv.org/abs/2012.09699

e8a72700066a86af25bc423c344eadc4.png

▲ 模型结构

主要提出使用 Laplacian eigenvector 作为 PE,比 GraphBERT 中使用的 PE 好。

324fc231b5f8c1553380b326180ee391.png

▲ 不同 PE 的效果比较

但是该模型的效果在 self-attention 只关注 neighbors 的时候会更好,与其说是 graph transformer,不如说是带有 PE 的 GAT。Sparse graph 指的是 self-attention 只计算邻居节点,full graph 是指 self-attention 会考虑图中所有节点。

c56664d5a207de301a59382055feabbd.png

▲ 实验结果

e47ba9c9fe5bad4b2efc293ace662430.png

GraphiT

GraphiT: Encoding Graph Structure in Transformers (arXiv 2021)

https://arxiv.org/abs/2106.05667

该工作表明,将结构和位置信息合并到 transformer 中,能够优于现有的经典 GNN。GraphiT(1)利用基于图上的核函数的相对位置编码来影响 attention scores,(2)并编码出 local sub-structures 进行利用。实现发现,无论将这种方法单独使用,还是结合起来使用都取得了不错的效果。

(i) leveraging relative positional encoding strategies in self-attention scores based on positive definite kernels on graphs, and (ii) enumerating and encoding local sub-structures such as paths of short length

之前 GT 发现 self-attention 在只关注 neighboring nodes 的时候会取得比较好的效果,但是在关注到所有节点的时候,性能就不行。这篇论文发现 transformer with global communication 同样可以达到不错的效果。因此,GraphiT 通过一些策略将 local graph structure 编码进模型中,(1)基于正定核的注意力得分加权的相对位置编码策略 (2)通过利用 graph convolution kernel networks(GCKN)将 small sub-structure(e.g.,paths或者subtree patterns)编码出来作为transformer的输入。

Transformer Architectures

31ee5b7fd4fd578ee1d8889f9dbe14df.png

ec04ae5e2d681356f2254c71a587b46d.png

Encoding Node Positions

Relative Position Encoding Strategies by Using Kernels on Graphs

c82caeefdda3e0c6d75c38cfeee25b17.png

Encoding Topological Structures

Graph convolutional kernel networks(GCKN)

42a2afb92e29628a280c33a3abdeca02.png

实验结果

d63e95772f3123f4da69c14aad2cb5a1.png

▲ 实验结果

d6a0d4a40bdf45719eca7103e3920527.png

GraphTrans

Representing Long-Range Context for Graph Neural Networks with Global Attention (NeurIPS 2021)

https://arxiv.org/abs/2201.08821

该论文提出了 GraphTrans,在标准 GNN 层之上添加T ransformer。并提出了一种新的 readout 机制(其实就是 NLP 中的 [CLS] token)。对于图而言,针对 target node 的聚合最好是 permutation-invariant,但是加上 PE 的 transformer 可能就没法实现这个了,因此不使用 PE 在图上是比较自然的。

829d2e0255236e8ac08f1c788054a0b1.png

▲ pipeline

可解释性观察

[CLS]token 是 18,可以发现它和其他 node 交互很频繁。它也许能抽取到更 general 的信息。

8916e72a5af85b6efbe2dcf9fb6ac188.png

虽然这篇论文方法很简单,但做了大量的实验,效果也不错。

NCI biological datasets

5af32e936bb486ec71d44fbaf2f92bcd.png

▲ NCI biological datasets

OpenGraphBenchmark Molpcba dataset

74924b49270cf414d78b7f1ae6ede7aa.png

▲ Molpcba dataset

OpenGraphBenchmark Code2 dataset

7414902095c06d87951ec1f01644f46f.png

▲ Code2 dataset

08c6c161ebbfed57e3ac5b6f994cf7cb.png

SAN

Rethinking Graph Transformers with Spectral Attention (NeurIPS 2021)

https://arxiv.org/abs/2106.03893

这篇论文使用 learnable PE,对为什么 laplacian PE 比较有效进行了比较好的说明,

237ed4c884ffb9a969b707953a8f8350.png

ddcaa5c7d16cb09882a4605fcca2ff81.png

Graphormer


Do Transformers Really Perform Bad for Graph Representation? (NeurIPS 2021)

https://arxiv.org/abs/2106.05234

原作者自己进行了解读:

https://www.msra.cn/zh-cn/news/features/ogb-lsc

核心点在于利用结构信息对 attention score 进行修正,这样比较好地利用上了图的结构信息。

3aec03b9749ac6c5870ea4aa7cbc6717.png

SAT

Structure-Aware Transformer for Graph Representation Learning (ICML 2022)

https://arxiv.org/abs/2202.03036

这篇论文和 GraphTrans 比较类似。也是先通过 GNN 得到新的节点表征,然后再输入到 GT 中。只是这篇论文对抽取结构信息的方式进行了更抽象化的定义(但是出于便利性,还是使用了 GNN)。还有一点不同就是这篇论文还使用了PE(RWPE)。

在这篇论文中,作者展示了使用位置编码的 Transformer 生成的节点表示不一定捕获节点之间的结构相似性。为了解决这个问题,Chen et al. 提出了一种 structure-aware transformer,这是一种建立在新的 self-attention 机制上的 transformer。这种新的 self-attention 在计算 attention 之前会抽取子图的表征(rooted at each node),这样融合进了结构信息。

作者提出了若干种可以自动生成 subgraph representation 的方法,从理论上证明这些表征至少和  subgraph representations 表现力一样。该 structure-aware 框架能够利用已有的 GNN 去抽取 subgraph representation,从实验上证明了模型的性能提升和 GNN 有较大的关系。仅对 Transformer 使用绝对位置编码会表现出过于宽松的结构归纳偏差,这不能保证两个节点具有相似的局部结构的节点生成相似的节点表示。

373471f202b9bf9a22c90e5ed53d7965.png

629c4f529555aed7ed294d277aa3c860.png

GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer (NeurIPS 2022 Under Review)

https://arxiv.org/abs/2205.12454

在这篇论文中,作者对之前使用的 PE 进行了细致的归类(local,global or relative,详见下方表格)。此外,该论文还提出了构建 General,Powerful,Scalable Graph Transformer 的要素有三:(1)positional/structural encoding,(2)local message-passing mechanism,(3)global attention mechanism。针对这三要素,作者设计了一种新的 graph transformer。

ec1acb0be45c324e2819fa922f6b7fe3.png

针对 layer 的设计,该论文采用 GPSlayer = a hybrid MPNN+Transformer layer。

84cde68680146e506fb4be37d4376ca5.png

该设计与 GraphTrans 的不同在于,GraphTrans 在输入到 Transformer 之前先输入到一个包含若干层的 MPNNs 中,这可能会有 over-smoothing,over-squashing 以及 low expressivity against the WL test 的问题,也就是说这些层可能无法在早期保存一些信息 ,输入到 transfomer 的信息就会有缺失。GPS 的设计是每一层都是一层的 MPNN+transformer layer,然后反复堆叠 L 层。

e27cbd6a92fac87bd63cee6b6444f541.png

具体计算如下:

d9cb8103c773e7a6db882f861fdb4aac.png

9cd6d150a73280e3b6f5c4d4a4d5b5bb.png

bf14dcd48314ce352c9495f4a2816dce.png

利用 Linear transformer,GPS 可以将时间复杂度降到 。

实验结果

1. 使用不同的 Transformer,MPNN:可以发现不使用 MPNN 掉点很多,使用 Transformer 可以带来性能提升。

1e5e528b75e29918d43a682e200a153c.png

▲ 消融实验:使用不同的 transformer,MPNN

2. 使用不同的 PE/SE:在低计算成本下,使用 RWSE 效果最佳。如果不介意计算开销可以使用效果更好的 。

e4ac1062ba8dca7c1548b3ff48d37835.png

▲ 消融实验:使用不同的PE\SE

3. Benchmarking GPS

3.1 Benchmarking GNNs

25726df9922abf9f670f2d3496a7d9a0.png

3.2 Open Graph Benchmark

76da6ebc0cbc9468cab38db0740b0037.png

▲ Open Graph Benchmark

3.3 OGB-LSC PCQM4Mv2

abef4806de260cd2788cc87ce01a4276.png

方法汇总

注:这篇文章主要汇总的是同质图上的 graph transformers,目前也有一些异质图上 graph transformers 的工作,感兴趣的读者自行查阅哈。

  1. 图上不同的 transformers 的主要区别在于(1)如何设计 PE,(2)如何利用结构信息(结合 GNN 或者利用结构信息去修正 attention score, etc)。

  2. 现有的方法基本都针对 small graphs(最多几百个节点),Graph-BERT 虽然针对节点分类任务,但是首先会通过 sampling 得到子图,这会损害性能(比 GAT 多了很多参数,但性能是差不多的),能否设计一种针对大图的 transformer 还是一个比较难的问题。

49353af16aa439300d68e7b70b89ce2c.png

▲ 各方法的区别

outside_default.png

参考文献

outside_default.png

[1] GRAPH-BERT: Only Attention is Needed for Learning Graph Representations:https://github.com/jwzhanggy/Graph-Bert

[2] (GROVER) Self-Supervised Graph Transformer on Large-Scale Molecular Data:https://github.com/tencent-ailab/grover

[3] (GT) A Generalization of Transformer Networks to Graphs:https://github.com/graphdeeplearning/graphtransformer

[4] GraphiT: Encoding Graph Structure in Transformers [Code is unavailable]

[5] (GraphTrans) Representing Long-Range Context for Graph Neural Networks with Global Attention:https://github.com/ucbrise/graphtrans

[6] (SAN) Rethinking Graph Transformers with Spectral Attention [Code is unavailable]

[7] (Graphormer) Do Transformers Really Perform Bad for Graph Representation?:https://github.com/microsoft/Graphormer

[8] (SAT) Structure-Aware Transformer for Graph Representation Learning [Code is unavailable] 

[9] (GraphGPS) Recipe for a General, Powerful, Scalable Graph Transformer:https://github.com/rampasek/GraphGPS


其他资料

[1] Graph Transformer综述:https://arxiv.org/abs/2202.08455 [ Code]

[2] Tutorial: [Arxiv 2022,06] A Bird's-Eye Tutorial of Graph Attention Architectures:https://arxiv.org/pdf/2206.02849.pdf

[3] Dataset: [Arxiv 2022,06]Long Range Graph Benchmark [ Code]:https://arxiv.org/pdf/2206.08164.pdf

简介:GNN 一般只能捕获 k-hop 的邻居,而可能无法捕获长距离依赖信息,Transformer 可以解决这一问题。该 benmark 共包含五个数据集(PascalVOC-SP, COCO-SP, PCQM-Contact, Peptides-func and Peptides-struct),需要模型能捕获长距离依赖才能取得比较好的效果,该数据集主要用来验证模型捕获 long range interactions 的能力。

还有一些同质图上Graph Transformers的工作,感兴趣的同学自行阅读:

[1] [KDD 2022] Global Self-Attention as a Replacement for Graph Convolution:https://arxiv.org/pdf/2108.03348.pdf

[2] [ICOMV 2022] Experimental analysis of position embedding in graph transformer networks:https://www.spiedigitallibrary.org/conference-proceedings-of-spie/12173/121731O/Experimental-analysis-of-position-embedding-in-graph-transformer-networks/10.1117/12.2634427.short

[3] [ICLR Workshop MLDD] GRPE: Relative Positional Encoding for Graph Transformer [Code]:https://arxiv.org/abs/2201.12787

[4] [Arxiv 2022,05] Your Transformer May Not be as Powerful as You Expect [Code]:https://arxiv.org/pdf/2205.13401.pdf

[5] [Arxiv 2022,06] NAGphormer: Neighborhood Aggregation Graph Transformer for Node Classification in Large Graphs:https://arxiv.org/abs/2206.04910

更多阅读

ed7a75e35cef8e544b6df2eb3f532bef.png

095e066cc45132c53e0648d7380f3cac.png

b60a83e6cab8c7a3a6f5d907f790023c.png

fe58fec05bde9594c78c8ac78e6094f4.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

369d654a477436e80aa9f6238df94e0a.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

c61008cc2550b39c0798cb6c2324bd8b.jpeg

"Structure-Aware Transformer for Graph Representation Learning"是一篇使用Transformer模型进行图表示学习的论文。这篇论文提出了一种名为SAT(Structure-Aware Transformer)的模型,它利用了图中节点之间的结构信息,以及节点自身的特征信息。SAT模型在多个图数据集上都取得了非常好的结果。 以下是SAT模型的dgl实现代码,代码中使用了Cora数据集进行示例: ``` import dgl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(GraphAttentionLayer, self).__init__() self.num_heads = num_heads self.out_dim = out_dim self.W = nn.Linear(in_dim, out_dim*num_heads, bias=False) nn.init.xavier_uniform_(self.W.weight) self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1))) nn.init.xavier_uniform_(self.a.data) def forward(self, g, h): h = self.W(h).view(-1, self.num_heads, self.out_dim) # Compute attention scores with g.local_scope(): g.ndata['h'] = h g.apply_edges(fn.u_dot_v('h', 'h', 'e')) e = F.leaky_relu(g.edata.pop('e'), negative_slope=0.2) g.edata['a'] = torch.cat([e, e], dim=1) g.edata['a'] = torch.matmul(g.edata['a'], self.a).squeeze() g.edata['a'] = F.leaky_relu(g.edata['a'], negative_slope=0.2) g.apply_edges(fn.e_softmax('a', 'w')) # Compute output features g.ndata['h'] = h g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) h = g.ndata['h'] return h.view(-1, self.num_heads*self.out_dim) class SATLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(SATLayer, self).__init__() self.attention = GraphAttentionLayer(in_dim, out_dim, num_heads) self.dropout = nn.Dropout(0.5) self.norm = nn.LayerNorm(out_dim*num_heads) def forward(self, g, h): h = self.attention(g, h) h = self.norm(h) h = F.relu(h) h = self.dropout(h) return h class SAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_heads): super(SAT, self).__init__() self.layer1 = SATLayer(in_dim, hidden_dim, num_heads) self.layer2 = SATLayer(hidden_dim*num_heads, out_dim, 1) def forward(self, g, h): h = self.layer1(g, h) h = self.layer2(g, h) return h.mean(0) # Load Cora dataset from dgl.data import citation_graph as citegrh data = citegrh.load_cora() g = data.graph features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.BoolTensor(data.train_mask) val_mask = torch.BoolTensor(data.val_mask) test_mask = torch.BoolTensor(data.test_mask) # Add self loop g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) # Define model and optimizer model = SAT(features.shape[1], 64, data.num_classes, 8) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Train model for epoch in range(200): model.train() logits = model(g, features) loss = F.cross_entropy(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc = (logits[val_mask].argmax(1) == labels[val_mask]).float().mean() if epoch % 10 == 0: print('Epoch {:03d} | Loss {:.4f} | Accuracy {:.4f}'.format(epoch, loss.item(), acc.item())) # Test model model.eval() logits = model(g, features) acc = (logits[test_mask].argmax(1) == labels[test_mask]).float().mean() print('Test accuracy {:.4f}'.format(acc.item())) ``` 在这个示例中,我们首先加载了Cora数据集,并将其转换为一个DGL图。然后,我们定义了一个包含两个SAT层的模型,以及Adam优化器。在训练过程中,我们使用交叉熵损失函数和验证集上的准确率来监控模型的性能。在测试阶段,我们计算测试集上的准确率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值