- 交叉验证(CrossValidation)是常用的机器学习训练手段,可以有效检验一个模型的泛化能力。交叉验证需要将原始数据集平等地划分为若干份,例如常用的10折交叉验证,10-folds CV 指的是将数据集分为10份,然后进行10次训练,每次取出一份数据作为测试集,剩下的作为训练集,得到10个模型,最终将10个模型的预测值做一个平均。
具体python代码如下:
def plot_cross_val(rf4, train_x, train_y,cv_num,path_out):
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
evaluate_vars = ['roc_auc','precision','recall','f1']
fig = plt.figure()
for plot_num in range(len(evaluate_vars)):
ax1 = fig.add_subplot(2, 2, plot_num + 1)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None,
wspace=0.3, hspace=0.5)
try:
scores = cross_val_score(rf4, train_x, train_y, cv=cv_num, scoring=evaluate_vars[plot_num])
except ValueError:
scores = np.zeros(10)
plt.plot(range(10), scores)
plt.xlabel('num of cv')
plt.ylabel(evaluate_vars[plot_num])
plt.xticks(np.arange(0, 10, 1),fontsize=6)
plt.yticks(np.arange(0, 1.1, 0.2),fontsize=8)
plt.show()
tt = 'plot of ' + str(evaluate_vars[plot_num])
ax1.set_title(tt,fontsize=10)
plt.savefig(path_out, bbox_inches='tight', dpi=300) # bbox_inches='tight'帮助删除图片空白部分
plt.show()
if __name__ == '__main__':
path_out = 'E:/program'
plot_cross_val(rf4, train_x, train_y,10,path_out)
效果如下: