【Pytorch】RuntimeError: arguments are located on different GPUs

0x00 前言

Pytorch里使用optimizer的时候,由于其会记录step等信息,
有时会希望将optimizer的内容记录下来,以备之后继续使用,
那么自然而然的会想到使用API中自带的
torch.save(object, path)
torch.load(path)

再配合上
optimizer.state_dict()
optimizer.load_state_dict(obj)
来实现这一需求了~

于是,大家自然而然地会自信满满敲出如下这样的语句——

torch.save(optimizer.state_dict(), path)
optimizer.load_state_dict(torch.load(path))

并收获如下的Error——

RuntimeError                              Traceback (most recent call last)
<ipython-input-160-19f8d61b5e53> in <module>()
     37         optimizer.zero_grad()
     38         loss.backward()
---> 39         optimizer.step()
     40         print(model.state_dict()['linear_layer.weight'])
     41 

/usr/local/anaconda2/lib/python2.7/site-packages/torch/optim/adam.pyc in step(self, closure)
     63 
     64                 # Decay the first and second moment running average coefficient
---> 65                 exp_avg.mul_(beta1).add_(1 - beta1, grad)
     66                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
     67 

RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

0x01 解决方案

二话不说上解决方案是我的习惯

# Load from dict
optimizer.load_state_dict(check_point['optim'])

# Load from file
optimizer.load_state_dict(torch.load(optim_path))

# Add this
for state in optimizer.state.values():
    for k, v in state.items():
        print (type(v))
        if torch.is_tensor(v):
            state[k] = v.cuda(cuda_id)

0x02 原理解释

然后在慢慢的讲为啥子~
首先,这个方案是我在Issue中翻看到的:
Thanks to pytorch/issues/2830

可以这么理解,举例说明,虽说你之前是放在GPU3上的,数据类型叫做 cuda.Tensor(GPU 3),
但是天晓得你这个GPU3是哪台机器上的GPU3哦,机器问了一下GPU3:是不是你家的啊,
GPU3看了一眼计算完被打扫干净的战场,已经空无一物——“不是吧,我家没人啊”,
然后就委婉的拒绝了它。

所以,我们可以对load完毕的optimizer逐个询问,只要是个tensor,我们就再把它介绍给GPU3一次~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

糖果天王

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值