在乳腺癌数据上探索核函数的性质:
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_breast_cancer
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from time import time
import datetime
data = load_breast_cancer()
X = data.data
y = data.target
X.shape
np.unique(y)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=420)
kernel = ['linear','poly','rbf','sigmoid']
for kernel in kernel:
time0 = time()
clf= SVC(kernel = kernel, gamma="auto", degree = 1, cache_size=5000).fit(Xtrain,Ytrain)
print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
print('需要花费的时间为:',datetime.datetime.fromtimestamp(time()-time0).strftime('%M:%S:%f'))
对数据进行无量纲化 :
X = StandardScaler().fit_transform(X)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3,random_state=420)
kernel = ['linear','poly','rbf','sigmoid']
for kernel in kernel:
time0 = time()
clf= SVC(kernel = kernel, gamma="auto", degree = 1, cache_size=5000).fit(Xtrain,Ytrain)
print("The accuracy under kernel %s is %f" % (kernel,clf.score(Xtest,Ytest)))
print('需要花费的时间为:',datetime.datetime.fromtimestamp(time()-time0).strftime('%M:%S:%f'))
# SVM 执行之前,先进行数据的无量纲化
可以到准确率和效率都有了明显的提升。
对核函数进行参数调优:
核函数相关参数:
对rbf,高斯核进行调优:
gamma_range = np.logspace(-10,1,50) # 返回对数刻度上均匀间隔的数字
score = []
for i in gamma_range:
clf = SVC(kernel='rbf',gamma=i,cache_size=5000).fit(Xtrain,Ytrain)
score.append(clf.score(Xtest,Ytest))
print(max(score),gamma_range[score.index(max(score))])
plt.plot(gamma_range,score)
plt.show()
对多项式和,poly进行调优:
from sklearn.model_selection import StratifiedShuffleSplit,GridSearchCV
time0 = time()
gamma_range = np.logspace(-10,1,50)
coef0_range = np.linspace(0,5,10)
param_grid = dict(gamma=gamma_range,coef0=coef0_range)
cv = StratifiedShuffleSplit(n_splits=5,test_size=0.3,random_state=420)
grid = GridSearchCV(SVC(kernel='poly',degree=1,cache_size=5000),param_grid=param_grid,cv=cv)
grid.fit(X,y)
print('网格搜索中最优的参数为%s,最高的准确率为%f'%(grid.best_params_,grid.best_score_))
print('消耗的时间为:',datetime.datetime.fromtimestamp(time()-time0).strftime('%M:%S:%f'))
重要参数C: