利用DataPretcher加速
device = torch.device("cuda")
class DataPrefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.batch = next(self.loader)
except StopIteration:
self.batch = None
return
with torch.cuda.stream(self.stream):
for k in self.batch:
if k != 'meta':
self.batch[k] = self.batch[k].to(device=device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
详情参考:https://tianws.github.io/skill/2019/08/27/gpu-volatile/
1310

被折叠的 条评论
为什么被折叠?



