一、机器学习模型机器学习模型使用pickle来保存和加载
# 引用
import pickle
# 读取模型
with open('./model/xgb2.pkl', 'rb') as f:
xgb1 = pickle.load(f)
# 使用模型
xgb1.fit(...)
# 保存模型
with open('./model/xgb3.pkl', 'wb') as f:
pickle.dump(xgb1, f)
二、深度学习模型使用tensorflow的save()方法保存和加载模型
import tensorflow as tf
# 训练
...
# 创建Saver()节点
saver = tf.train.Saver()
with tf.Session() as sess:
# 恢复sess
ckpt = tf.train.get_checkpoint_state('./ckpt/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
...
save_path = saver.save(sess, "./ckpt/my_model_final.ckpt")