机器学习第四十九周周报 GT

week49 GY

摘要

本周阅读了题为Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns的论文。该文将图的预训练性能不理想归因于预训练与下游数据集之间的结构分歧。此外,研究将这种差异的原因确定为预训练和下游图之间生成模式的差异。基于理论分析,该文提出了一种基于graphon的GNN微调策略G-TUNING,以使预训练的模型适应下游数据集。最后,实证证明了G-TUNING的有效性。

Abstract

This week’s weekly newspaper decodes the paper entitled Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns. This paper attributes the unsatisfactory pre-training performance of graphs to the structural discrepancy between pre-training and downstream datasets. Furthermore, the study identifies the cause of this disparity as the difference in generative patterns between pre-trained and downstream graphs. Based on theoretical analysis, this paper proposes a graphon-based GNN fine-tuning strategy named G-TUNING to adapt the pre-trained model to downstream datasets. Finally, empirical evidence demonstrates the effectiveness of G-TUNING.

1. 题目

标题:Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns

作者:Yifei Sun, Qi Zhu, Yang Yang, Chunping Wang, Tianyu Fan, Jiajun Zhu, Lei Chen

发布:AAAI2024

链接:https://export.arxiv.org/abs/2312.13583

代码链接:https://github.com/zjunet/G-Tuning

2. Abstract

该文认为结构差异的根本原因是预训练图下游图之间生成模式的差异。此外,提出了G-TUNING来保留下游图的生成模式。给定一个下游图G,核心思想是调整预训练的GNN,以便它可以重建G的生成模式,即图元w。然而,已知图元的精确重建在计算上是昂贵的。为了克服这一挑战,提供了一个理论分析,建立了一组称为graphon bases的替代graphon的存在。通过利用这些graphon bases的线性组合,可以有效地近似w。这一理论发现构成了提出模型的基础,因为它可以有效地学习graphon bases及其相关系数。与现有算法相比,G-TUNING在域内和域外迁移学习实验上分别平均提高了0.5%和2.6%。

3. 网络结构

image-20240727101959921

3.1 graphon

graphon是“图函数”的缩写,可以解释为具有不可数节点数的图或图生成模型的概括。

表示图生成模式 P ( G ; Θ ) P(G;Θ) P(G;Θ)。形式上,图元是一个连续对称函数 W : [ 0 , 1 ] 2 → [ 0 , 1 ] W:[0,1]2→[0,1] W:[0,1]2[0,1]。给定两个点 u i , u j ∈ [ 0 , 1 ] u_i, u_j∈[0,1] ui,uj[0,1]作为“节点”, W ( i , j ) ∈ [ 0 , 1 ] W(i, j)∈[0,1] W(i,j)[0,1]​表示它们形成一条边的概率。

graphon的主要思想是,当从观察到的图中提取子图时,随着子图大小的增加,这些子图的结构与观察到的图的结构越来越相似。然后结构在某种意义上收敛到一个极限对象,graphon。收敛性通过同态密度的收敛性来定义。同态密度 t ( F , G ) t(F, G) t(F,G)用来度量图F在图G中同态出现的相对频率: t ( F , G ) = ∣ h m ( F , G ) ∣ ∣ V G ∣ ∣ V H ∣ t(F, G) = |h_m (F,G)| |VG||VH| t(F,G)=hm(F,G)∣∣VG∣∣VH,可以看作是顶点从F到G的随机映射是同态的概率。因此,收敛性可以形式化为 l i m n → ∞ t ( F , G n ) = t ( F , W ) lim_{n→∞}t(F, G_n) = t(F,W) limnt(F,Gn)=t(F,W)。当作为图的生成模式时,从 P ( G ; W ) P(G;W) P(G;W)​中抽取N个节点的图G的邻接矩阵A如下:
v ∼ U ( 0 , 1 ) , v ∈ V ; A i j ∼ Ber ( W ( v i , v j ) ) , ∀ i , j ∈ [ N ] v\sim \mathbb U(0,1), v\in V;A_{ij}\sim \text{Ber}(W(v_i,v_j)),\forall i,j\in [N] vU(0,1),vV;AijBer(W(vi,vj)),i,j[N]
现有的研究主要采用二维阶跃函数来表示graphon,该阶跃函数可以看作矩阵

根据上述工作,采用阶跃函数 W ∈ [ 0 , 1 ] D × D W∈[0,1]^{D×D} W[0,1]D×D来表示一个graphon,其中D是一个超参数。

3.2 框架概览

G-TUNING旨在通过保留生成模式来使预训练的GNN适应微调图。在微调过程中,预训练的GNN Φ获得下游图 G t = G 1 , … , G n G_t = {G_1,…, G_n} Gt=G1Gn,并将它们馈送到任务特定层 f φ f_φ fφ中,用微调标签y进行训练。对于特定图 G i ( a , X ) G_i(a,X) Gi(a,X)​,通过预训练模型Φ获得预训练节点嵌入H:
L t a s k = L C E ( f ϕ ( H ) , Y ) H = Φ ( A , X ) \mathcal L_{task}=\mathcal L_{CE}(f_{\phi}(H),Y)\quad H=\Phi(A,X) Ltask=LCE(fϕ(H),Y)H=Φ(A,X)
原有策略可能无法提高微调性能,因为预训练和微调图之间存在较大差异,即负迁移。为了缓解这一问题,建议通过重建下游图W(图2中的整体工作流程),使预训练的GNNΦ能够保留下游图Gt的生成模式。在微调开始时,嵌入H也包含来自预训练数据的偏差。故需要下游图的H和图结构A来重建图。

具体来说,设计了一个graphon重建模块Ω来重建。因此,graphon重构模块Ω通过 L a u x \mathcal L_{aux} Laux对每个下游图逼近一个估计的oracle graphon(即 W ∈ [ 0 , 1 ] D × D W∈[0,1]D×D W[0,1]D×D), D为oracle graphon的大小。最后,在G-TUNING的框架下(图2),利用下游任务损失和重构损失来优化预训练的GNN编码器Φ、fϕ层和graphon重构模块Ω的参数,如下所示:
L = L t a s k + λ L G-TUNING ( W , W ^ ) \mathcal L=\mathcal L_{task}+\lambda \mathcal L_{\text{G-TUNING}}(W,\hat W) L=Ltask+λLG-TUNING(W,W^)
近似graphon的一种直接的方法是学习一个映射函数图结构A和节点嵌入H到目标w。

首先建立了一个graphon分解定理,并利用它进行高效的graphon近似。具体来说,提出任何graphon都可以通过graphon base B k ∈ B Bk∈B BkB​的线性组合来重构。

综上所述,设计了图形重构模块Ω作为另一个GNN,将编码节点表示H和图形结构A转换为系数 α = { α 1 , … α C } α = \{α1,…α_C\} α={α1αC}:
α = Ψ ( A , H ) \alpha =\Psi (A,H) α=Ψ(A,H)
复杂性分析。现在分析G-TUNING除了普通调优之外的额外时间复杂度。设|V|和|E|为节点和边的平均数目,d为隐藏维数,C为graphon基数。G-TUNING的总时间复杂度包括两个部分:(i)石墨解码器耗时 O ( C M 2 ∣ V ∣ d ) O(CM^2 |V|d) O(CM2Vd);(ii) oracle graphon估计成本 O ( ∣ E ∣ D ∣ V ∣ D 2 ) O(|E|D |V|D^2) O(EDVD2) ,其中D为oracle graphon的大小。因此,总体的额外时间复杂度为 O ( ∣ E ∣ D ∣ V ∣ D 2 C M 2 ∣ V ∣ D ) O(|E|D |V|D^2 CM^2 |V|D) O(EDVD2CM2VD),假设M,D≪|V|,这与普通调谐过程的 O ( ∣ E ∣ D ∣ V ∣ D ) O(|E|D |V|D) O(EDVD)是相同的数量级。

4. 文献解读

4.1 Introduction

在本文中,目标是通过提出一种微调策略G-TUNING来解决这些挑战,该策略与预训练数据和算法无关。具体来说,它在调优期间执行下游图的图重建。为了实现有效的重建,提供了一个理论结果(定理1),即给定一个图元W,有可能找到一组其他的图元,称为图元基,其线性组合可以接近W。然后,开发了一个图元解码器,将嵌入从预训练模型转换为一组系数。这些系数与结构感知的可学习基相结合,形成重构图。为了确保重建图形的保真度,引入了基于GW差异的损耗,从而最小化了近似图形与oracle graphon之间的距离(Xu et al. 2021)。此外,通过优化提出的G-TUNING,我们获得了与任务相关的判别子图的可证明结果(定理2)。

4.2 创新点

该文的主要贡献有四个方面:

  1. 确定下游图的生成模式是弥合预训练和微调之间差距的关键步骤。
  2. 基于理论结果,设计了模型架构G-TUNING,以有效地将graphon重构为具有严格泛化结果的生成模式。
  3. 从经验上看,该方法在8个域内和7个域外迁移学习数据集上比最佳基线平均提高了0.47%和2.62%。

4.3 实验过程

目标是在两种设置下实际评估G-TUNING在15个数据集上的性能。

具体来说,回答以下问题:

  • (有效性)G-TUNING是否提高了微调的性能
  • (可转移性)G-TUNING能比基线更好地实现可转移性吗
  • (完整性)G-TUNING的每个组成部分对性能的贡献是什么
  • (效率)G-TUNING能否在可接受的时间消耗下提高微调的性能

基线。有大量的GNN预训练方法,但只有少数的微调策略可用。

  • 几个最初为cnn设计的代表性基线,包括StochNorm, DELTA和固定注意力系数的版本(Feature-Map), L2_SP和BSS。
  • 一个基线致力于改善gnn的微调,这与预训练策略无关,即gtt -tuning
  • 为了验证重建图的有效性,引入了VGAE- tuning进行比较,该方法使用VGAE 作为辅助损失来重建下游图的邻接矩阵。

根据作者发布的代码重现基线,并根据他们发布的代码和他们论文中描述的设置设置超参数。

4.3.1 有效性

评估了G-TUNING在分子性质预测任务上的有效性。使用无监督上下文预测任务预训练的模型作为骨干模型。具体来说,在带有200万个未标记分子的ZINC15数据集(Sterling and Irwin 2015)上通过自监督上下文预测任务预训练GIN (Xu et al. 2019)。接下来,对从MoleculeNet (Wu et al. 2018)获得的8个二元分类数据集进行骨架模型的微调。用的是8:1:1比例的支架。由于框架对骨干gnn是不可知的,专注于评估我们的模型是否达到更好的微调结果。

对于每个数据集,运行5次,并报告具有相应标准差的平均ROCAUC。

image-20240727110155957

表1显示,G-TUNING在8个基线数据集中获得了6个最佳性能,平均排名最高。注意到,从预训练模型(如Feature-Map或DELTA)中约束嵌入有时会带来比普通调优更差的性能。从监督学习和G-TUNING w/o Pre之间的比较来看,尽管在将G-TUNING损失应用于监督学习时可能偶尔会出现性能轻微下降的情况,但大多数监督训练经验都受益于G-TUNING损失。从原型调优和表的最后两行比较可以看出,未经预训练的G-TUNING的性能低于经过预训练的G-TUNING,但有时优于原型调优。结果一般证明,在数据集来自相同域的情况下,G-TUNING可以通过保留生成模式来补偿结构差异,从而获得更好的性能。

4.3.2 可转移性

在跨域设置中评估G-TUNING,其中预训练数据集和下游数据集不是来自同一域。较大的结构差异反过来会降低性能。故采用GCC (Qiu et al. 2020)作为主干模型,其子图判别作为预训练任务。根据GCC的设置,从7个不同数据集上进行预训练,并在7个下游图分类基准上评估方法:IMDB-M, IMDB-B, MUTAG, PROTEINS, ENZYMES, MSRC_21和来自tudatasset的RDT-M12K (Morris et al. 2020)。这些数据集涵盖了广泛的领域。报告了10倍交叉验证的结果。

image-20240727110451229

从表2中,发现模型在7个数据集中的6个上优于所有基线,并且在MUTAG上呈现出具有竞争力的结果(比最好的低1.92%)。与原型调优和第二好的基线相比,G-TUNING分别提高了7.63%和4.71%的蛋白质性能。与之前的实验相比,可以观察到G-TUNING有了更实质性的改进(表1),因为我们明确地保留了生成模式。尽管GTOT也包含结构信息,但它有时甚至比普通调优(即蛋白质和酶)的性能更差。一般来说,当预训练和微调图显示出较大的结构差异时,G-TUNING清楚地表明了它的有效性。

4.3.3 消融研究

image-20240727111050629

  • 首先,通过与直接重建图元(Direct-Rec)进行比较,来检验所提出的图元分解方法的有效性。结果表明,原型调优之外的改进是有限的,在某些情况下甚至存在负迁移。原因可能是直接重建复杂的语义信息和从A和h中捕获graphon属性的困难。在四个数据集中,观察到G-TUNING总是优于“Direct-Rec”。
  • 接下来,比较了不同的GNN架构(两层MLP, GCN(Welling and Kipf 2016), GraphSAGE(Hamilton, Ying, and Leskovec 2017)和GAT (Veli ckovi等人2018))与默认主干(即GIN (Xu等人2019))。在图3中,观察到MLPencoder表现最差,这证明了结合结构信息重构graphon的有效性。
  • 最后,用KL散度、Wasserstein距离和余弦相似度代替损失。可以观察到,GW差异损失显著优于其他。认为余弦相似度可能对指边的概率的绝对值不敏感。由于KL散度不满足交换律,重构图时难以收敛。虽然Wasserstein距离也是基于最优输运,但它无法捕捉到两个graphon之间的几何形状。

image-20240727111120077

超参研究:还研究了2 ~ 512的不同graphon bases的影响。更多的基数可以表示更多的信息,并且可以更好地近似oracle graphon。图4显示,当基数从2个增加到32个时,性能有所提高。然而,当数量继续增加时,改善变得越来越小。将这种现象归因于参数数量增加带来的优化难度。此外,随着碱基数量的增加,G-TUNING的运行时间呈指数增长(绿色曲线)。因此,G-TUNING只需要少量的基数就可以提高微调性能。

4.3.4 运行时间
  1. 进行了运行时间对比(表3)。表3中给出了该文方法与基线的运行时间对比。时间复杂度主要由两部分组成:(i)预训练模型的graphon逼近和(ii) oracle graphon估计。
  2. 现在报告了在域内设置下数据集上我们的方法和基线的计算效率,而不会失去一般性(每个训练epoch的秒数)。

image-20240727111435580

从表3中,可以看到G-TUNING在大多数情况下并不是最慢的调优方法。可以观察到,该方法的时间接近于vgae-tuning。正如第一行图的数量所示,很明显,随着图数量的增加,该方法的时间消耗变得与其他基线更具可比性。这意味着该方法具有出色的可扩展性。因此,该方法的时间消耗保持在一个可接受的范围内。

5. 结论

在该文中,将图的预训练性能不理想归因于预训练与下游数据集之间的结构分歧。此外,将这种差异的原因确定为预训练和下游图之间生成模式的差异。基于理论分析,提出了一种基于graphon的GNN微调策略G-TUNING,以使预训练的模型适应下游数据集。最后,实证证明了G-TUNING的有效性。

6.代码复现

https://github.com/zjunet/G-Tuning

Mole-based model

import random

from ot_distance import sliced_fgw_distance, fgw_distance
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax, subgraph
import torch_geometric.utils as PyG_utils
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
import torch.nn as nn
# from torch_geometric.nn.conv import GATConv
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros

from torch_geometric.datasets import TUDataset
from abc import ABC

num_atom_type = 120  # including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 6  # including aromatic and self-loop edge, and extra masked tokens
num_bond_direction = 3


class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not.


    See https://arxiv.org/abs/1810.00826
    """

    def __init__(self, emb_dim, aggr="add"):
        super(GINConv, self).__init__()
        # multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(),
                                       torch.nn.Linear(2 * emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        # edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        # add features corresponding to self-loop edges.
        # self_loop_attr = torch.zeros(edge_index[0].size(0), 2)
        # self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        # self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        # edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])
        try:  # PyG 1.6.
            return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)
        except:  # PyG 1.0.3
            return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)


class GCNConv(MessagePassing):

    def __init__(self, emb_dim, aggr="add"):
        super(GCNConv, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        self.aggr = aggr

    def norm(self, edge_index, num_nodes, dtype):
        ### assuming that self-loops have been already added in edge_index
        edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                                 device=edge_index.device)
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        # edge_index, edge_weight = add_self_loops(edge_index, num_nodes=x.size(0)) pyg 1.6
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])
        try:
            norm = self.norm(edge_index, x.size(0), x.dtype)
        except:
            norm = self.norm(edge_index[0], x.size(0), x.dtype)

        x = self.linear(x)
        try:  # PyG 1.6.
            return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
        except:  # PyG 1.0.3
            return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm=norm)
        # return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm=norm)
        # return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, norm = norm)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * (x_j + edge_attr)


class GATConv(MessagePassing):
    def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr="add"):
        super(GATConv, self).__init__()

        self.aggr = aggr

        self.emb_dim = emb_dim
        self.heads = heads
        self.negative_slope = negative_slope

        self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim)
        self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim))

        self.bias = torch.nn.Parameter(torch.Tensor(emb_dim))

        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        edge_ind = add_self_loops(edge_index, num_nodes=x.size(0))
        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])

        x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
        # edge_ind = edge_ind[0]
        print("edge_index", edge_ind)
        return self.propagate(self.aggr, edge_index=edge_ind, x=x, edge_attr=edge_embeddings)

    def message(self, edge_index, x_i, x_j, edge_attr):
        edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
        x_j += edge_attr

        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index[0])

        return x_j * alpha.view(-1, self.heads, 1)

    def update(self, aggr_out):
        aggr_out = aggr_out.mean(dim=1)
        aggr_out = aggr_out + self.bias

        return aggr_out


class GraphSAGEConv(MessagePassing):
    def __init__(self, emb_dim, aggr="mean"):
        super(GraphSAGEConv, self).__init__()

        self.emb_dim = emb_dim
        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:, 0] = 4  # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1])

        x = self.linear(x)

        try:  # PyG 1.6.
            return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)
        except:  # PyG 1.0.3
            return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return F.normalize(aggr_out, p=2, dim=-1)


class GNN(torch.nn.Module):
    """


    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """

    def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0, gnn_type="gin"):
        super(GNN, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.gnn_type = gnn_type
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)

        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr="add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))
            elif gnn_type == "gat":
                self.gnns.append(GATConv(emb_dim))
            elif gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    # def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        batch = None
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        elif len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1])

        h_list = [x]
        for layer in range(self.num_layer):

            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]

        return node_representation


class GNN_graphpred(torch.nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        drop_ratio (float): dropout rate
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set
        gnn_type: gin, gcn, graphsage, gat

    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536
    """

    def __init__(self, num_layer, emb_dim, num_tasks, JK="last", drop_ratio=0, graph_pooling="mean", gnn_type="gin",
                 backbone=None, args=None):
        '''
        backbone is gnn default
        '''
        super(GNN_graphpred, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.emb_f = None
        self.gnn_type = gnn_type
        self.param_args = args

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
        if backbone is None:
            self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type=gnn_type)
        else:
            self.gnn = backbone
        # self.backbone = self.gnn

        # Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            if self.JK == "concat":
                self.pool = GlobalAttention(gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
            else:
                self.pool = GlobalAttention(gate_nn=torch.nn.Linear(emb_dim, 1))
        elif graph_pooling[:-1] == "set2set":
            set2set_iter = int(graph_pooling[-1])
            if self.JK == "concat":
                self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
            else:
                self.pool = Set2Set(emb_dim, set2set_iter)
        else:
            raise ValueError("Invalid graph pooling type.")

        # For graph-level binary classification
        if graph_pooling[:-1] == "set2set":
            self.mult = 2
        else:
            self.mult = 1

        if self.JK == "concat":
            self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)

    def from_pretrained(self, model_file):
        # self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
        model = torch.load(model_file, map_location='cpu')
        self.gnn.load_state_dict(model)  # self.args.device))

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.gnn(x, edge_index, edge_attr)
        self.emb_f = self.pool(node_representation, batch)
        return self.graph_pred_linear(self.emb_f)

    def get_bottleneck(self):
        return self.emb_f


class GraphonEncoder(torch.nn.Module):
    def __init__(self, feature_length, hidden_size, out_size):
        super(GraphonEncoder, self).__init__()
        self.feature_length, self.hidden_size, self.out_size = feature_length, hidden_size, out_size
        self.fc1 = torch.nn.Linear(feature_length, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, out_size)

    def forward(self, x):
        x = x.view(-1, self.feature_length)
        # print(x, x.shape)
        # print(edge_index, edge_index.shape)
        x = F.dropout(x, p=0.9, training=self.training)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.9, training=self.training)
        x = self.fc2(x)
        return x


def sampling_gaussian(mu, logvar, num_sample):
    std = torch.exp(0.5 * logvar)
    samples = None
    for i in range(num_sample):
        eps = torch.randn_like(std)
        if i == 0:
            samples = mu + eps * std
        else:
            samples = torch.cat((samples, mu + eps * std), dim=0)
    return samples


def sampling_gmm(mu, logvar, num_sample):
    std = torch.exp(0.5 * logvar)
    n = int(num_sample / mu.size(0)) + 1
    samples = None
    for i in range(n):
        eps = torch.randn_like(std)
        if i == 0:
            samples = mu + eps * std
        else:
            samples = torch.cat((samples, mu + eps * std), dim=0)
    return samples[:num_sample, :]


class Prior(nn.Module, ABC):
    def __init__(self, data_size: list, prior_type: str = 'gmm'):
        super(Prior, self).__init__()
        # data_size = [num_component, z_dim]
        self.data_size = data_size
        self.number_components = data_size[0]
        self.output_size = data_size[1]
        self.prior_type = prior_type
        if self.prior_type == 'gmm':
            self.mu = nn.Parameter(torch.randn(data_size), requires_grad=True)
            self.logvar = nn.Parameter(torch.randn(data_size), requires_grad=True)
        else:
            self.mu = nn.Parameter(torch.zeros(1, self.output_size), requires_grad=False)
            self.logvar = nn.Parameter(torch.ones(1, self.output_size), requires_grad=False)

    def forward(self):
        return self.mu, self.logvar

    def sampling(self, num_sample):
        if self.prior_type == 'gmm':
            return sampling_gmm(self.mu, self.logvar, num_sample)
        else:
            return sampling_gaussian(self.mu, self.logvar, num_sample)


class GraphonNewEncoder(torch.nn.Module):
    def __init__(self, feature_length, hidden_size, out_size, encoder_type):
        super(GraphonNewEncoder, self).__init__()
        self.feature_length, self.hidden_size, self.out_size = feature_length, hidden_size, out_size
        self.encoder_type = encoder_type
        self.fc1 = torch.nn.Linear(feature_length, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, out_size)
        # self.gnn = GNN(2, feature_length, "last", 0.2, gnn_type='gin')
        self.gnns_en = torch.nn.ModuleList()

        for layer in range(2):
            if encoder_type == "gin":
                self.gnns_en.append(GINConv(feature_length, aggr="add"))
            elif encoder_type == "gcn":
                self.gnns_en.append(GCNConv(feature_length))
            elif encoder_type == "gat":
                self.gnns_en.append(GATConv(feature_length))
            elif encoder_type == "graphsage":
                self.gnns_en.append(GraphSAGEConv(feature_length))

        self.fc3 = torch.nn.Linear(feature_length, out_size)
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(2):
            self.batch_norms.append(torch.nn.BatchNorm1d(feature_length))

    def forward(self, x, edge_index, edge_attr, batch):
        # node_representation = self.gnn(x, edge_index.long(), None)
        # if self.encoder_type == 'gat':
        #     print("before", edge_index.shape, edge_index.dtype, edge_index)
        #     edge_index = edge_index.type(torch.LongTensor)
        #     print("after", edge_index.shape, edge_index.dtype, edge_index)
        # print("self.encoder_type", self.encoder_type)
        if self.encoder_type != 'mlp':
            h_list = [x, ]
            for layer in range(2):
                inp = h_list[layer]
                h = self.gnns_en[layer](inp, edge_index, edge_attr)
                h = self.batch_norms[layer](h)
                # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
                if layer == 1:
                    # remove relu for the last layer
                    h = F.dropout(h, 0.2, training=self.training)
                else:
                    h = F.dropout(F.relu(h), 0.2, training=self.training)
                h_list.append(h)

            x = h_list[-1] \
                + h_list[0]

            x = global_mean_pool(x, batch)
            x = self.fc3(x)
        else:
            x = x.view(-1, self.feature_length)
            # print(x, x.shape)
            # print(edge_index, edge_index.shape)
            x = F.dropout(x, p=0.9, training=self.training)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, p=0.9, training=self.training)

            x = global_mean_pool(x, batch)
            x = self.fc2(x)
        return x


class GraphonFactorization(torch.nn.Module, ABC):
    def __init__(self, num_factors: int, graphs: TUDataset, seed: int, param_args, node_type: str = 'categorical'):
        """
        A basic graphon model based on Fourier transformation
        Args:
            num_factors: the number of sin/cos bases for one graphon
            graphs: the graphs used as the prior of the model
            seed: random seed
            node_type: 'binary', 'categorical' and 'continuous'
        """
        super(GraphonFactorization, self).__init__()
        self.num_factors = num_factors
        self.node_type = node_type
        self.factors_graphon = nn.ParameterList()
        self.factors_signal = nn.ParameterList()
        self.num_partitions = []
        indices = list(range(len(graphs)))
        random.seed(seed)
        random.shuffle(indices)
        # indices = np.random.RandomState(seed).permutation(len(graphs))
        for c in range(self.num_factors):
            sample = graphs[indices[c]]
            adj = torch.sparse_coo_tensor(sample.edge_index,
                                          torch.ones(sample.edge_index.shape[1]),
                                          size=[sample.x.shape[0], sample.x.shape[0]])
            adj = adj.to_dense()
            # print(adj.shape)
            if len(adj.shape) > 2:
                adj = adj.sum(2)
            # attribute = sample.x
            degrees = torch.sum(adj, dim=1)
            idx = torch.argsort(degrees)
            # print(idx.shape)
            adj = adj[idx, :][:, idx]
            # attribute = attribute[idx, :]
            num_partitions = adj.shape[0]
            graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)
            # if self.node_type == "binary" or "categorical":
            #     signal = nn.Parameter(data=(attribute - 0.5), requires_grad=True)
            # else:
            #     signal = nn.Parameter(data=attribute, requires_grad=True)
            self.num_partitions.append(num_partitions)
            self.factors_graphon.append(graphon)
            # self.factors_signal.append(signal)
        # self.dim = self.factors_signal[0].shape[1]
        self.sigmoid = nn.Sigmoid()
        # self.relu = nn.ReLU()
        self.softmax0 = nn.Softmax(dim=0)
        self.softmax1 = nn.Softmax(dim=1)
        self.softmax2 = nn.Softmax(dim=2)

        # 这里三层softmax是?

        self.batch_size = param_args['batch_size']
        self.fc = torch.nn.Linear(param_args['batch_size'], 1)
        # x = (torch.arange(0, 100) + 0.5).view(1, -1) / 100
        # self.register_buffer('positions', x)
        self.num_components =param_args['n_components']
        self.prior_type = param_args['prior_type']
        self.prior = Prior(data_size=[self.num_components, self.num_factors],
                           prior_type=self.prior_type)

    def sampling_z(self, num_samples):
        return self.prior.sampling(num_samples)

    def sampling(self, vs: torch.Tensor):
        """
        Sampling graphon factors
        Args:
            vs: (n_nodes)
        Returns:
            graphons: (n_factors, n_nodes, n_nodes)
            signals: (n_factors, n_nodes, n_nodes)
        """
        n_nodes = vs.shape[0]
        graphons = torch.zeros(self.num_factors, n_nodes, n_nodes).to(vs.device)
        # signals = torch.zeros(1, self.num_factors, n_nodes, self.dim).to(vs.device)
        for c in range(self.num_factors):
            idx = torch.floor(self.num_partitions[c] * vs).long()
            graphons[c, :, :] = self.factors_graphon[c][idx, :][:, idx]
            # signals[0, c, :, :] = self.factors_signal[c][idx, :]
        graphons = self.sigmoid(graphons)
        # return graphons, signals
        return graphons

    def forward(self, zs: torch.Tensor, vs: torch.Tensor):
        """
        Given a graphon model, sample a batch of graphs from it
        Args:
            zs: (batch_size, n_factors) latent representations
            vs: (n_nodes) random variables ~ Uniform([0, 1])

        Returns:
            graphon: (batch_size, n_nodes, n_nodes)
            signal: (batch_size, n_nodes, dim)
            graph: (batch_size, n_nodes, n_nodes) adjacency matrix
            attribute: (batch_size, n_nodes, dim) node attributes
        """
        tzs = zs.t()
        tzs_pad = tzs
        if self.batch_size - tzs.shape[1] != 0:
            pad = torch.zeros(tzs.shape[0],
                              self.batch_size - tzs.shape[1],
                              device=tzs.device)
            tzs_pad = torch.cat((tzs, pad), dim=1)
        zs_hat_one = self.fc(tzs_pad)
        zs_hat = self.softmax1(zs_hat_one)
        # graphons, signals = self.sampling(vs)  # basis
        # print('zs_hat', zs_hat.shape)

        graphons_basis = self.sampling(vs)  # basis
        # graphons_basis = torch.sigmoid(graphons)  #TODO: here change to [0,1]
        assert (graphons_basis.max().item() <= 1
                and graphons_basis.min().item() >= 0)
        # print('graphons', graphons.shape)
        graphon_est = (zs_hat.view(self.num_factors, 1, 1) * graphons_basis).sum(
            0)  # ( n_nodes, n_nodes)
        # signal = (zs_hat.view(-1, self.num_factors, 1, 1) * signals).sum(1)  # (batch, n_nodes, dim)
        # if self.node_type == 'binary':
        #     signal = self.sigmoid(signal)
        # if self.node_type == 'categorical':
        #     signal = self.softmax2(signal)
        # graphs = torch.bernoulli(graphon)  # TODO: ???!?!
        # graphs += graphs.clone().permute(0, 2, 1)  # Change: add .clone()
        # graphs[graphs > 1] = 1
        # if self.node_type == "binary":
        #     attributes = torch.bernoulli(signal)
        # elif self.node_type == "categorical":
        #     distribution = torch.distributions.one_hot_categorical.OneHotCategorical(signal)
        #     attributes = distribution.sample()
        # else:
        #     distribution = torch.distributions.normal.Normal(signal, scale=2)
        #     attributes = distribution.sample()
        # return graphon, signal, graphs, attributes
        return graphon_est


class GraphonNewFactorization(torch.nn.Module, ABC):
    def __init__(self, num_factors: int, graphs_pre, graphs_down: TUDataset, seed: int, args,
                 node_type: str = 'categorical'):
        """
        A basic graphon model based on Fourier transformation
        Args:
            num_factors: the number of sin/cos bases for one graphon
            graphs: the graphs used as the prior of the model
            seed: random seed
            node_type: 'binary', 'categorical' and 'continuous'
        """
        super(GraphonNewFactorization, self).__init__()
        self.num_factors = num_factors
        self.node_type = node_type
        self.factors_graphon = nn.ParameterList()
        self.factors_signal = nn.ParameterList()
        self.num_partitions = []

        num_mul_nodes = args.nnodes * args.ngraphs
        indices_pre = list(range(len(graphs_pre)))
        random.seed(seed)
        random.shuffle(indices_pre)
        indices_pre = indices_pre[:num_mul_nodes]
        num_pre_factors = int(self.num_factors / 2)
        print('Dealing with pretrain basis')
        for c in range(num_pre_factors):
            sample = graphs_pre[indices_pre[c]]
            node_ids = list(range(sample.x.shape[0]))
            random.shuffle(node_ids)
            node_sample = node_ids[:num_mul_nodes]
            edge, _ = subgraph(node_sample, sample.edge_index, relabel_nodes=True)
            adj = torch.sparse_coo_tensor(edge,
                                          torch.ones(edge.shape[1]),
                                          size=[num_mul_nodes, num_mul_nodes])
            adj = adj.to_dense()
            # print(adj.shape)
            if len(adj.shape) > 2:
                adj = adj.sum(2)
            # attribute = sample.x
            degrees = torch.sum(adj, dim=1)
            idx = torch.argsort(degrees)
            # print(idx.shape)
            adj = adj[idx, :][:, idx]
            # attribute = attribute[idx, :]
            num_partitions = adj.shape[0]
            graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)
            self.num_partitions.append(num_partitions)
            self.factors_graphon.append(graphon)

        indices_down = list(range(len(graphs_down)))
        random.seed(seed)
        random.shuffle(indices_down)
        # indices = np.random.RandomState(seed).permutation(len(graphs))
        print('Dealing with downstream basis')
        for c in range(self.num_factors - num_pre_factors):
            sample = graphs_down[indices_down[c]]
            # adj = torch.sparse_coo_tensor(sample.edge_index,
            #                               torch.ones(sample.edge_index.shape[1]),
            #                               size=[sample.x.shape[0], sample.x.shape[0]])
            node_ids = list(range(sample.x.shape[0]))
            random.shuffle(node_ids)
            node_sample = node_ids[:num_mul_nodes]
            edge, _ = subgraph(node_sample, sample.edge_index, relabel_nodes=True)
            adj = torch.sparse_coo_tensor(edge,
                                          torch.ones(edge.shape[1]),
                                          size=[num_mul_nodes, num_mul_nodes])
            adj = adj.to_dense()
            # print(adj.shape)
            if len(adj.shape) > 2:
                adj = adj.sum(2)
            # attribute = sample.x
            degrees = torch.sum(adj, dim=1)
            idx = torch.argsort(degrees)
            # print(idx.shape)
            adj = adj[idx, :][:, idx]
            # attribute = attribute[idx, :]
            num_partitions = adj.shape[0]
            graphon = nn.Parameter(data=(adj - 0.5), requires_grad=True)
            # if self.node_type == "binary" or "categorical":
            #     signal = nn.Parameter(data=(attribute - 0.5), requires_grad=True)
            # else:
            #     signal = nn.Parameter(data=attribute, requires_grad=True)
            self.num_partitions.append(num_partitions)
            self.factors_graphon.append(graphon)
            # self.factors_signal.append(signal)
        # self.dim = self.factors_signal[0].shape[1]
        self.sigmoid = nn.Sigmoid()
        # self.relu = nn.ReLU()
        self.softmax0 = nn.Softmax(dim=0)
        self.softmax1 = nn.Softmax(dim=1)
        self.softmax2 = nn.Softmax(dim=2)

        # 这里三层softmax是?

        self.batch_size = args.batch_size
        self.fc = torch.nn.Linear(args.batch_size, 1)
        # x = (torch.arange(0, 100) + 0.5).view(1, -1) / 100
        # self.register_buffer('positions', x)
        self.num_components = args.n_components
        self.prior_type = args.prior_type
        self.prior = Prior(data_size=[self.num_components, self.num_factors],
                           prior_type=self.prior_type)

    def sampling_z(self, num_samples):
        return self.prior.sampling(num_samples)

    def sampling(self, vs: torch.Tensor):
        """
        Sampling graphon factors
        Args:
            vs: (n_nodes)
        Returns:
            graphons: (n_factors, n_nodes, n_nodes)
            signals: (n_factors, n_nodes, n_nodes)
        """
        n_nodes = vs.shape[0]
        graphons = torch.zeros(self.num_factors, n_nodes, n_nodes).to(vs.device)
        # signals = torch.zeros(1, self.num_factors, n_nodes, self.dim).to(vs.device)
        for c in range(self.num_factors):
            idx = torch.floor(self.num_partitions[c] * vs).long()
            graphons[c, :, :] = self.factors_graphon[c][idx, :][:, idx]
            # signals[0, c, :, :] = self.factors_signal[c][idx, :]
        graphons = self.sigmoid(graphons)
        # return graphons, signals
        return graphons

    def forward(self, zs: torch.Tensor, vs: torch.Tensor):
        """
        Given a graphon model, sample a batch of graphs from it
        Args:
            zs: (batch_size, n_factors) latent representations
            vs: (n_nodes) random variables ~ Uniform([0, 1])

        Returns:
            graphon: (batch_size, n_nodes, n_nodes)
            signal: (batch_size, n_nodes, dim)
            graph: (batch_size, n_nodes, n_nodes) adjacency matrix
            attribute: (batch_size, n_nodes, dim) node attributes
        """
        tzs = zs.t()
        tzs_pad = tzs
        if self.batch_size - tzs.shape[1] != 0:
            pad = torch.zeros(tzs.shape[0],
                              self.batch_size - tzs.shape[1],
                              device=tzs.device)
            tzs_pad = torch.cat((tzs, pad), dim=1)
        zs_hat_one = self.fc(tzs_pad)
        zs_hat = self.softmax1(zs_hat_one)
        # graphons, signals = self.sampling(vs)  # basis
        # print('zs_hat', zs_hat.shape)

        graphons_basis = self.sampling(vs)  # basis
        # graphons_basis = torch.sigmoid(graphons)  #TODO: here change to [0,1]
        assert (graphons_basis.max().item() <= 1
                and graphons_basis.min().item() >= 0)
        # print('graphons', graphons.shape)
        graphon_est = (zs_hat.view(self.num_factors, 1, 1) * graphons_basis).sum(
            0)  # ( n_nodes, n_nodes)
        # signal = (zs_hat.view(-1, self.num_factors, 1, 1) * signals).sum(1)  # (batch, n_nodes, dim)
        # if self.node_type == 'binary':
        #     signal = self.sigmoid(signal)
        # if self.node_type == 'categorical':
        #     signal = self.softmax2(signal)
        # graphs = torch.bernoulli(graphon)  # TODO: ???!?!
        # graphs += graphs.clone().permute(0, 2, 1)  # Change: add .clone()
        # graphs[graphs > 1] = 1
        # if self.node_type == "binary":
        #     attributes = torch.bernoulli(signal)
        # elif self.node_type == "categorical":
        #     distribution = torch.distributions.one_hot_categorical.OneHotCategorical(signal)
        #     attributes = distribution.sample()
        # else:
        #     distribution = torch.distributions.normal.Normal(signal, scale=2)
        #     attributes = distribution.sample()
        # return graphon, signal, graphs, attributes
        return graphon_est


def raml(graphons_hat, graphons_lbl, args):
    # adj = torch.sparse_coo_tensor(data.edge_index,
    #                               torch.ones(data.edge_index.shape[1]),
    #                               size=[data.x.shape[0], data.x.shape[0]])
    # adj = adj.to_dense()
    # if len(adj.shape) > 2:
    #     adj = adj.sum(2)
    # log_p_x = torch.zeros(graphons_hat.shape[0], args.n_graphs).to(graphons_hat.device)
    # for b in range(graphons_hat.shape[0]):
    # d_fgw = torch.zeros(args.n_graphs).to(graphons_hat.device)
    # adj2 = adj[data.batch == b, :][:, data.batch == b]
    # s2 = data.x[data.batch == b, :]

    # for k in range(args.n_graphs):
    #     adj0 = graphons_hat[b, k * args.n_nodes:(k + 1) * args.n_nodes, :][:,
    #            k * args.n_nodes:(k + 1) * args.n_nodes]
    #     # s0 = signals[b, k * args.n_nodes:(k + 1) * args.n_nodes, :]
    #     adj1 = graphons_lbl[b, k * args.n_nodes:(k + 1) * args.n_nodes, :][:,
    #            k * args.n_nodes:(k + 1) * args.n_nodes]

    # s1 = attributes[b, k * args.n_nodes:(k + 1) * args.n_nodes, :]
    # if node_type == 'binary':
    #     log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')
    # elif node_type == 'categorical':
    #     log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')
    # else:
    # log_p_x[b, k] = F.binary_cross_entropy(input=adj0, target=adj1, reduction='mean')
    # d_fgw[k] = fgw_distance(adj1, adj2, args)
    # d_fgw[k] = fgw_distance(adj1, adj0, args)
    # print(b, k, d_fgw[k])
    # print('graphons_hat, graphons_lbl')
    # print(graphons_hat.shape)
    # print(graphons_lbl.shape)
    d_fgw = fgw_distance(graphons_hat, graphons_lbl, args)
    # print('d_fgw', d_fgw)
    # q_x = F.softmax(-2 * d_fgw / torch.min(d_fgw), dim=0).detach()  # TODO: detach() ??
    # log_p_x[b, :] *= q_x
    # print(q_x.shape)
    # log_p_x[b, :] = q_x
    # return log_p_x.mean()
    return d_fgw


def distance_tensor(pts_src: torch.Tensor, pts_dst: torch.Tensor, p: int = 2):
    """
    Returns the matrix of ||x_i-y_j||_p^p.
    :param pts_src: [R, D] matrix
    :param pts_dst: [C, D] matrix
    :param p:
    :return: [R, C, D] distance matrix
    """
    x_col = pts_src.unsqueeze(1)
    y_row = pts_dst.unsqueeze(0)
    distance = torch.abs(x_col - y_row) ** p
    return distance


def sliced_fgw_distance(posterior_samples, prior_samples, num_projections=50, p=2, beta=0.1):
    # derive latent space dimension size from random samples drawn from latent prior distribution
    embedding_dim = prior_samples.size(1)
    # generate random projections in latent space
    projections = torch.randn(size=(embedding_dim, num_projections)).to(posterior_samples.device)
    projections /= (projections ** 2).sum(0).sqrt().unsqueeze(0)
    # calculate projections through the encoded samples
    posterior_projections = posterior_samples.matmul(projections)  # batch size x #projections
    prior_projections = prior_samples.matmul(projections)  # batch size x #projections
    posterior_projections = torch.sort(posterior_projections, dim=0)[0]
    prior_projections1 = torch.sort(prior_projections, dim=0)[0]
    prior_projections2 = torch.sort(prior_projections, dim=0, descending=True)[0]
    posterior_diff = distance_tensor(posterior_projections, posterior_projections, p=p)
    prior_diff1 = distance_tensor(prior_projections1, prior_projections1, p=p)
    prior_diff2 = distance_tensor(prior_projections2, prior_projections2, p=p)
    # print(posterior_projections.size(), prior_projections1.size())
    # print(posterior_diff.size(), prior_diff1.size())
    w1 = torch.sum((posterior_projections - prior_projections1) ** p, dim=0)
    w2 = torch.sum((posterior_projections - prior_projections2) ** p, dim=0)
    # print(w1.size(), torch.sum(w1))
    gw1 = torch.mean(torch.mean((posterior_diff - prior_diff1) ** p, dim=0), dim=0)
    gw2 = torch.mean(torch.mean((posterior_diff - prior_diff2) ** p, dim=0), dim=0)
    # print(gw1.size(), torch.sum(gw1))
    fgw1 = (1 - beta) * w1 + beta * gw1
    fgw2 = (1 - beta) * w2 + beta * gw2
    return torch.sum(torch.min(fgw1, fgw2))


if __name__ == "__main__":
    pass

小结

结构差异现象的核心根源在于预训练图目标应用图(或称为下游图)之间在生成模式上存在的显著差异。针对此问题,文中创新性地提出了G-TUNING方法,旨在保持并适应下游图的独特生成模式。具体而言,对于给定的下游图G,G-TUNING的核心策略是调整预训练的图神经网络(GNN),使其能够重新构造出G的生成模式,这里以图元w为代表。

然而,直接精确重建图元w在计算上极具挑战性且成本高昂。为突破这一瓶颈,文章进一步提供了深入的理论分析,证明了存在一组称为graphon bases的替代graphon,它们可以作为构建块来近似表示图元w。通过巧妙地利用这些graphon bases的线性组合,G-TUNING能够以高效的方式逼近真实的图元生成模式。

这一理论发现为G-TUNING模型的构建奠定了坚实基础,因为它允许模型有效地学习graphon bases及其相应的组合系数,从而在不牺牲性能的前提下,显著降低计算复杂度。实验结果显示,与现有技术相比,G-TUNING在域内迁移学习和跨域迁移学习场景下,分别实现了平均0.5%和2.6%的性能提升,充分验证了其有效性和优越性。

参考文献

[1] Yifei Sun, Qi Zhu, Yang Yang, Chunping Wang, Tianyu Fan, Jiajun Zhu, Lei Chen. Fine-tuning Graph Neural Networks by Preserving Graph Generative Patterns. [C] AAAI2024 https://export.arxiv.org/abs/2312.13583

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值