很多情况下,线性SVM就非常好用。但是当数据为非线性时,此时就要考虑非线性SVM。
本文主要是考虑如何对非线性SVM使用。
那么,处理非线性SVM有两种方式:
1. 增加新的复杂的特征,这对于某些问题是非常好用的。实质上,在以前写过的多项式回归就是基于这样的思想来处理非线性复杂问题,这里同样是这样的思想。
2.利用多项式的核技巧 (SVC) (这里它产生的结果就像跟方式1是一样的,但是实质上并没有添加任何的特征)
方式1:
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
import numpy as np
import matplotlib.pyplot as plt
#构造球型数据集
(X,y)=make_moons(200,noise=0.2)
#利用pipeline机制
poly_svm_clf=Pipeline((
("poly_feature",PolynomialFeatures(degree=3)),
("scaler",StandardScaler()),
("svm_clf",LinearSVC(C=10,loss="hinge"))
))
poly_svm_clf.fit(X,y)
xx, yy = np.meshgrid(np.arange(-2,3,0.01), np.arange(-1,2,0.01))
y_new=poly_svm_clf.predict(np.c_[xx.ravel(),yy.ravel()])
plt.contourf(xx, yy, y_new.reshape(xx.shape),cmap="PuBu")
plt.scatter(X[:,0],X[:,1],marker="o",c=y)
plt.axis([-1.5,2.5,-1.0,1.5])
结果为;
结果分析:上述结果表明,仍然有少量样本的错误分类,但是大体上能够分成两个类别。另外你可以利用这个训练好的模型进行相应的数据点的预测:poly_svm_clf.predict()
poly_svm_clf.predict([[0,0.5]])
array([0], dtype=int64)
预测结果表明,确实它属于上面红色点的那个类别。
方式2:
方式2的主要步骤是抛弃了上面的新特征的产生,主要利用核技巧,具体利用Scikit-learn的SVC包
from sklearn.svm import SVC
poly_kernel_svm=Pipeline((
("scaler",StandardScaler()),
("svm_clf",SVC(kernel="poly",degree=3,coef0=1,C=5))
))
poly_kernel_svm.fit(X,y)
xx, yy = np.meshgrid(np.arange(-2,3,0.01), np.arange(-1,2,0.01))
y_new=poly_kernel_svm.predict(np.c_[xx.ravel(),yy.ravel()])
plt.contourf(xx, yy, y_new.reshape(xx.shape),cmap="PuBu")
plt.scatter(X[:,0],X[:,1],marker="o",c=y)
plt.axis([-1.5,2.5,-1.0,1.5])
结果表明与上面的方式1大致上是相同的。
另外我们也可以控制SVC函数里的参数,让其更充分的拟合
from sklearn.svm import SVC
poly_kernel_svm=Pipeline((
("scaler",StandardScaler()),
("svm_clf",SVC(kernel="poly",degree=10,coef0=1,C=5))
))
poly_kernel_svm.fit(X,y)
xx, yy = np.meshgrid(np.arange(-2,3,0.01), np.arange(-1,2,0.01))
y_new=poly_kernel_svm.predict(np.c_[xx.ravel(),yy.ravel()])
plt.contourf(xx, yy, y_new.reshape(xx.shape),cmap="PuBu")
plt.scatter(X[:,0],X[:,1],marker="o",c=y)
plt.axis([-1.5,2.5,-1.0,1.5])
发现了没,通过更改SVC的degree=10. 使用了一个10阶多项式核训练SVM分类器,结果比上述拟合更充分,但是要避免其过拟合。
最后,不要直接运行方式2的代码,会报错.