-----2022.04 update------
最近DGL官方提供了prefetch的功能,可以直接使用官方实现了:https://github.com/dmlc/dgl/pull/3665
导读: 这两天在研究怎么加速DGL上GNN的训练,使用line_profiler工具发现,除了forward和backward之外,最耗时的是CPU与GPU之间的数据传输(即mini batch训练时将当前batch的子图及对应的feature和label传输到GPU中)。因此尝试使用prefetch,希望在当前batch进行GPU计算的同时,将数据从CPU传到GPU。本文使用的例子是:
https://github.com/dmlc/dgl/tree/master/examples/pytorch/ogb_lsc/MAG240M,该场景下在我的环境中速度大概提升了15%
先看一下train.py
中主要的耗时(注意这里为了对齐后面的prefetch,将pin_memory设置为了True。而且为了防止OOM,将batch size设为了512):
Wrote profile results to train.py.lprof
Timer unit: 1e-06 s
Total time: 2000.99 s
File: train.py
Function: train at line 128
Line # Hits Time Per Hit % Time Line Contents
==============================================================
128 @profile
129 def train(args, dataset, g, feats, paper_offset):
130 1 23.0 23.0 0.0 print('Loading masks and labels')
131 1 2024.0 2024.0 0.0 train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
132 1 2506.0 2506.0 0.0 valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
133 1 28142.0 28142.0 0.0 label = dataset.paper_label
134
135 1 25.0 25.0 0.0 print('Initializing dataloader...')
136 1 44.0 44.0 0.0 sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
137 1 887.0 887.0 0.0 train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
138 1 70.0 70.0 0.0 valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
139 1 11.0 11.0 0.0 train_dataloader = torch.utils.data.DataLoader(
140 1 3.0 3.0 0.0 train_collator.dataset,
141 1 9.0 9.0 0.0 batch_size=args.batch_size,
142 1 1.0 1.0 0.0 shuffle=True,
143 1 1.0 1.0 0.0 drop_last=False,
144 1 1.0 1.0 0.0 collate_fn=train_collator.collate,
145 1 1.0 1.0 0.0 num_workers=4,
146 1 311.0 311.0 0.0 pin_memory=True
147 )
148 1 3.0 3.0 0.0 valid_dataloader = torch.utils.data.DataLoader(
149 1 2.0 2.0 0.0 valid_collator.dataset,
150 1 1.0 1.0 0.0 batch_size=args.batch_size,
151 1 1.0 1.0 0.0 shuffle=True,
152 1 2.0 2.0 0.0 drop_last=False,
153 1 2.0 2.0 0.0 collate_fn=valid_collator.collate,
154 1 1.0 1.0 0.0 num_workers=2,
155 1 76.0 76.0 0.0 pin_memory=True
156 )
157
158 1 23.0 23.0 0.0 print('Initializing model...')
159 1 11117937.0 11117937.0 0.6 model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
160 1 1948.0 1948.0 0.0 opt = torch.optim.Adam(model.parameters(), lr=0.001)
161 1 209.0 209.0 0.0 sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
162
163 1 2.0 2.0 0.0 best_acc = 0
164
165 2 14.0 7.0 0.0 for _ in range(args.epochs):
166 1 624.0 624.0 0.0 model.train()
167 1 3195.0 3195.0 0.0 with tqdm.tqdm(train_dataloader) as tq:
168 1 104.0 104.0 0.0 torch.cuda.synchronize()
169 1 3.0 3.0 0.0 t0 = time.perf_counter()
170 2174 176648528.0 81255.1 8.8 for i, (input_nodes, output_nodes, mfgs) in enumerate(tq):
171 2173 5748739.0 2645.5 0.3 mfgs = [g.to('cuda') for g in mfgs]
172
173 # t = mfgs[0].srcdata['x'][100]
174 # tt = mfgs[-1].dstdata['y'][5]
175 2173 389817972.0 179391.6 19.5 x = mfgs[0].srcdata['x'] #除了GPU前后向计算,就是这里最耗时
176 2173 389554.0 179.3 0.0 y = mfgs[-1].dstdata['y']
177 2173 542893145.0 249835.8 27.1 y_hat = model(mfgs, x)
178 2173 9109582.0 4192.2 0.5 loss = F.cross_entropy(y_hat, y)
179 2173 4741064.0 2181.8 0.2 opt.zero_grad()
180 2173 435564957.0 200444.1 21.8 loss.backward()
181 2173 16753289.0 7709.8 0.8 opt.step()
182 2173 288915.0 133.0 0.0 acc = (y_hat.argmax(1) == y).float().mean()
183 2173 197078022.0 90694.0 9.8 tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)
从上可以看出,除了forward和backward以外,耗时最长的是x = mfgs[0].srcdata['x']
。这里并不是取特征x
耗时长,实际是因为DGL中取特征是lazy的,当第一次取的时候才真正从CPU往GPU传数据,可以通过在前面调用一下mfgs[0].srcdata['x'][100]
来验证,添加这一行之后就会变成改行代码耗时很长了。
接下来尝试进行Prefetch,由于DGL Graph没有提供相应功能,这里只能退而求其次,将Feature
和Label
这两种Tensor类型进行Prefetch。
可以参考:1.https://zhuanlan.zhihu.com/p/66145913
2.https://zhuanlan.zhihu.com/p/72956595
原理介绍:https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789
同时发现dgl.dataloading.async_transfer
也提供了异步传输的功能(只能针对Tensor),因此可以较为方便的实现。
实现后结果如下,可以发现总时间从2000s左右减少为了1650s左右,主要节省的就是x = mfgs[0].srcdata['x']
这一行的时间。再次提醒,如果pin_memory
设置为了False
,是无法使用的。
Wrote profile results to train_prefetch.py.lprof
Timer unit: 1e-06 s
Total time: 1654.36 s
File: train_prefetch.py
Function: train at line 172
Line # Hits Time Per Hit % Time Line Contents
==============================================================
172 @profile
173 def train(args, dataset, g, feats, paper_offset):
202 1 22.0 22.0 0.0 print('Initializing model...')
203 1 6740234.0 6740234.0 0.4 model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
204 1 1917.0 1917.0 0.0 opt = torch.optim.Adam(model.parameters(), lr=0.001)
205 1 208.0 208.0 0.0 sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
206
207 1 2.0 2.0 0.0 best_acc = 0
208
209
210 2 9.0 4.5 0.0 for _ in range(args.epochs):
211 1 10178442.0 10178442.0 0.6 train_prefetcher = data_prefetcher(train_dataloader, dev_id=0)
212 1 6792441.0 6792441.0 0.4 valid_prefetcher = data_prefetcher(valid_dataloader, dev_id=0)
213 1 1083.0 1083.0 0.0 model.train()
214 1 1976.0 1976.0 0.0 with tqdm.tqdm(train_prefetcher) as tq:
215 1 54357.0 54357.0 0.0 torch.cuda.synchronize()
216 1 5.0 5.0 0.0 t0 = time.perf_counter()
217 2174 145970461.0 67143.7 8.8 for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tq):
218 2173 8052025.0 3705.5 0.5 mfgs = [g.to('cuda') for g in mfgs]
219
220 2173 562793527.0 258993.8 34.0 y_hat = model(mfgs, input_x)
221 2173 10833457.0 4985.5 0.7 loss = F.cross_entropy(y_hat, target_y)
222 2173 4943455.0 2274.9 0.3 opt.zero_grad()
223 2173 461658294.0 212452.0 27.9 loss.backward()
224 2173 20994316.0 9661.4 1.3 opt.step()
225 2173 300364.0 138.2 0.0 acc = (y_hat.argmax(1) == target_y).float().mean()
226 2173 211057970.0 97127.5 12.8 tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)
完整代码如下:
#!/usr/bin/env python
# coding: utf-8
import ogb
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
import dgl
import torch
import numpy as np
import time
import tqdm
import dgl.function as fn
import numpy as np
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import argparse
class data_prefetcher:
def __init__(self, loader, dev_id):
self.loader = iter(loader)
self.dev_id = dev_id
self.transfer = dgl.dataloading.AsyncTransferer(dev_id)
self.preload()
def __iter__(self):
return self
def preload(self):
try:
self.input_nodes, self.output_nodes, self.mfgs, self.input_x, self.target_y = next(
self.loader)
except StopIteration:
self.input_nodes = None
self.output_nodes = None
self.mfgs = None
self.input_x_future = None
self.target_y_future = None
return
self.input_x_future = self.transfer.async_copy(self.input_x, self.dev_id)
self.target_y_future = self.transfer.async_copy(self.target_y, self.dev_id)
def __next__(self):
input_nodes = self.input_nodes
output_nodes = self.output_nodes
mfgs = self.mfgs
input_x_future = self.input_x_future
target_y_future = self.target_y_future
if input_x_future is not None:
input_x = input_x_future.wait()
else:
raise StopIteration()
if target_y_future is not None:
target_y = target_y_future.wait()
else:
raise StopIteration()
self.preload()
return input_nodes, output_nodes, mfgs, input_x, target_y
class RGAT(nn.Module):
def __init__(self, in_channels, out_channels, hidden_channels, num_etypes, num_layers, num_heads, dropout,
pred_ntype):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.skips = nn.ModuleList()
self.convs.append(nn.ModuleList([
dglnn.GATConv(in_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(in_channels, hidden_channels))
for _ in range(num_layers - 1):
self.convs.append(nn.ModuleList([
dglnn.GATConv(hidden_channels, hidden_channels // num_heads, num_heads, allow_zero_in_degree=True)
for _ in range(num_etypes)
]))
self.norms.append(nn.BatchNorm1d(hidden_channels))
self.skips.append(nn.Linear(hidden_channels, hidden_channels))
self.mlp = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
)
self.dropout = nn.Dropout(dropout)
self.hidden_channels = hidden_channels
self.pred_ntype = pred_ntype
self.num_etypes = num_etypes
def forward(self, mfgs, x):
for i in range(len(mfgs)):
mfg = mfgs[i]
x_dst = x[:mfg.num_dst_nodes()]
n_src = mfg.num_src_nodes()
n_dst = mfg.num_dst_nodes()
mfg = dgl.block_to_graph(mfg)
x_skip = self.skips[i](x_dst)
for j in range(self.num_etypes):
subg = mfg.edge_subgraph(mfg.edata['etype'] == j, preserve_nodes=True)
x_skip += self.convs[i][j](subg, (x, x_dst)).view(-1, self.hidden_channels)
x = self.norms[i](x_skip)
x = F.elu(x)
x = self.dropout(x)
return self.mlp(x)
class ExternalNodeCollator(dgl.dataloading.NodeCollator):
def __init__(self, g, idx, sampler, offset, feats, label):
super().__init__(g, idx, sampler)
self.offset = offset
self.feats = feats
self.label = label
def collate(self, items):
input_nodes, output_nodes, mfgs = super().collate(items)
# Copy input features
# mfgs[0].srcdata['x'] = torch.FloatTensor(self.feats[input_nodes])
# mfgs[-1].dstdata['y'] = torch.LongTensor(self.label[output_nodes - self.offset])
input_x = torch.FloatTensor(self.feats[input_nodes])
target_y = torch.LongTensor(self.label[output_nodes - self.offset])
return input_nodes, output_nodes, mfgs, input_x, target_y
def print_memory_usage():
import os
import psutil
process = psutil.Process(os.getpid())
print("memory usage is {} GB".format(process.memory_info()[0] / 1024 / 1024 / 1024))
# @profile
def train(args, dataset, g, feats, paper_offset):
print('Loading masks and labels')
train_idx = torch.LongTensor(dataset.get_idx_split('train')) + paper_offset
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
label = dataset.paper_label
print('Initializing dataloader...')
sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 25])
train_collator = ExternalNodeCollator(g, train_idx, sampler, paper_offset, feats, label)
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
train_dataloader = torch.utils.data.DataLoader(
train_collator.dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
collate_fn=train_collator.collate,
num_workers=4,
pin_memory=True # 一定要设为True
)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2,
pin_memory=True
)
print('Initializing model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=25, gamma=0.25)
best_acc = 0
for _ in range(args.epochs):
# 每个Epoch需要将dataloader额外包一下
train_prefetcher = data_prefetcher(train_dataloader, dev_id=0)
valid_prefetcher = data_prefetcher(valid_dataloader, dev_id=0)
model.train()
with tqdm.tqdm(train_prefetcher) as tq:
torch.cuda.synchronize()
t0 = time.perf_counter()
for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tq):
mfgs = [g.to('cuda') for g in mfgs]
y_hat = model(mfgs, input_x)
loss = F.cross_entropy(y_hat, target_y)
opt.zero_grad()
loss.backward()
opt.step()
acc = (y_hat.argmax(1) == target_y).float().mean()
tq.set_postfix({'loss': '%.4f' % loss.item(), 'acc': '%.4f' % acc.item()}, refresh=False)
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs, input_x, target_y) in enumerate(tqdm.tqdm(valid_prefetcher)):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = input_x
y = target_y
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print('Validation accuracy:', acc)
sched.step()
if best_acc < acc:
best_acc = acc
print('Updating best model...')
torch.save(model.state_dict(), args.model_path)
def test(args, dataset, g, feats, paper_offset):
print('Loading masks and labels...')
valid_idx = torch.LongTensor(dataset.get_idx_split('valid')) + paper_offset
test_idx = torch.LongTensor(dataset.get_idx_split('test')) + paper_offset
label = dataset.paper_label
print('Initializing data loader...')
sampler = dgl.dataloading.MultiLayerNeighborSampler([160, 160])
valid_collator = ExternalNodeCollator(g, valid_idx, sampler, paper_offset, feats, label)
valid_dataloader = torch.utils.data.DataLoader(
valid_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=valid_collator.collate,
num_workers=2
)
test_collator = ExternalNodeCollator(g, test_idx, sampler, paper_offset, feats, label)
test_dataloader = torch.utils.data.DataLoader(
test_collator.dataset,
batch_size=16,
shuffle=False,
drop_last=False,
collate_fn=test_collator.collate,
num_workers=4
)
print('Loading model...')
model = RGAT(dataset.num_paper_features, dataset.num_classes, 1024, 5, 2, 4, 0.5, 'paper').cuda()
model.load_state_dict(torch.load(args.model_path))
model.eval()
correct = total = 0
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(valid_dataloader)):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
y_hat = model(mfgs, x)
correct += (y_hat.argmax(1) == y).sum().item()
total += y_hat.shape[0]
acc = correct / total
print('Validation accuracy:', acc)
evaluator = MAG240MEvaluator()
y_preds = []
for i, (input_nodes, output_nodes, mfgs) in enumerate(tqdm.tqdm(test_dataloader)):
with torch.no_grad():
mfgs = [g.to('cuda') for g in mfgs]
x = mfgs[0].srcdata['x']
y = mfgs[-1].dstdata['y']
y_hat = model(mfgs, x)
y_preds.append(y_hat.argmax(1).cpu())
evaluator.save_test_submission({'y_pred': torch.cat(y_preds)}, args.submission_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--rootdir', type=str, default='.', help='Directory to download the OGB dataset.')
parser.add_argument('--graph-path', type=str, default='./graph.dgl', help='Path to the graph.')
parser.add_argument('--full-feature-path', type=str, default='./full.npy',
help='Path to the features of all nodes.')
parser.add_argument('--epochs', type=int, default=1, help='Number of epochs.')
parser.add_argument('--batch-size', type=int, default=512)
parser.add_argument('--model-path', type=str, default='./model.pt', help='Path to store the best model.')
parser.add_argument('--submission-path', type=str, default='./results', help='Submission directory.')
args = parser.parse_args()
dataset = MAG240MDataset(root=args.rootdir)
print('Loading graph')
(g,), _ = dgl.load_graphs(args.graph_path)
g = g.formats(['csc'])
print('Loading features')
paper_offset = dataset.num_authors + dataset.num_institutions
num_nodes = paper_offset + dataset.num_papers
num_features = dataset.num_paper_features
feats = np.memmap(args.full_feature_path, mode='r', dtype='float16', shape=(num_nodes, num_features))
if args.epochs != 0:
train(args, dataset, g, feats, paper_offset)
# test(args, dataset, g, feats, paper_offset)