A Generalization of ViTMLP-Mixer to Graphs源码解读2——model.py

model.py  定义文章提出的模型

def forward

下面这一段对应patch embedding

torch.scatter用法参考torch.scatter-CSDN博客

        # Patch Encoder
        x = x[data.subgraphs_nodes_mapper]
        e = edge_attr[data.subgraphs_edges_mapper]
        edge_index = data.combined_subgraphs
        batch_x = data.subgraphs_batch
        for i, gnn in enumerate(self.gnns):
            if i > 0:
                subgraph = scatter(x, batch_x, dim=0,
                                   reduce=self.pooling)[batch_x]
                x = x + self.U[i-1](subgraph)#U是MLP
                #torch.scatter(input, dim, index, src)
                x = scatter(x, data.subgraphs_nodes_mapper,
                            dim=0, reduce='mean')[data.subgraphs_nodes_mapper]
            x = gnn(x, edge_index, e)
        subgraph_x = scatter(x, batch_x, dim=0, reduce=self.pooling)

transformer.py   预处理把graph分为patch

cfg = set_cfg(CN())

train_helper.py

在这里定义了run函数

run函数的其中一个参数是cfg,

cfg  用.yaml文件设置参数

cfg定义在config.py

from yacs.config import CfgNode as CN
cfg = set_cfg(CN())

run函数的__main__中调用yacs中的merge_from_file来读取config

if __name__ == '__main__':
    # get config
    cfg.merge_from_file('train/configs/GraphMLPMixer/cifar10.yaml')
    cfg = update_cfg(cfg)
    run(cfg, create_dataset, create_model, train, test)

run的参数还包括train函数,

run(cfg, create_dataset, create_model, train, test)

def train(train_loader, model, optimizer, evaluator, device, sharp):

run函数调用torch.Dataloader制作数据集train_loader变量

Dataloader的输入参数train_dataset由get_data.py的create_dataset函数得到

    train_dataset, val_dataset, test_dataset = create_dataset(cfg)

    train_loader = DataLoader(#create_dataset制作train_dataset
        train_dataset, cfg.train.batch_size, shuffle=True, num_workers=cfg.num_workers)

Dataloader的参数包括train_dataset,train_dataset由get_data.py中的create_dataset函数制作

cifar10.py

这是数据集cifar10的训练文件

使用cifar10数据集,直接运行这个文件就行

# from ogb.utils import smiles2graph
from ogb.utils.mol import smiles2graph

transform_util文件夹 subgraph_extractors.py

把原本的graph分割成多个patch

踩到的坑1 安装metis

在安装好Cmake和metis后,不知道为什么,pip install metis一直不成功(这个相当于python中的metis wrapper)

于是从官网下载tar.gz文件,官网如下

metis · PyPI

tar.gz文件不是.whl格式,是setup.py,安装教程如下

在Anaconda的环境中安装.tar.gz包-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值