本文代码参考博主twilight0402的文章
关于sklearn中SVM使用的文章很多,但其可视化部分往往给出代码但解释通常都不是很详细,本文主要对sklearn训练SVM后的可视化过程进行阐述,故对SVM的使用不做赘述,可以参考其他文章。
1.导入数据:本文使用sklearn中的moon、circle生成数据
import numpy as np
from sklearn import svm, datasets
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
# 生成弯月形数据
Data = datasets.make_moons(n_samples=100, noise=0.2, random_state=11)
# # 生成环形数据
# Data = datasets.make_circles(n_samples=100, noise=0.1, factor=0.1, random_state=11)
X = Data[0] # 特征 (100, 2)
y = Data[1] # 标签,0或1 (100,)
2.构造SVM模型:参考文章,这里使用pipeline构造包含数据归一化和SVM的模型
def creatSVM(kernel='rbf', C=10, gamma=10):
"""
:param kernel: SVM使用的核函数,本文以rbf核为例,常用于处理分线性分类问题
:param C: 目标函数惩罚项系数
:param gamma: rbf核函数参数。实际训练模型时C和gamma需要调参
"""
return Pipeline([
("scaler", StandardScaler()),
("SVM", svm.SVC(kernel=kernel, gamma=gamma, C=C))
])
mySVM = creatSVM() # 初始化模型
mySVM.fit(X, y) # 训练SVM模型
3.模型可视化:
SVM的可视化主要是希望画出模型分界面(决策函数),当kernel=‘linear’时(即线性SVM)可以直接得到分界面(wx+b=0)参数.coef_[0](对应w)、.intercept_[0](对应b),由此可以计算并绘制分界面。但使用其他核函数时由于实际上进行了特征变换,故无法直接得到分界面参数。
此时通常采用模拟的方法绘制近似分界面。具体来说,使用训练好的SVM计算指定区域内所有点(当然在算法中实际为抽样的有限个点)的预测结果,利用绘制等高线的函数绘制以输入特征为底面,预测结果为高度的等高线图,对应不同高度(预测结果)的点的分界面即为近似的分界面。
x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
x2_min, x2_max = X[:, 1].min()-1, X[:, 1].max() + 1
# 获得绘图边界,这里没有区分训练数据或测试数据,根据实际需求选择即可
h = (x1_max - x1_min) / 100
# h为采样点间隔,可以自己设定
xx, yy = np.meshgrid(np.arange(x1_min, x1_max, h), np.arange(x2_min, x2_max, h))
# 由meshgrid函数生成对应区域内所有点的横纵坐标,xx、yy均为尺寸为(M, N)的二维矩阵,分别对应区域内所有点的横坐标和所有点的纵坐标,同时也是区域内所有样本的第一维特征和第二维特征
z = mySVM.predict(np.c_[xx.ravel(), yy.ravel()])
# 由训练好的SVM预测区域内所有样本的结果。由于xx、yy尺寸均为(M,N),通过.ravel拉平并通过.c_组合,尺寸变为(M*N, 2),相当于M*N个具有两维特征的样本,输出z尺寸为(M*N,)
z = z.reshape(xx.shape)
# 将输出尺寸也转变为(M, N)以和横纵坐标对应绘制等高线图
plt.contourf(xx, yy, z, cmap=plt.cm.ocean, alpha=0.6)
# 绘制等高线图
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
# 标记数据中各样本
plt.title('Visualization of SVM with RBF kernel')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()
绘制结果如下,由于本文是二分类问题,途中只有两种颜色,对应不同分类结果。不同颜色的等高线界面即为近似分界面