"""
@author: Vincnet_Sheng
@file: sklearn_cross_validation-2.py
@time: 2018/1/4 0004 下午 8:45
#-*- coding: utf-8 -*
"""
# target: 1) 相同gamma值情况下,test_error和train_error相对于train集大小的变化曲线
# (gamma是SVM.SVC算法中的核函数参数,关系到是否overfitting,应该类似于多项式中的最高次数,)
# 2)改变gamma值,继续观察二者的变化曲线
from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
# load data
digits = load_digits()
X = digits.data
y = digits.target
# train_sizes: 在train中分别取0.1,0.25,0.5...时train集的数量大小
train_sizes, train_loss, test_loss = learning_curve(
SVC(gamma=0.01), X, y, cv=10, # cv=10即划分10次不同的train和test集
scoring='neg_mean_squared_error', # mean_squared_error renamed to neg_mean_squared_error, 输出为负数
train_sizes=[0.1, 0.25, 0.5, 0.75, 1]) # 在train集中训练数据总量的0.1,0.25,0.5...分别输出
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
# 图像化
plt.plot(train_sizes, train_loss_mean, 'o-', color='r',
label='Training')
plt.plot(train_sizes, test_loss_mean, 'o-', color='g',
label='Cross_Validation')
plt.xlabel('Training examples')
plt.ylabel('Loss')
plt.legend(loc='best')
plt.show()
output