一、GridSearchCV
- 将网格搜索和交叉验证放在一起进行。
- 网格搜索用于超参数调优。
- 交叉验证用于模型泛化性能验证,交叉验证不会提高模型精度。
from statistics import mean
import joblib
import pandas as pd
import seaborn as sns
from sklearn.svm import SVC
from sklearn import metrics
import datetime
from imblearn.over_sampling import SMOTE
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
import xgboost as xgb
from sklearn import tree
from sklearn.utils import stats
from sklearn import ensemble
from sklearn import svm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
from sklearn.model_selection import RepeatedKFold
# 交叉验证初始化
rbk = RepeatedKFold(n_splits=5, n_repeats=1, random_state=12) # scoring默认是acc, scoring='f1_macro'
# 开始网格搜索和交叉验证
clf_svm = GridSearchCV(svm.SVC(
class_weight='balanced',
decision_function_shape='ovo',
probability=True),
param_grid, scoring="accuracy", cv=rbk)
clf_svm.fit(X_train, Y_train)
二、cross_val_score
一般用于获取每折的交叉验证的得分,进而得知模型的一般泛化性能。然后根据这个得分为模型选择合适的超参数,通常需要编写循环手动完成交叉验证过程。
from statistics import mean
import joblib
import pandas as pd
import seaborn as sns
from sklearn.svm import SVC
from sklearn import metrics
import datetime
from imblearn.over_sampling import SMOTE
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
import xgboost as xgb
from sklearn import tree
from sklearn.utils import stats
from sklearn import ensemble
from sklearn import svm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
from sklearn.model_selection import RepeatedKFold
n = 5
SKF = KFold(n_splits=n, shuffle=True, random_state=42)
score_SVM = cross_val_score(clf_svm, X, y, cv=SKF) # K折交叉验证,也可以把SKF改为10
print('——>交叉验证分数', score_SVM)
# 获取置信区间。(也就是均值和方差),std()计算标准偏差Accuracy: 0.98 (+/- 0.03)
# print("——>10折交叉验证 Mean Accuracy: %0.4f (+/- %0.4f)" % (score_SVM.mean(), score_SVM.std() * 2))
print("——>5折交叉验证 Mean Accuracy: %0.4f" % (score_SVM.mean()))
mean_score = [score_SVM.mean()] * 5
plt.plot(range(1, n + 1), score_SVM, label='K-Score')
plt.plot(range(1, n + 1), mean_score, label='MeanScore')
plt.legend()
# plt.savefig('../figureResult/train_svm/{}-fold.jpg'.format(n), dpi=800)
plt.show()
三、总结
GridSearchCV :
除了自行完成叉验证外,还返回了最优的超参数及对应的最优模型
所以相对于cross_val_score来说,GridSearchCV在使用上更为方便;但是对于细节理解上,手动实现循环调用cross_val_score会更好些。
cross_val_score :
一般用于获取每折的交叉验证的得分,然后根据这个得分为模型选择合适的超参数,通常需要编写循环手动完成交叉验证过程。