model.py 定义文章提出的模型
def forward
下面这一段对应patch embedding
torch.scatter用法参考torch.scatter-CSDN博客
# Patch Encoder
x = x[data.subgraphs_nodes_mapper]
e = edge_attr[data.subgraphs_edges_mapper]
edge_index = data.combined_subgraphs
batch_x = data.subgraphs_batch
for i, gnn in enumerate(self.gnns):
if i > 0:
subgraph = scatter(x, batch_x, dim=0,
reduce=self.pooling)[batch_x]
x = x + self.U[i-1](subgraph)#U是MLP
#torch.scatter(input, dim, index, src)
x = scatter(x, data.subgraphs_nodes_mapper,
dim=0, reduce='mean')[data.subgraphs_nodes_mapper]
x = gnn(x, edge_index, e)
subgraph_x = scatter(x, batch_x, dim=0, reduce=self.pooling)
transformer.py 预处理把graph分为patch
cfg = set_cfg(CN())
train_helper.py
在这里定义了run函数
run函数的其中一个参数是cfg,
cfg 用.yaml文件设置参数
cfg定义在config.py
from yacs.config import CfgNode as CN
cfg = set_cfg(CN())
run函数的__main__中调用yacs中的merge_from_file来读取config
if __name__ == '__main__':
# get config
cfg.merge_from_file('train/configs/GraphMLPMixer/cifar10.yaml')
cfg = update_cfg(cfg)
run(cfg, create_dataset, create_model, train, test)
run的参数还包括train函数,
run(cfg, create_dataset, create_model, train, test)
def train(train_loader, model, optimizer, evaluator, device, sharp):
run函数调用torch.Dataloader制作数据集train_loader变量
Dataloader的输入参数train_dataset由get_data.py的create_dataset函数得到
train_dataset, val_dataset, test_dataset = create_dataset(cfg)
train_loader = DataLoader(#create_dataset制作train_dataset
train_dataset, cfg.train.batch_size, shuffle=True, num_workers=cfg.num_workers)
Dataloader的参数包括train_dataset,train_dataset由get_data.py中的create_dataset函数制作
cifar10.py
这是数据集cifar10的训练文件
使用cifar10数据集,直接运行这个文件就行
# from ogb.utils import smiles2graph
from ogb.utils.mol import smiles2graph
transform_util文件夹 subgraph_extractors.py
把原本的graph分割成多个patch
踩到的坑1 安装metis
在安装好Cmake和metis后,不知道为什么,pip install metis一直不成功(这个相当于python中的metis wrapper)
于是从官网下载tar.gz文件,官网如下
tar.gz文件不是.whl格式,是setup.py,安装教程如下