Exphormer Sparse Transformers for Graphs源码解读1

前言

这篇文章是以graphGPS为基础作了修改,修改的是graphGPS的global layer层

graphGPS框架图如下

layer中定义了各种可能在graphGPS框架里用到的层

同时layer也定义了multi_model_layer.py,在此文件中,可以根据输入参数,选择自己定义的层或torchgeo库的层

Layer文件夹 multi_model_layer.py

class LocalModel  对应图中MPNNlayer

class GlobalModel 对应图中global attention

class MultiLayer将上面两个类加入进来,组成整个GPS Layers。

Network文件夹 multi_model.py

在class MultiModel中

MultiLayer类与FeatureEncoder类组合,构成整个graphGPS框架

torchgeo用register_network来定义网络

register_network('MultiModel', MultiModel)

__main__.py

 # Set machine learning pipeline
        loaders = create_loader()
        loggers = create_logger()
        # custom_train expects three loggers for 'train', 'valid' and 'test'.
        # GraphGym code creates one logger/loader for each of the 'train_mask' etc.
        # attributes in the dataset. As a work around it, we create one logger for each
        # of the types.
        # loaders are a const, so it is ok to just duplicate the loader. 

__main__使用的是graphgym,现在graphgym已经被集成进来torchgeometric

layer文件夹 exphomer.py

定义了exphomer的层,里面有虚拟节点的部分

 # 虚拟节点
        if self.use_virt_nodes:
            h = torch.cat([h, batch.virt_h], dim=0)
            edge_index = torch.cat([edge_index, batch.virt_edge_index], dim=1)
            edge_attr = torch.cat([edge_attr, batch.virt_edge_attr], dim=0)

注意力机制很类似transformer但又不是transformer那种Q乘K的转置,softmax,再乘V

而是Q和K逐元素点乘,沿dim-1求和,得到[num_edges, num_heads,1]的score

score与V逐元素点乘

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值