一、
import numpy as np
import pandas as pd
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit
digits = load_digits() # 手写数字数据集
X, y = digits.data, digits.target
print(pd.DataFrame(X).head())
print(np.unique(y))
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, test_size=0.3, random_state=42)
print("Train:", Xtrain.shape, "Test:", Xtest.shape)
gnb = GaussianNB()
gnb.fit(Xtrain, ytrain)
acc = gnb.score(Xtest, ytest)
print("Accuracy:", acc)
prob = gnb.predict_proba(Xtest)
print(prob.shape, prob[0, :].sum()) # [540, 10] 1
y_pred = gnb.predict(Xtest)
print((prob.argmax(axis=1) == y_pred).sum()) # 540
cm = confusion_matrix(ytest, y_pred)
print("Confusion Matrix", cm) # 混淆矩阵
gnb = GaussianNB()
cv = ShuffleSplit(n_splits=50, test_size=0.2, random_state=42)
train_sizes, train_scores, test_scores = learning_curve(gnb, X, y, cv=cv, n_jobs=4)
print(train_sizes) # 每次分训练集和测试集建模之后,训练集上的样本数量
print(np.mean(train_scores, axis=1), np.mean(test_scores, axis=1))
二、
- (score and time)
import datetime
from time import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve # 导入画学习曲线的类
from sklearn.model_selection import ShuffleSplit # 导入设定交叉验证模式的类
def plot_learning_curve(model, title, X, y,
ax, # 选择子图
ylim=None, # 设置纵坐标的取值范围
cv=None, # 交叉验证
n_jobs=None # 设定所要使用的线程
):
train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=cv, n_jobs=n_jobs)
ax.set_title(title)
if ylim is not None:
ax.set_ylim(*ylim)
ax.set_xlabel("Training examples")
ax.set_ylabel("Score")
ax.grid() # 显示网格作为背景
ax.plot(train_sizes, np.mean(train_scores, axis=1), 'o-'
, color="r", label="Training score")
ax.plot(train_sizes, np.mean(test_scores, axis=1), 'o-'
, color="g", label="Test score")
ax.legend(loc="best")
return ax
digits = load_digits()
X, y = digits.data, digits.target
titles = ["Naive Bayes", "DecisionTree", "SVM", "RandomForest", "Logistic"]
models = [GaussianNB(), DecisionTreeClassifier(), SVC(gamma=0.001), RandomForestClassifier(n_estimators=50), LogisticRegression(C=.1, multi_class="auto", solver='lbfgs', max_iter=5000)]
cv = ShuffleSplit(n_splits=50, test_size=0.2, random_state=42)
fig, axes = plt.subplots(1, 5, figsize=(30, 6))
for index, title, model in zip(range(len(titles)), titles, models):
time_start = time()
plot_learning_curve(model, title, X, y, ax=axes[index], ylim=[0.7, 1.05], cv=cv, n_jobs=4)
time_end = time()
print("{}:{}".format(title, datetime.datetime.fromtimestamp(time_end-time_start).strftime("%M:%S:%f")))
plt.show()
time
score
三、
- (log loss)
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from sklearn.metrics import brier_score_loss # 布里尔分数
digits = load_digits() # 手写数字数据集
X, y = digits.data, digits.target
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, test_size=0.3, random_state=42)
gnb = GaussianNB()
gnb.fit(Xtrain, ytrain)
gnb_prob = gnb.predict_proba(Xtest)
print(gnb_prob.shape) # (540,10)
lr = LogisticRegression(C=1., solver='lbfgs', max_iter=5000, multi_class="auto")
lr.fit(Xtrain, ytrain)
lr_prob = lr.predict_proba(Xtest)
print(lr_prob.shape) # (540,10)
svc = SVC(kernel="linear", gamma=1)
svc.fit(Xtrain, ytrain)
svc_df = svc.decision_function(Xtest)
svc_prob = (svc_df - svc_df.min())/(svc_df.max() - svc_df.min())
print(svc_prob.shape) # (540,10)
print(log_loss(ytest, gnb_prob), log_loss(ytest, lr_prob), log_loss(ytest, svc_prob)) # 对比多个模型
# print(brier_score_loss(y_true=ytest, y_prob=gnb_prob)) # error