概述
支持向量机与超平面
在了解svm算法之前,我们首先需要了解一下线性分类器这个概念。比如给定一系列的数据样本,每个样本都有对应的一个标签。为了使得描述更加直观,我们采用二维平面进行解释,高维空间原理也是一样。
举个例子,假设在一个二维线性可分的数据集中,如下图图A所示,我们要找到一个超平面把两组数据分开,这时,我们认为线性回归的直线或逻辑回归的直线也能够做这个分类,这条直线可以是图一B中的直线,也可以是图一C中的直线,或者图一D中的直线,但哪条直线才最好呢,也就是说哪条直线能够达到最好的泛化能力呢?那就是一个能使两类之间的空间大小最大的一个超平面。
这个超平面在二维平面上看到的就是一条直线,在三维空间中就是一个平面…,因此,我们把这个划分数据的决策边界统称为超平面。离这个超平面最近的点就叫做支持向量,点到超平面的距离叫间隔。支持向量机就是要使超平面和支持向量之间的间隔尽可能的大,这样超平面才可以将两类样本准确的分开,而保证间隔尽可能的大就是保证我们的分类器误差尽可能的小,尽可能的健壮。
点到超平面的距离公式
最大间隔的优化模型
松弛变量(slack variable)
由上一节的分析我们知道实际中很多样本数据都不能够用一个超平面把数据完全分开。如果数据集中存在噪点的话,那么在求超平的时候就会出现很大问题。从下图中课看出其中一个蓝点偏差太大,如果把它作为支持向量的话所求出来的margin就会比不算入它时要小得多。更糟糕的情况是如果这个蓝点落在了红点之间那么就找不出超平面了。
算法
(1)线性可分支持向量机
(2)线性支持向量机
(3)非线性支持向量机
实验步骤
1 安装并引入必要的库
!pip install numpy==1.16.0
!pip install scikit-learn==0.22.1
from sklearn.datasets import load_digits
from sklearn import svm
2 digits数据集的介绍
digits数据集由1797个8x8图像组成,如下图所示。
这些图像是由手写数字转换成的图片格式。
我们可以使用这些数据来训练我们的机器,以进一步确定其他特定数字形式的8x8图像!
听起来就像我们在对数据做分类 !
首先我们需要从sklearn导入数据集并对数据集命名:
digits = load_digits()
现在我们来看看digits数据集的类型和数据。 该类型应该是’Bunch’,它是一个类似于字典的对象,特别适用于加载sklearn内部示例数据集:
print (type(digits))
print (digits.data)
实际上,你不需要创建’Bunch’类型。 但他们提供了大量有用的信息来帮助初学者学习。
让我们来看看这个数据集的描述(description)了解更多信息!
print (digits.DESCR)
通过调用目标(target)字段,我们可以看到分类每个图像的类别(categories)。 一个数字与每个digit的分类相关联。 目标字段获取这些数字,其中每个数字都映射到target_names中的名称:
print (digits.target)
现在我们打印出target_names, 我们可以找出数据被归类为:
print (digits.target_names)
需要注意的一个重要信息是数据被存储为numpy数据类型,它是一个多维数组(ndarray)。
print (type(digits.data))
print (type(digits.target))
print (type(digits.target_names))
现在让我们确认数据和target的形状
注意: 数据的形状是一个元组,其中第一个字段是观测值的数量,第二个字段是属性的数量。
print (digits.data.shape)
print (digits.target.shape)
3 拟合预测
我们可以为命名数据和目标,以用于训练机器!
X = digits.data
y = digits.target
拟合
1 sklearn.svm.SVC()
clf = svm.SVC(gamma=0.001, C=100)
clf.fit(X,y)
预测(predict)最后一位数字8:
print('Prediction: %.2f' % clf.predict(digits.data[-1].reshape(1, -1)))
print('Actual: %.2f' % y[-1])
2 sklearn.svm.LinearSVC()
cls = svm.LinearSVC()
cls.fit(X,y)
预测(predict)最后一位数字8:
print('Prediction: %.2f' % cls.predict(digits.data[-1].reshape(1, -1)))
print('Actual: %.2f' % y[-1])