sklearn模型的保存和加载API
from sklearn.externals import joblib
- 保存 :joblib.dump(estimator, ‘test.pkl’)
- 加载 :estimator = joblib.load(‘test.pkl’)
- 注意:
- 1.保存⽂件,后缀名是**.pkl
- 2.加载模型是需要通过⼀个变量进⾏承接
线性回归的模型保存加载案例
# 导入模块
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, SGDRegressor, RidgeCV, Ridge
from sklearn.metrics import mean_squared_error
import joblib
# 取消警告
import warnings
warnings.filterwarnings('ignore')
def dump_load():
"""
模型保存和加载
:return:None
"""
# 1.获取数据
boston = load_boston()
# 2.数据基本处理
# 2.1 分割数据
x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target, random_state=22, test_size=0.2)
# 3.特征工程-标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.fit_transform(x_test)
# 4.机器学习-线性回归
# 4.1 模型训练
estimator = RidgeCV(alphas=(0.001, 0.01, 0.1, 1, 10, 100))
estimator.fit(x_train, y_train)
# 4.2 模型保存
joblib.dump(estimator, "./data/test.pkl")
# 4.3 模型加载
estimator = joblib.load("./data/test.pkl")
print("这个模型的偏置是:\n", estimator.intercept_)
print("这个模型的系数是:\n", estimator.coef_)
# 5.模型评估
# 5.1 预测值
y_pre = estimator.predict(x_test)
print("预测值是:\n", y_pre)
# 5.2 均方误差
ret = mean_squared_error(y_test, y_pre)
print("均方误差:\n", ret)
dump_load()
这个模型的偏置是: 22.579702970297042 这个模型的系数是: [-0.63459715 0.94773129 -0.31920213 0.88529186 -1.74178644 2.79245589 -0.20323347 -3.01367932 1.95216601 -1.14280797 -1.60928268 0.90244836 -3.67372353] 预测值是: [27.70161604 30.69198547 20.83293806 31.37993352 19.06886905 18.46499899 20.76434753 18.2087364 18.48509569 31.9167732 20.57736653 26.88095143 15.13391718 19.38341927 35.95783246 18.2489176 8.30254898 17.59236763 29.48841298 23.38417563 18.41239891 32.99997871 28.11719645 17.2524964 33.93017182 25.76855298 34.07662551 26.12377425 18.98332919 13.70504553 29.75130577 14.08245618 36.66838496 9.42992218 15.39829098 16.38059949 8.13674968 19.23355789 38.86737442 27.71325959 24.28681416 17.02344926 38.03981891 6.57716147 21.20739634 24.15864691 19.22136751 20.02583079 15.72281764 26.27039269 8.89642718 26.59251088 29.12520923 16.78458923 8.68635424 34.47348662 31.56930862 21.30448809 16.47327242 20.52293343 22.84054023 23.25324333 19.32834221 37.0459777 24.29366892 19.26138453 13.19474709 6.9174831 40.91101695 20.93104583 16.36981652 21.14627646 39.65318001 20.56130518 35.80757432 26.57356025 20.10938353 19.71277372 24.15586543 21.65389104 30.50234679 19.15218487 22.61455148 30.60026088 26.32261595 20.39646938 28.25519618 20.53674089 26.04721145 18.4195752 24.63530895 22.69959017 19.19323207 19.25652691 14.6268253 17.69199436 23.75279201 16.00619317 20.24243978 26.13593095 20.44423499 17.41745467] 均方误差: 20.935600218333263