之前因为RNN模块 没有export 方法,直接用了 cpickle 强行保存。现在要载入保存的数据,用于inference。需要解决训练时的context和 载入时 device不一致的问题。
找了下,发现ParameterDict里面有个 reset_ctx可以用:
import mxnet as mx
import numpy as np
nn = mx.gluon.nn
net = nn.Sequential()
net.add(\
¦ nn.Dense(10))
ctx = mx.cpu()
_x = np.random.randint(0,256,(5,199))
x = mx.nd.array(_x)
net.initialize()
y= net(x)
print y
print('cross to gpu device...')
ctx = mx.gpu()
x = x.as_in_context( ctx )
try:
y = net(x)
except:
print 'forward failed, try reset_ctx for ParameterDict...'
net.collect_params().reset_ctx( ctx )
y= net(x)
print y
print 'test ok'