我们希望训练好之后的模型,可以保存下来,然后需要预测新数据的时候,就可以拿来用,可以这样做。
我们以线性回归的例子来讲:
1,训练并保存模型。
import mxnet as mx
import numpy as np
import logging
logging.getLogger().setLevel(logging.DEBUG)
# Training data
train_data = np.random.uniform(0, 1, [100, 2])
train_label = np.array([train_data[i][0] + 2 * train_data[i][1] for i in range(100)])
batch_size = 1
num_epoch=5
# Evaluation Data
eval_data = np.array([[7,2],[6,10],[12,2]])
eval_label = np.array([11,26,16])
train_iter = mx.io.NDArrayIter(train_data,train_label, batch_size, shuffle=True,label_name='lin_reg_label')
eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)
X = mx.sym.Variable('data')
Y = mx.sym.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden = 1)
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
model = mx.mod.Module(
symbol = lro ,
data_names=['data'],
label_names = ['lin_reg_label'] # network structure
)
model.fit(train_iter, eval_iter,
optimizer_params={'learning_rate':0.005, 'momentum': 0.9},
num_epoch=50,
eval_metric='mse',)
model.predict(eval_iter).asnumpy()
metric = mx.metric.MSE()
model.score(eval_iter, metric)
keys = model.get_params()[0].keys() # 列出所有权重名称
print(keys)
conv_w = model.get_params()[0]['fc1_weight'] # 获取想要查看的权重信息,如conv_weight
bias = model.get_params()[0]['fc1_bias']
print(conv_w.asnumpy()) # 查看具体数值
print(bias.asnumpy())
# save model, test stands for prefix of model, num_epoch stands for the epoch number of the model
model.save_checkpoint('test',num_epoch) # 保存模型
运行结果为:
dict_keys(['fc1_weight', 'fc1_bias'])
[[ 0.99999714 1.99999332]]
INFO:root:Saved checkpoint to "test-0005.params"
被保存下来的文件分别是:
test-symbol.json
test-num_epoch.params
2,下载模型并使用。
import mxnet as mx
import numpy as np
batch_size = 1
num_batch = 5
# Adding 0.1 to each of the values
eval_data = np.array([[7,2],[6,10],[12,2]])
eval_label = np.array([11.1,26.1,16.1])
eval_iter = mx.io.NDArrayIter(eval_data, eval_label, batch_size, shuffle=False)
# load model
sym,arg_params,aux_params = mx.model.load_checkpoint('test', 5)
mod = mx.mod.Module(symbol=sym,context=mx.gpu(),data_names=['data'],label_names=['lin_reg_label'])
mod.bind(for_training=False,data_shapes=[('data', (1, 2))])
mod.set_params(arg_params,aux_params)
# use model
predict_stress = mod.predict(eval_iter, num_batch)
print(predict_stress.asnumpy())
运行结果为:
[[ 10.99997139]
[ 25.9999218 ]
[ 15.99995708]]