在做模型训练的时候,尤其是在训练集上做交叉验证,通常想要将模型保存下来,然后放到独立的测试集上测试,下面介绍的是Python中训练模型的保存和再使用。
模型保存(pickle模块和joblib模块)
使用pickle模块或者sklearn内部的joblib
一、使用pickle模块
from sklearn import svm
from sklearn import datasets
clf=svm.SVC()
iris=datasets.load_iris()
X,y=iris.data,iris.target
clf.fit(X,y)
import pickle
s=pickle.dumps(clf)
f=open('svm.txt','w')
f.write(s)
f.close()
f2=open('svm.txt','r')
s2=f2.read()
clf2=pickle.loads(s2)
clf2.score(X,y)
二、使用joblib
joblib更适合大数据量的模型,且只能往硬盘存储,不能往字符串存储
from sklearn.externals import joblib
joblib.dump(clf,'filename.pkl')
clf=joblib.load('filename.pkl')
三、具体例子如下:
scikit-learn已经有了模型持久化的操作,导入joblib即可
from sklearn.externals import joblib
模型保存
通过joblib的dump可以将模型保存到本地,clf是训练的分类器
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.externals import joblib
def test_save_model(self):
model_save_path = "./model_save/"
train_X = [[0, 0], [1, 1]]
train_y = [0, 1]
print "Start LR method ..."
print('LR Train classifier...')
clf = LogisticRegression()
clf.fit(train_X, train_y)
print "LR Model save..."
save_path_name = model_save_path + "lr_" + "train_model.m"
self.is_exist(model_save_path, save_path_name)
joblib.dump(clf, save_path_name)
clf = joblib.load(save_path_name)
print('LR Predict...')
pred = clf.predict_proba(train_X)
submit_csv_name = model_save_path + "lr" + '_submission.csv'
self.make_submission(pred[:, 0], submit_csv_name)
print "Start SVM method ..."
# 训练
print('SVM Train classifier...')
from sklearn import svm
clf = svm.SVC()
clf.fit(train_X, train_y)
# 保存
print "SVM Model save..."
save_path_name=model_save_path+"svm_"+"train_model.m"
self.is_exist(model_save_path,save_path_name)
joblib.dump(clf, save_path_name)
clf = joblib.load(save_path_name)
# 预测
print('SVM Predict...')
pred=clf.predict(train_X)
submit_csv_name = model_save_path+"svm" + '_submission.csv'
self.make_submission(pred, submit_csv_name)
train_X = [[0, 1], [1, 1]]
train_y = [0, 1]
print clf.score(train_X, train_y, sample_weight=None)
模型从本地调回
clf = joblib.load("train_model.m")
通过joblib的load方法,加载保存的模型。
然后就可以在测试集上测试了
clf.predict(test_X) #此处test_X为特征集
参考: https://blog.csdn.net/Dream_angel_Z/article/details/47175373