关于PyTorch XLA的问题以及为啥现在开源代码都要拼手速

PyTorch和TF在处理TPU训练上有一个明显的不同,那就是PyTorch缺少steps_per_execution这个参数。简单来说,TF可以一次喂给TPU一堆东西,而PyTorch XLA不可以。前两天,刚刚提了这个bug。得到回复知道怎么修了,结果人家官方有个人18个小时提出来了。

Anyway,这段代码是核心:

from __future__ import division
from __future__ import print_function

from six import iteritems, itervalues
import threading
import torch
import torch_xla
import torch_xla.utils.keyd_queue as kq
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm


class PerDeviceQueue(object):

  def __init__(self, device, loader_prefetch_size, device_prefetch_size):
    self.device = device
    self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
    self.queue = kq.Queue(maxsize=device_prefetch_size)


class PerDeviceLoader(object):

  def __init__(self, loader, device, experimental_steps_per_execution=1):
    self._loader = loader
    self._device = device
    self._steps_count = 0
    self._experiment_steps_per_execution = experimental_steps_per_execution

  def __iter__(self):
    return self

  def __next__(self):
    return self.next()

  def __len__(self):
    return self._loader.per_device_samples()

  def next(self):
    if self._steps_count % self._experiment_steps_per_execution == 0:
      xm.mark_step()
      self._steps_count = 0
    else:
      self._steps_count += 1
    item = self._loader.next_item(self._device)
    if item is None:
      raise StopIteration
   
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值