sklearn导包可视化决策边界和超平面
import numpy as np
import matplotlib.pyplot as plt
import os
import matplotlib
%matplotlib inline
plt.rcParams['axes.labelsize']=14
plt.rcParams['xtick.labelsize']=12
plt.rcParams['ytick.labelsize']=12
import warnings
warnings.filterwarnings('ignore')
from sklearn.svm import SVC
from sklearn import datasets
iris = datasets.load_iris()
X = iris['data']
y = iris['target']
X = X[:,(2,3)] #拿两个特征
display(X.shape,y.shape)
print(y)
(150, 2)
(150,)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
有三个类别,这里只拿两个类别来展示二分类
index = (y==0) | (y==1) #拿两个类别
X_train = X[index]
y_train = y[index]
display(X_train.shape,y_train.shape,y_train)
(100, 2)
(100,)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
#C软间隔容忍程度
svm_model = SVC(kernel='linear',C=0.5)
svm_model.fit(X_train,y_train)
C值越大,分类器会减少误分类,但最终间隔会较小
C值较小,最终间隔会较大,但会有较多的样本点出现在间隔内
决策边界:
x
0
∗
w
0
+
x
1
∗
w
1
+
b
=
0
x_0*w_0+x_1*w_1+b=0
x0∗w0+x1∗w1+b=0
x 1 = − x 0 ∗ w 0 − b w 1 x_1 = \frac{-x_0*w_0 - b} {w_1} x1=w1−x0∗w0−b
#画线
def plot_svc_decision_boundary(svm_cls,xmin,xmax,sv=True):
w = svm_cls.coef_[0]
b = svm_cls.intercept_
#数据最小值和最大值的取值范围,取两百个点
x0 = np.linspace(xmin,xmax,200)
#决策边界 x1
decision_boundary = -w[0]/w[1] * x0 - b/w[1]
margin = 1 /w[1]
#上边界虚线
up_line = decision_boundary + margin
#下边界虚线
low_line = decision_boundary - margin
#画支持向量
if sv:
svs = svm_cls.support_vectors_
plt.scatter(svs[:,0],svs[:,1],s=180,facecolors="#FFAAAA")
plt.plot(x0,decision_boundary,'k-',linewidth=2)
plt.plot(x0,up_line,'k--',linewidth=2)
plt.plot(x0,low_line,'k--',linewidth=2)
plt.figure(figsize=(14,4))
#画线
plot_svc_decision_boundary(svm_model,0,5.5)
#画样本点
plt.plot(X[:,0][y==1],X[:,1][y==1],'bs')
plt.plot(X[:,0][y==0],X[:,1][y==0],'ys')
plt.axis([0,5.5,0,2])