原理
线性支持向量机原始最优化问题:
m
i
n
:
1
2
∥
w
∥
2
+
C
∑
i
=
1
N
ξ
i
min:\frac{1}{2}\parallel{w}\parallel^{2}+C\sum_{i=1}^{N}\xi_{i}
min:21∥w∥2+Ci=1∑Nξi
s
.
t
.
y
i
(
w
∙
x
i
+
b
)
≥
1
−
ξ
i
,
i
=
1
,
2
,
⋯
,
N
s.t.\ y_{i}(w\bullet{x_{i}}+b)\ge1-\xi_{i},i=1,2,\cdots,N
s.t. yi(w∙xi+b)≥1−ξi,i=1,2,⋯,N
ξ
i
≥
0
,
i
=
1
,
2
,
⋯
,
N
\xi_{i}\ge0,i=1,2,\cdots,N
ξi≥0,i=1,2,⋯,N
等价于最优化问题:
m
i
n
:
C
∑
i
=
1
N
m
a
x
(
0
,
[
1
−
y
i
(
w
∙
x
i
+
b
)
]
)
+
1
2
∥
w
∥
2
min:\ C\sum_{i=1}^{N}max(0,[1-y_{i}(w\bullet{x_{i}}+b)])+\frac{1}{2}\parallel{w}\parallel^{2}
min: Ci=1∑Nmax(0,[1−yi(w∙xi+b)])+21∥w∥2
第一项为合页损失函数,C为正则化系数的倒数–惩罚系数。
sklearn实现
模型:
sklearn.svm.linearSVC(penalty=‘l2’, loss=‘squared_hinge’, *, dual=True, tol=0.0001, C=1.0, multi_class=‘ovr’, fit_intercept=True, intercept_scaling=1, class_weight=None, verbose=0, random_state=None, max_iter=1000)
主要参数:
- penalty:{‘l1’, ‘l2’}, default=’l2’,正则化方法
- loss:{‘hinge’, ‘squared_hinge’}, default=’squared_hinge’,即hinge的平方
- dual:bool, default=True,Prefer dual=False when n_samples > n_features.
- C:float, default=1.0,The strength of the regularization is inversely proportional to C.
- max_iter:int, default=1000,The maximum number of iterations to be run.
主要方法:
- decision_function(X):输入数据,输出置信水平
- fit(X_train,y_train):拟合模型
- score(X_test,y_test):输出结果
二元分类
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib notebook
def create_data(): #处理鸢尾花数据,获取前两项特征与标签进行分类
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
data = np.array(df.iloc[:100, [0, 1, -1]])
for i in range(len(data)):
if data[i,-1] == 0:
data[i,-1] = -1
# print(data)
return data[:,:2], data[:,-1]
X, y = create_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
plt.scatter(X[:50,0],X[:50,1], label='0')
plt.scatter(X[50:,0],X[50:,1], label='1')
plt.legend()
<IPython.core.display.Javascript object>
<matplotlib.legend.Legend at 0x228908ddac8>
from sklearn.svm import LinearSVC
model = LinearSVC()
model.fit(X_train,y_train)
model.score(X_test,y_test)
1.0
#绘制分类曲线
w = model.coef_
b = model.intercept_
plt.scatter(X[:50,0],X[:50,1], label='0')
plt.scatter(X[50:,0],X[50:,1], label='1')
x = np.linspace(3.0,8.0,20)
y = -w[0][0]*x/w[0][1]-b/w[0][1]
plt.plot(x,y)
plt.legend()
<IPython.core.display.Javascript object>
<matplotlib.legend.Legend at 0x22890494fd0>
多元分类–采用one vs rest方法
from sklearn.datasets import load_iris
import pandas as pd
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
data = pd.DataFrame(X,columns=feature_names)
data['labels'] = y
data
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | labels | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
5 | 5.4 | 3.9 | 1.7 | 0.4 | 0 |
6 | 4.6 | 3.4 | 1.4 | 0.3 | 0 |
7 | 5.0 | 3.4 | 1.5 | 0.2 | 0 |
8 | 4.4 | 2.9 | 1.4 | 0.2 | 0 |
9 | 4.9 | 3.1 | 1.5 | 0.1 | 0 |
10 | 5.4 | 3.7 | 1.5 | 0.2 | 0 |
11 | 4.8 | 3.4 | 1.6 | 0.2 | 0 |
12 | 4.8 | 3.0 | 1.4 | 0.1 | 0 |
13 | 4.3 | 3.0 | 1.1 | 0.1 | 0 |
14 | 5.8 | 4.0 | 1.2 | 0.2 | 0 |
15 | 5.7 | 4.4 | 1.5 | 0.4 | 0 |
16 | 5.4 | 3.9 | 1.3 | 0.4 | 0 |
17 | 5.1 | 3.5 | 1.4 | 0.3 | 0 |
18 | 5.7 | 3.8 | 1.7 | 0.3 | 0 |
19 | 5.1 | 3.8 | 1.5 | 0.3 | 0 |
20 | 5.4 | 3.4 | 1.7 | 0.2 | 0 |
21 | 5.1 | 3.7 | 1.5 | 0.4 | 0 |
22 | 4.6 | 3.6 | 1.0 | 0.2 | 0 |
23 | 5.1 | 3.3 | 1.7 | 0.5 | 0 |
24 | 4.8 | 3.4 | 1.9 | 0.2 | 0 |
25 | 5.0 | 3.0 | 1.6 | 0.2 | 0 |
26 | 5.0 | 3.4 | 1.6 | 0.4 | 0 |
27 | 5.2 | 3.5 | 1.5 | 0.2 | 0 |
28 | 5.2 | 3.4 | 1.4 | 0.2 | 0 |
29 | 4.7 | 3.2 | 1.6 | 0.2 | 0 |
... | ... | ... | ... | ... | ... |
120 | 6.9 | 3.2 | 5.7 | 2.3 | 2 |
121 | 5.6 | 2.8 | 4.9 | 2.0 | 2 |
122 | 7.7 | 2.8 | 6.7 | 2.0 | 2 |
123 | 6.3 | 2.7 | 4.9 | 1.8 | 2 |
124 | 6.7 | 3.3 | 5.7 | 2.1 | 2 |
125 | 7.2 | 3.2 | 6.0 | 1.8 | 2 |
126 | 6.2 | 2.8 | 4.8 | 1.8 | 2 |
127 | 6.1 | 3.0 | 4.9 | 1.8 | 2 |
128 | 6.4 | 2.8 | 5.6 | 2.1 | 2 |
129 | 7.2 | 3.0 | 5.8 | 1.6 | 2 |
130 | 7.4 | 2.8 | 6.1 | 1.9 | 2 |
131 | 7.9 | 3.8 | 6.4 | 2.0 | 2 |
132 | 6.4 | 2.8 | 5.6 | 2.2 | 2 |
133 | 6.3 | 2.8 | 5.1 | 1.5 | 2 |
134 | 6.1 | 2.6 | 5.6 | 1.4 | 2 |
135 | 7.7 | 3.0 | 6.1 | 2.3 | 2 |
136 | 6.3 | 3.4 | 5.6 | 2.4 | 2 |
137 | 6.4 | 3.1 | 5.5 | 1.8 | 2 |
138 | 6.0 | 3.0 | 4.8 | 1.8 | 2 |
139 | 6.9 | 3.1 | 5.4 | 2.1 | 2 |
140 | 6.7 | 3.1 | 5.6 | 2.4 | 2 |
141 | 6.9 | 3.1 | 5.1 | 2.3 | 2 |
142 | 5.8 | 2.7 | 5.1 | 1.9 | 2 |
143 | 6.8 | 3.2 | 5.9 | 2.3 | 2 |
144 | 6.7 | 3.3 | 5.7 | 2.5 | 2 |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2 |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2 |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2 |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2 |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2 |
150 rows × 5 columns
from sklearn.svm import LinearSVC
model = LinearSVC(penalty='l2',loss='hinge',C=5.0,max_iter=1000)
model
LinearSVC(C=5.0, class_weight=None, dual=True, fit_intercept=True,
intercept_scaling=1, loss='hinge', max_iter=1000, multi_class='ovr',
penalty='l2', random_state=None, tol=0.0001, verbose=0)
#利用支持向量机进行多元分类,求得3个二元分类模型
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3)
model.fit(X_train,y_train)
model.score(X_test,y_test)
0.91111111111111109
model.coef_ #求得参数w(3*4)
array([[ 0.09122994, 0.68795258, -0.89771321, -0.46878047],
[ 0.63845111, -2.53659412, 0.56076395, -2.09711552],
[-0.93766893, -1.60206571, 1.61178485, 3.44670505]])
model.intercept_ #参数b,截距
array([ 0.02498646, 3.58805562, -2.9032515 ])
print(model.decision_function(X_test)) #类别预测的置信水平,选取值最大的类别进行预测
y_test
[[ 7.10150829e-01 7.26777649e-01 -7.67818099e+00]
[ -2.63877237e+00 1.34370998e+00 1.81745975e-01]
[ 1.61727503e+00 -2.58289007e+00 -1.03215239e+01]
[ -1.91526895e+00 -2.76549708e-01 -1.98082316e+00]
[ -3.19493369e+00 -1.71267056e+00 1.95280407e+00]
[ -2.01050042e+00 -3.60258900e-01 -2.19276841e+00]
[ -2.51660439e+00 -1.27416100e+00 6.05174156e-01]
[ -4.29851520e+00 9.64553122e-01 3.08328234e+00]
[ -1.69251111e+00 -1.03050471e-01 -2.18095999e+00]
[ 1.50066057e+00 -1.87798823e+00 -1.00020827e+01]
[ -1.75373537e+00 1.57268128e+00 -1.70777554e+00]
[ -2.27403231e+00 4.99195848e-01 -1.39738111e+00]
[ -3.69443294e+00 1.00181579e+00 2.36162361e+00]
[ 2.06559741e+00 -3.18851211e+00 -1.18432668e+01]
[ -2.09932222e+00 -1.48025719e+00 -1.00726573e+00]
[ 1.13697130e+00 -9.16788189e-01 -9.07591398e+00]
[ -2.09122537e+00 7.82175918e-01 -8.17325781e-01]
[ -1.86566083e+00 3.05463735e-03 -2.13893179e+00]
[ 1.07090611e+00 -5.93235984e-01 -8.72914554e+00]
[ 1.60721089e+00 -1.98578120e+00 -1.04131929e+01]
[ -3.70431871e+00 -1.99146410e+00 2.17055764e+00]
[ -3.55621482e+00 -1.23301770e+00 3.12680656e+00]
[ 1.37984905e+00 -1.42238598e+00 -9.93160127e+00]
[ -2.14042657e+00 -4.83622474e-01 -1.18963172e+00]
[ -2.64474161e+00 -6.66677804e-01 1.60943687e+00]
[ 1.62272680e+00 -1.92798377e+00 -1.07872886e+01]
[ -2.29899315e+00 -1.73886667e-01 -8.68246668e-01]
[ 1.06544596e+00 -7.33021571e-01 -9.10226928e+00]
[ -1.93257380e+00 -1.06519391e+00 -1.60785352e+00]
[ 1.01489669e+00 -3.51671937e-01 -9.12959650e+00]
[ -3.76293090e+00 -5.08358729e-01 2.82980981e+00]
[ -3.38338703e+00 -1.89056278e+00 3.14605024e+00]
[ 2.03790650e+00 -3.93148742e+00 -1.15006942e+01]
[ -1.72407717e+00 -5.43885397e-01 -2.14861910e+00]
[ -3.48953042e+00 -2.76896386e+00 3.02674585e+00]
[ -2.38813687e+00 -1.84532494e+00 3.78527906e-01]
[ 1.49247035e+00 -2.08766661e+00 -1.05617683e+01]
[ -3.21108558e+00 4.72167391e-01 4.02415283e-01]
[ -4.75628250e+00 9.54890080e-01 4.76006397e+00]
[ -3.05742817e+00 -6.82355356e-01 2.10153361e+00]
[ -1.83975047e+00 -4.99937538e-01 -1.64374203e+00]
[ -1.65202598e+00 1.12708763e-01 -2.24530173e+00]
[ 1.36706321e+00 -1.41029061e+00 -9.37094374e+00]
[ -1.81059256e+00 4.22444570e-01 -1.92391668e+00]
[ 1.12511823e+00 -1.05052609e+00 -9.16870896e+00]]
array([0, 1, 0, 1, 2, 1, 2, 2, 1, 0, 1, 1, 2, 0, 1, 0, 1, 1, 0, 0, 2, 2, 0,
1, 2, 0, 1, 0, 1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 2, 2, 1, 1, 0, 1, 0])