Python数据分析:scikit-learn基础(一)
基本步骤:
-
加载示例数据集
- digits
-
在训练集上训练模型
- SVM模型
- LR模型
- .fit() 训练模型
-
在测试集上测试模型
- .predict() 进行预测
-
保存模型
- pickle.dumps()
-
加载模型预测
加载示例数据集
from sklearn import datasets
digits = datasets.load_digits()
# 查看数据集digits
print(digits.data)
print(digits.data.shape)
print(digits.target_names)
print(digits.target)
运行:
训练模型
# 手动划分训练集、测试集
n_test = 100 # 测试样本个数
train_X = digits.data[:-n_test, :]
train_y = digits.target[:-n_test]
test_X = digits.data[-n_test:, :]
y_true = digits.target[-n_test:]
# 选择SVM模型
from sklearn import svm
svm_model = svm.SVC(gamma=0.001, C=100.)
# 训练模型
svm_model.fit(train_X, train_y)
运行:
# 选择LR模型
from sklearn.linear_model import LogisticRegression
lr_model = LogisticRegression()
# 训练模型
lr_model.fit(train_X, train_y)
运行:
测试模型:
y_pred_svm = svm_model.predict(test_X)
y_pred_lr = lr_model.predict(test_X)
# 查看结果
from sklearn.metrics import accuracy_score
print('SVM结果:', accuracy_score(y_true, y_pred_svm))
print('LR结果:', accuracy_score(y_true, y_pred_lr))
运行:
保存模型
import pickle
#保存模型
with open('svm_model.pkl', 'wb') as f:
pickle.dump(svm_model, f)
加载模型预测
import numpy as np
# 重新加载模型进行预测
with open('svm_model.pkl', 'rb') as f:
model = pickle.load(f)
random_samples_index = np.random.randint(0, 1796, 5)
random_samples = digits.data[random_samples_index, :]
random_targets = digits.target[random_samples_index]
random_predict = model.predict(random_samples)
print(random_predict)
print(random_targets)
运行: