Sequential 类的设备迁移

之前因为RNN模块 没有export 方法,直接用了 cpickle 强行保存。现在要载入保存的数据,用于inference。需要解决训练时的context和 载入时 device不一致的问题。
找了下,发现ParameterDict里面有个 reset_ctx可以用:

import mxnet as mx                                                                                                                                           
import numpy as np
nn = mx.gluon.nn
net = nn.Sequential()
net.add(\
    ¦   nn.Dense(10))
ctx = mx.cpu()
_x = np.random.randint(0,256,(5,199))
x = mx.nd.array(_x)
net.initialize()


y= net(x)
print y
print('cross to gpu device...')
ctx = mx.gpu()
x = x.as_in_context( ctx )
try:
    y = net(x)
except:
    print 'forward failed, try reset_ctx for ParameterDict...'
    net.collect_params().reset_ctx( ctx )
y= net(x)
print y
print 'test ok'

转载于:https://www.cnblogs.com/chenyliang/p/9493448.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值