GNN从入门到放弃task06

Dataset基类简介

通过继承torch_geometric.data.Dataset基类,来自定义一个按需加载样本到内存的数据集类。

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

图样本封装成批(BATCHING)与DataLoader类

合并小图成大图

图可以有任意数量的节点和边,对图数据封装成批的操作与对图像和序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。
这样做的优势在于:

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换
  • 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。

小图的属性增值与拼接

  • 将小图存储到大图中时需要对小图的属性做一些修改,比如要对节点序号增值。在最一般的形式中,PyTorch Geometric的DataLoader类会自动对edge_index张量增值,增加的值为当前被处理图的前面的图的累积节点数量。
  • 图的匹配,例如想在一个Data对象中存储多个图,例如用于图匹配等应用,我们需要确保所有这些图的正确封装成批行为。例如,将一个源图 G s G_{s} Gs和一个目标图 G t G_{t} Gt,存储在一个Data类中。此时需要DataLoader的follow_batch参数发挥作用。在这里,我们可以指定我们要为哪些属性维护批信息。
  • 二部图的邻接矩阵定义两种类型的节点之间的连接关系。为了对二部图实现正确的封装成批,我们需要告诉PyTorch Geometric,它应该在edge_index中独立地为边的源节点和目标节点做增值操作。
  • 在新的维度上做拼接,有时,Data对象的属性需要在一个新的维度上做拼接(如经典的封装成批),例如,图级别属性或预测目标。具体来说,形状为[num_features]的属性列表应该被返回为[num_examples, num_features],而不是[num_examples * num_features]。PyTorch Geometric通过在__cat_dim__()中返回一个None的连接维度来实现这一点。

作业

import os
import torch
import argparse
from tqdm import tqdm
from ogb.lsc import PCQM4MEvaluator
from torch_geometric.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

from pcqm4m_data import MyPCQM4MDataset
from gin_graph import GINGraphPooling

from torch.utils.tensorboard import SummaryWriter

def parse_args():

    parser = argparse.ArgumentParser(description='Graph data miming with GNN')
    parser.add_argument('--task_name', type=str, default='GINGraphPooling',
                        help='task name')
    parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--num_layers', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
    parser.add_argument('--graph_pooling', type=str, default='sum',
                        help='graph pooling strategy mean or sum (default: sum)')
    parser.add_argument('--emb_dim', type=int, default=256,
                        help='dimensionality of hidden units in GNNs (default: 256)')
    parser.add_argument('--drop_ratio', type=float, default=0.,
                        help='dropout ratio (default: 0.)')
    parser.add_argument('--save_test', action='store_true')
    parser.add_argument('--batch_size', type=int, default=512,
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--weight_decay', type=float, default=0.00001,
                        help='weight decay')
    parser.add_argument('--early_stop', type=int, default=10,
                        help='early stop (default: 10)')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='number of workers (default: 4)')
    parser.add_argument('--dataset_root', type=str, default="dataset",
                        help='dataset root')
    args = parser.parse_args()

    return args


def prepartion(args):
    save_dir = os.path.join('saves', args.task_name)
    if os.path.exists(save_dir):
        for idx in range(1000):
            if not os.path.exists(save_dir + '=' + str(idx)):
                save_dir = save_dir + '=' + str(idx)
                break

    args.save_dir = save_dir
    os.makedirs(args.save_dir, exist_ok=True)
    args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    args.output_file = open(os.path.join(args.save_dir, 'output'), 'a')
    print(args, file=args.output_file, flush=True)


def train(model, device, loader, optimizer, criterion_fn):
    model.train()
    loss_accum = 0

    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)
        pred = model(batch).view(-1,)
        optimizer.zero_grad()
        loss = criterion_fn(pred, batch.y)
        loss.backward()
        optimizer.step()
        loss_accum += loss.detach().cpu().item()

    return loss_accum / (step + 1)


def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for _, batch in enumerate(tqdm(loader)):
            batch = batch.to(device)
            pred = model(batch).view(-1,)
            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)["mae"]


def test(model, device, loader):
    model.eval()
    y_pred = []

    with torch.no_grad():
        for _, batch in enumerate(loader):
            batch = batch.to(device)
            pred = model(batch).view(-1,)
            y_pred.append(pred.detach().cpu())

    y_pred = torch.cat(y_pred, dim=0)
    return y_pred


def main(args):
    prepartion(args)
    nn_params = {
        'num_layers': args.num_layers,
        'emb_dim': args.emb_dim,
        'drop_ratio': args.drop_ratio,
        'graph_pooling': args.graph_pooling
    }

    # automatic dataloading and splitting
    dataset = MyPCQM4MDataset(root=args.dataset_root)
    split_idx = dataset.get_idx_split()
    train_data = dataset[split_idx['train']]
    valid_data = dataset[split_idx['valid']]
    test_data = dataset[split_idx['test']]
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # automatic evaluator. takes dataset name as input
    evaluator = PCQM4MEvaluator()
    criterion_fn = torch.nn.MSELoss()

    device = args.device

    model = GINGraphPooling(**nn_params).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f'#Params: {num_params}', file=args.output_file, flush=True)
    print(model, file=args.output_file, flush=True)

    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.25)

    writer = SummaryWriter(log_dir=args.save_dir)
    not_improved = 0
    best_valid_mae = 9999
    for epoch in range(1, args.epochs + 1):
        print("=====Epoch {}".format(epoch), file=args.output_file, flush=True)
        print('Training...', file=args.output_file, flush=True)
        train_mae = train(model, device, train_loader, optimizer, criterion_fn)

        print('Evaluating...', file=args.output_file, flush=True)
        valid_mae = eval(model, device, valid_loader, evaluator)

        print({'Train': train_mae, 'Validation': valid_mae}, file=args.output_file, flush=True)

        writer.add_scalar('valid/mae', valid_mae, epoch)
        writer.add_scalar('train/mae', train_mae, epoch)

        if valid_mae < best_valid_mae:
            best_valid_mae = valid_mae
            if args.save_test:
                print('Saving checkpoint...', file=args.output_file, flush=True)
                checkpoint = {
                    'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, 'num_params': num_params
                }
                torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt'))
                print('Predicting on test data...', file=args.output_file, flush=True)
                y_pred = test(model, device, test_loader)
                print('Saving test submission file...', file=args.output_file, flush=True)
                evaluator.save_test_submission({'y_pred': y_pred}, args.save_dir)

            not_improved = 0
        else:
            not_improved += 1
            if not_improved == args.early_stop:
                print(f"Have not improved for {not_improved} epoches.", file=args.output_file, flush=True)
                break

        scheduler.step()
        print(f'Best validation MAE so far: {best_valid_mae}', file=args.output_file, flush=True)

    writer.close()
    args.output_file.close()


if __name__ == "__main__":
    args = parse_args()
    main(args)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值