前言
这篇文章是以graphGPS为基础作了修改,修改的是graphGPS的global layer层
graphGPS框架图如下
layer中定义了各种可能在graphGPS框架里用到的层
同时layer也定义了multi_model_layer.py,在此文件中,可以根据输入参数,选择自己定义的层或torchgeo库的层
Layer文件夹 multi_model_layer.py
class LocalModel 对应图中MPNNlayer
class GlobalModel 对应图中global attention
class MultiLayer将上面两个类加入进来,组成整个GPS Layers。
Network文件夹 multi_model.py
在class MultiModel中
MultiLayer类与FeatureEncoder类组合,构成整个graphGPS框架
torchgeo用register_network来定义网络
register_network('MultiModel', MultiModel)
__main__.py
# Set machine learning pipeline
loaders = create_loader()
loggers = create_logger()
# custom_train expects three loggers for 'train', 'valid' and 'test'.
# GraphGym code creates one logger/loader for each of the 'train_mask' etc.
# attributes in the dataset. As a work around it, we create one logger for each
# of the types.
# loaders are a const, so it is ok to just duplicate the loader.
__main__使用的是graphgym,现在graphgym已经被集成进来torchgeometric
layer文件夹 exphomer.py
定义了exphomer的层,里面有虚拟节点的部分
# 虚拟节点
if self.use_virt_nodes:
h = torch.cat([h, batch.virt_h], dim=0)
edge_index = torch.cat([edge_index, batch.virt_edge_index], dim=1)
edge_attr = torch.cat([edge_attr, batch.virt_edge_attr], dim=0)
注意力机制很类似transformer但又不是transformer那种Q乘K的转置,softmax,再乘V
而是Q和K逐元素点乘,沿dim-1求和,得到[num_edges, num_heads,1]的score
score与V逐元素点乘