DGL中minibatch训练子图Prefetch到GPU中加速

-----2022.04 update------
最近DGL官方提供了prefetch的功能,可以直接使用官方实现了:https://github.com/dmlc/dgl/pull/3665

导读: 这两天在研究怎么加速DGL上GNN的训练,使用line_profiler工具发现,除了forward和backward之外,最耗时的是CPU与GPU之间的数据传输(即mini batch训练时将当前batch的子图及对应的feature和label传输到GPU中)。因此尝试使用prefetch,希望在当前batch进行GPU计算的同时,将数据从CPU传到GPU。本文使用的例子是:https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb_lsc/MAG240M,该场景下在我的环境中速度大概提升了15%

  先看一下train.py中主要的耗时(注意这里为了对齐后面的prefetch,将pin_memory设置为了True。而且为了防止OOM,将batch size设为了512):

Wrote profile results to train.py.lprof
Timer unit: 1e-06 s

Total time: 2000.99 s
File: train.py
Function: train at line 128

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   128                                           @profile
   129                                           def train(args, dataset, g, feats, paper_offset):
   130         1         23.0     23.0      0.0      print('Loading masks and labels')
   131         1       2024.0   2024.0      0.0      train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
   132         1       2506.0   2506.0      0.0      valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
   133         1      28142.0  28142.0      0.0      label = dataset.paper_label
   134
   135         1         25.0     25.0      0.0      print('Initializing dataloader...')
   136         1         44.0     44.0      0.0      sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
   137         1        887.0    887.0      0.0      train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
   138         1         70.0     70.0      0.0      valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
   139         1         11.0     11.0      0.0      train_dataloader = torch.utils.data.DataLoader(
   140         1          3.0      3.0      0.0          train_collator.dataset,
   141         1          9.0      9.0      0.0          batch_size=args.batch_size,
   142         1          1.0      1.0      0.0          shuffle=True,
   143         1          1.0      1.0      0.0          drop_last=False,
   144         1          1.0      1.0      0.0          collate_fn=train_collator.collate,
   145         1          1.0      1.0      0.0          num_workers=4,
   146         1        311.0    311.0      0.0          pin_memory=True
   147                                               )
   148         1          3.0      3.0      0.0      valid_dataloader = torch.utils.data.DataLoader(
   149         1          2.0      2.0      0.0          valid_collator.dataset,
   150         1          1.0      1.0      0.0          batch_size=args.batch_size,
   151         1          1.0      1.0      0.0          shuffle=True,
   152         1          2.0      2.0      0.0          drop_last=False,
   153         1          2.0      2.0      0.0          collate_fn=valid_collator.collate,
   154         1          1.0      1.0      0.0          num_workers=2,
   155         1         76.0     76.0      0.0          pin_memory=True
   156                                               )
   157
   158         1         23.0     23.0      0.0      print('Initializing model...')
   159         1   11117937.0 11117937.0      0.6      model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
   160         1       1948.0   1948.0      0.0      opt = torch.optim.Adam(model.parameters(), lr=0.001)
   161         1        209.0    209.0      0.0      sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
   162
   163         1          2.0      2.0      0.0      best_acc = 0
   164
   165         2         14.0      7.0      0.0      for _ in range(args.epochs):
   166         1        624.0    624.0      0.0          model.train()
   167         1       3195.0   3195.0      0.0          with tqdm.tqdm(train_dataloader) as tq:
   168         1        104.0    104.0      0.0              torch.cuda.synchronize()
   169         1          3.0      3.0      0.0              t0 = time.perf_counter()
   170      2174  176648528.0  81255.1      8.8              for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
   171      2173    5748739.0   2645.5      0.3                  mfgs = [g.to('cuda') for g in mfgs]
   172
   173                                                           # t = mfgs[0].srcdata['x'][100]
   174                                                           # tt = mfgs[-1].dstdata['y'][5]
   175      2173  389817972.0 179391.6     19.5                  x = mfgs[0].srcdata['x']  #除了GPU前后向计算,就是这里最耗时
   176      2173     389554.0    179.3      0.0                  y = mfgs[-1].dstdata['y']
   177      2173  542893145.0 249835.8     27.1                  y_hat = model(mfgs, x)
   178      2173    9109582.0   4192.2      0.5                  loss = F.cross_entropy(y_hat, y)
   179      2173    4741064.0   2181.8      0.2                  opt.zero_grad()
   180      2173  435564957.0 200444.1     21.8                  loss.backward()
   181      2173   16753289.0   7709.8      0.8                  opt.step()
   182      2173     288915.0    133.0      0.0                  acc = (y_hat.argmax(1) == y).float().mean()
   183      2173  197078022.0  90694.0      9.8                  tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)

  从上可以看出,除了forward和backward以外,耗时最长的是x = mfgs[0].srcdata['x']。这里并不是取特征x耗时长,实际是因为DGL中取特征是lazy的,当第一次取的时候才真正从CPU往GPU传数据,可以通过在前面调用一下mfgs[0].srcdata['x'][100]来验证,添加这一行之后就会变成改行代码耗时很长了。
  接下来尝试进行Prefetch,由于DGL Graph没有提供相应功能,这里只能退而求其次,将FeatureLabel这两种Tensor类型进行Prefetch。
可以参考:1.https://zhuanlan.zhihu.com/p/66145913
2.https://zhuanlan.zhihu.com/p/72956595
原理介绍:https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789
  同时发现dgl.dataloading.async_transfer也提供了异步传输的功能(只能针对Tensor),因此可以较为方便的实现。

  实现后结果如下,可以发现总时间从2000s左右减少为了1650s左右,主要节省的就是x = mfgs[0].srcdata['x']这一行的时间。再次提醒,如果pin_memory设置为了False,是无法使用的。

Wrote profile results to train_prefetch.py.lprof
Timer unit: 1e-06 s

Total time: 1654.36 s
File: train_prefetch.py
Function: train at line 172

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   172                                           @profile
   173                                           def train(args, dataset, g, feats, paper_offset):
   202         1         22.0     22.0      0.0      print('Initializing model...')
   203         1    6740234.0 6740234.0      0.4      model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
   204         1       1917.0   1917.0      0.0      opt = torch.optim.Adam(model.parameters(), lr=0.001)
   205         1        208.0    208.0      0.0      sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
   206
   207         1          2.0      2.0      0.0      best_acc = 0
   208
   209
   210         2          9.0      4.5      0.0      for _ in range(args.epochs):
   211         1   10178442.0 10178442.0      0.6          train_prefetcher = data_prefetcher(train_dataloader, dev_id=0)
   212         1    6792441.0 6792441.0      0.4          valid_prefetcher = data_prefetcher(valid_dataloader, dev_id=0)
   213         1       1083.0   1083.0      0.0          model.train()
   214         1       1976.0   1976.0      0.0          with tqdm.tqdm(train_prefetcher) as tq:
   215         1      54357.0  54357.0      0.0              torch.cuda.synchronize()
   216         1          5.0      5.0      0.0              t0 = time.perf_counter()
   217      2174  145970461.0  67143.7      8.8              for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tq):
   218      2173    8052025.0   3705.5      0.5                  mfgs = [g.to('cuda') for g in mfgs]
   219
   220      2173  562793527.0 258993.8     34.0                  y_hat = model(mfgs, input_x)
   221      2173   10833457.0   4985.5      0.7                  loss = F.cross_entropy(y_hat, target_y)
   222      2173    4943455.0   2274.9      0.3                  opt.zero_grad()
   223      2173  461658294.0 212452.0     27.9                  loss.backward()
   224      2173   20994316.0   9661.4      1.3                  opt.step()
   225      2173     300364.0    138.2      0.0                  acc = (y_hat.argmax(1) == target_y).float().mean()
   226      2173  211057970.0  97127.5     12.8                  tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)

  完整代码如下:


#!/usr/bin/env python
# coding: utf-8
import ogb
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch
import numpy as np
import time
import tqdm
import dgl.function as fn
import numpy as np
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import argparse


class data_prefetcher:
    def __init__(self, loader, dev_id):
        self.loader = iter(loader)
        self.dev_id = dev_id
        self.transfer = dgl.dataloading.AsyncTransferer(dev_id)
        self.preload()

    def __iter__(self):
        return self

    def preload(self):
        try:
            self.input_nodes, self.output_nodes, self.mfgs, self.input_x, self.target_y = next(
                self.loader)
        except StopIteration:
            self.input_nodes = None
            self.output_nodes = None
            self.mfgs = None
            self.input_x_future = None
            self.target_y_future = None
            return
        self.input_x_future = self.transfer.async_copy(self.input_x, self.dev_id)
        self.target_y_future = self.transfer.async_copy(self.target_y, self.dev_id)

    def __next__(self):
        input_nodes = self.input_nodes
        output_nodes = self.output_nodes
        mfgs = self.mfgs
        input_x_future = self.input_x_future
        target_y_future = self.target_y_future
        if input_x_future is not None:
            input_x = input_x_future.wait()
        else:
            raise StopIteration()
        if target_y_future is not None:
            target_y = target_y_future.wait()
        else:
            raise StopIteration()
        self.preload()
        return input_nodes, output_nodes, mfgs, input_x, target_y


class RGAT(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout,
                 pred_ntype):
        super().__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.skips = nn.ModuleList()

        self.convs.append(nn.ModuleList([
            dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
            for _ in range(num_etypes)
        ]))
        self.norms.append(nn.BatchNorm1d(hidden_channels))
        self.skips.append(nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(nn.ModuleList([
                dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
                for _ in range(num_etypes)
            ]))
            self.norms.append(nn.BatchNorm1d(hidden_channels))
            self.skips.append(nn.Linear(hidden_channels, hidden_channels))

        self.mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.BatchNorm1d(hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, out_channels)
        )
        self.dropout = nn.Dropout(dropout)

        self.hidden_channels = hidden_channels
        self.pred_ntype = pred_ntype
        self.num_etypes = num_etypes

    def forward(self, mfgs, x):
        for i in range(len(mfgs)):
            mfg = mfgs[i]
            x_dst = x[:mfg.num_dst_nodes()]
            n_src = mfg.num_src_nodes()
            n_dst = mfg.num_dst_nodes()
            mfg = dgl.block_to_graph(mfg)
            x_skip = self.skips[i](x_dst)
            for j in range(self.num_etypes):
                subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True)
                x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
            x = self.norms[i](x_skip)
            x = F.elu(x)
            x = self.dropout(x)
        return self.mlp(x)


class ExternalNodeCollator(dgl.dataloading.NodeCollator):
    def __init__(self, g, idx, sampler, offset, feats, label):
        super().__init__(g, idx, sampler)
        self.offset = offset
        self.feats = feats
        self.label = label

    def collate(self, items):
        input_nodes, output_nodes, mfgs = super().collate(items)
        # Copy input features
        # mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes])
        # mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset])
        input_x = torch.FloatTensor(self.feats[input_nodes])
        target_y = torch.LongTensor(self.label[output_nodes - self.offset])
        return input_nodes, output_nodes, mfgs, input_x, target_y


def print_memory_usage():
    import os
    import psutil
    process = psutil.Process(os.getpid())
    print("memory usage is {} GB".format(process.memory_info()[0] / 1024 / 1024 / 1024))


# @profile
def train(args, dataset, g, feats, paper_offset):
    print('Loading masks and labels')
    train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
    valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
    label = dataset.paper_label

    print('Initializing dataloader...')
    sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
    train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
    valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
    train_dataloader = torch.utils.data.DataLoader(
        train_collator.dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        collate_fn=train_collator.collate,
        num_workers=4,
        pin_memory=True # 一定要设为True
    )
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        collate_fn=valid_collator.collate,
        num_workers=2,
        pin_memory=True
    )

    print('Initializing model...')
    model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)

    best_acc = 0


    for _ in range(args.epochs):
        # 每个Epoch需要将dataloader额外包一下
        train_prefetcher = data_prefetcher(train_dataloader, dev_id=0)
        valid_prefetcher = data_prefetcher(valid_dataloader, dev_id=0)
        model.train()
        with tqdm.tqdm(train_prefetcher) as tq:
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tq):
                mfgs = [g.to('cuda') for g in mfgs]

                y_hat = model(mfgs, input_x)
                loss = F.cross_entropy(y_hat, target_y)
                opt.zero_grad()
                loss.backward()
                opt.step()
                acc = (y_hat.argmax(1) == target_y).float().mean()
                tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)

        model.eval()
        correct = total = 0
        for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tqdm.tqdm(valid_prefetcher)):
            with torch.no_grad():
                mfgs = [g.to('cuda') for g in mfgs]
                x = input_x
                y = target_y
                y_hat = model(mfgs, x)
                correct += (y_hat.argmax(1) == y).sum().item()
                total += y_hat.shape[0]
        acc = correct / total
        print('Validation accuracy:', acc)

        sched.step()

        if best_acc < acc:
            best_acc = acc
            print('Updating best model...')
            torch.save(model.state_dict(), args.model_path)


def test(args, dataset, g, feats, paper_offset):
    print('Loading masks and labels...')
    valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
    test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset
    label = dataset.paper_label

    print('Initializing data loader...')
    sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
    valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
    valid_dataloader = torch.utils.data.DataLoader(
        valid_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=valid_collator.collate,
        num_workers=2
    )
    test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
    test_dataloader = torch.utils.data.DataLoader(
        test_collator.dataset,
        batch_size=16,
        shuffle=False,
        drop_last=False,
        collate_fn=test_collator.collate,
        num_workers=4
    )

    print('Loading model...')
    model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
    model.load_state_dict(torch.load(args.model_path))

    model.eval()
    correct = total = 0
    for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
        with torch.no_grad():
            mfgs = [g.to('cuda') for g in mfgs]
            x = mfgs[0].srcdata['x']
            y = mfgs[-1].dstdata['y']
            y_hat = model(mfgs, x)
            correct += (y_hat.argmax(1) == y).sum().item()
            total += y_hat.shape[0]
    acc = correct / total
    print('Validation accuracy:', acc)
    evaluator = MAG240MEvaluator()
    y_preds = []
    for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)):
        with torch.no_grad():
            mfgs = [g.to('cuda') for g in mfgs]
            x = mfgs[0].srcdata['x']
            y = mfgs[-1].dstdata['y']
            y_hat = model(mfgs, x)
            y_preds.append(y_hat.argmax(1).cpu())
    evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.')
    parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.')
    parser.add_argument('--full-feature-path', type=str, default='./full.npy',
                        help='Path to the features of all nodes.')
    parser.add_argument('--epochs', type=int, default=1, help='Number of epochs.')
    parser.add_argument('--batch-size', type=int, default=512)
    parser.add_argument('--model-path', type=str, default='./model.pt', help='Path to store the best model.')
    parser.add_argument('--submission-path', type=str, default='./results', help='Submission directory.')
    args = parser.parse_args()

    dataset = MAG240MDataset(root=args.rootdir)

    print('Loading graph')
    (g,), _ = dgl.load_graphs(args.graph_path)
    g = g.formats(['csc'])

    print('Loading features')
    paper_offset = dataset.num_authors + dataset.num_institutions
    num_nodes = paper_offset + dataset.num_papers
    num_features = dataset.num_paper_features
    feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features))

    if args.epochs != 0:
        train(args, dataset, g, feats, paper_offset)
    # test(args, dataset, g, feats, paper_offset)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值