PyTorch和TF在处理TPU训练上有一个明显的不同,那就是PyTorch缺少steps_per_execution这个参数。简单来说,TF可以一次喂给TPU一堆东西,而PyTorch XLA不可以。前两天,刚刚提了这个bug。得到回复知道怎么修了,结果人家官方有个人18个小时提出来了。
Anyway,这段代码是核心:
from __future__ import division
from __future__ import print_function
from six import iteritems, itervalues
import threading
import torch
import torch_xla
import torch_xla.utils.keyd_queue as kq
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
class PerDeviceQueue(object):
def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
class PerDeviceLoader(object):
def __init__(self, loader, device, experimental_steps_per_execution=1):
self._loader = loader
self._device = device
self._steps_count = 0
self._experiment_steps_per_execution = experimental_steps_per_execution
def __iter__(self):
return self
def __next__(self):
return self.next()
def __len__(self):
return self._loader.per_device_samples()
def next(self):
if self._steps_count % self._experiment_steps_per_execution == 0:
xm.mark_step()
self._steps_count = 0
else:
self._steps_count += 1
item = self._loader.next_item(self._device)
if item is None:
raise StopIteration