1. 线性SVM决策过程可视化
1、导入需要的模块
from sklearn.datasets import make_blobs
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np
2、实例化数据集,可视化数据
X,y = make_blobs(n_samples= 500
,centers= 2
,random_state= 0
,cluster_std= 0.6)
plt.scatter(X[:,0],X[:,1]
,c= y
,s= 50
,cmap= 'rainbow')
ax = plt.gca()
plt.xticks([])
plt.yticks([])
plt.show()
3、制作网格
xlim = ax.get_xlim()
ylim = ax.get_ylim()
axisx = np.linspace(xlim[0],xlim[1],30)
axisy = np.linspace(ylim[0],ylim[1],30)
axisy,axisx = np.meshgrid(axisy,axisx)
xy = np.vstack([axisx.ravel(), axisy.ravel()]).T
plt.scatter(xy[:,0],xy[:,1],s=1,cmap="rainbow")
4、建模,计算决策边界并画出等高线
clf = SVC(kernel = "linear").fit(X,y)
Z = clf.decision_function(xy).reshape(axisx.shape)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
ax = plt.gca()
ax.contour(axisx,axisy,Z
,colors="k"
,levels=[-1,0,1]
,alpha=0.5
,linestyles=["--","-","--"])
ax.set_xlim(xlim)
ax.set_ylim(ylim)
-5、将绘图过程包装成函数
def plot_svc_decision_function(model,ax=None):
if ax is None:
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
x = np.linspace(xlim[0],xlim[1],30)
y = np.linspace(ylim[0],ylim[1],30)
Y,X = np.meshgrid(y,x)
xy = np.vstack([X.ravel(), Y.ravel()]).T
P = model.decision_function(xy).reshape(X.shape)
ax.contour(X, Y, P,colors="k",levels=[-1,0,1],alpha=0.5,linestyles=["--","-","--"])
ax.set_xlim(xlim)
ax.set_ylim(ylim)
clf = SVC(kernel = "linear").fit(X,y)
plt.scatter(X[:,0],X[:,1],c=y,s=50,cmap="rainbow")
plot_svc_decision_function(clf)
------------------------------------------------------------------
6、当处理非线性数据时候
r = np.exp(-(X**2).sum(1))
rlim = np.linspace(min(r),max(r),100)
from mpl_toolkits import mplot3d
def plot_3D(elev=30 ,azim=30 , X=X,y=y):
ax = plt.subplot(projection= '3d')
ax.scatter3D(X[:,0],X[:,1]
,r
,c= y
,s= 50
,cmap= 'rainbow')
ax.view_init(elev=elev , azim=azim)
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('r')
plt.show()
from ipywidgets import interact,fixed
interact(plot_3D,elev=[0,30],azip=[-180,180],X=fixed(X),y=fixed(y))
plt.show()
2.探索核函数在不同数据集上的表现
1、导入所需要的库和模块
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import svm
from sklearn.datasets import make_circles , make_moons , make_blobs , make_classification
------------------------------------------------------------------
2、创建数据集,定义核函数的选择
n_samples = 100
datasets = [
make_moons(n_samples= n_samples , noise=0.2 , random_state= 0),
make_circles(n_samples= n_samples , noise= 0.2 , factor= 0.5 , random_state= 0),
make_blobs(n_samples= n_samples , centers= 2 , random_state= 5),
make_classification(n_samples= n_samples , n_features= 2 , n_informative= 2 , n_redundant= 0 ,random_state= 5)
]
kernel = ['linear' ,'poly','rbf','sigmoid']
------------------------------------------------------------------
3、构建子图
nrows=len(datasets)
ncols=len(Kernel) + 1
fig, axes = plt.subplots(nrows, ncols,figsize=(20,16))
4、开始进行子图循环
for ds_cnt, (X,Y) in enumerate(datasets):
ax = axes[ds_cnt, 0]
if ds_cnt == 0:
ax.set_title("Input data")
ax.scatter(X[:, 0], X[:, 1], c=Y, zorder=10, cmap=plt.cm.Paired,edgecolors='k')
ax.set_xticks(())
ax.set_yticks(())
for est_idx, kernel in enumerate(Kernel):
ax = axes[ds_cnt, est_idx + 1]
clf = svm.SVC(kernel=kernel, gamma=2).fit(X, Y)
score = clf.score(X, Y)
ax.scatter(X[:, 0], X[:, 1], c=Y
,zorder=10
,cmap=plt.cm.Paired,edgecolors='k')
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=50,
facecolors='none', zorder=10, edgecolors='k')
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
XX, YY = np.mgrid[x_min:x_max:200j, y_min:y_max:200j]
Z = clf.decision_function(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
ax.pcolormesh(XX, YY, Z > 0, cmap=plt.cm.Paired)
ax.contour(XX, YY, Z, colors=['k', 'k', 'k'], linestyles=['--', '-', '--'],
levels=[-1, 0, 1])
ax.set_xticks(())
ax.set_yticks(())
if ds_cnt == 0:
ax.set_title(kernel)
ax.text(0.95, 0.06, ('%.2f' % score).lstrip('0')
, size=15
, bbox=dict(boxstyle='round', alpha=0.8, facecolor='white')
, transform=ax.transAxes
, horizontalalignment='right'
)
plt.tight_layout()
plt.show()